diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..bfbd0a6 --- /dev/null +++ b/.gitignore @@ -0,0 +1,28 @@ +.DS_Store +*.pkl +*.npy +*.pt +*.pyc +**/__*__/** + + +*.log +.vscode/ +__pycache__/ +*.pyc +*.pt +*.model +*.dgl +*.txt +*results/ +*.npz +*.npy +*.sh +*.log +*.cmd +data/* +*.pdf +*.png +*.json + +.idea/ diff --git a/ColoredMNIST/epo_lp.py b/ColoredMNIST/epo_lp.py new file mode 100644 index 0000000..d92abe1 --- /dev/null +++ b/ColoredMNIST/epo_lp.py @@ -0,0 +1,120 @@ +import numpy as np +import cvxpy as cp +import cvxopt + +from scipy.special import softmax +class EPO_LP(object): + + def __init__(self, m, n, r, eps=1e-4, softmax_norm=False): + # self.solver = cp.GLPK + self.solver = cp.GUROBI + # cvxopt.glpk.options["msg_lev"] = "GLP_MSG_OFF" + self.m = m + self.n = n + self.r = r + self.eps = eps + self.last_move = None + self.a = cp.Parameter(m) # Adjustments + self.C = cp.Parameter((m, m)) # C: Gradient inner products, G^T G + self.Ca = cp.Parameter(m) # d_bal^TG + self.rhs = cp.Parameter(m) # RHS of constraints for balancing + + self.alpha = cp.Variable(m) # Variable to optimize + + obj_bal = cp.Maximize(self.alpha @ self.Ca) # objective for balance + constraints_bal = [self.alpha >= 0, cp.sum(self.alpha) == 1, # Simplex + self.C @ self.alpha >= self.rhs] + self.prob_bal = cp.Problem(obj_bal, constraints_bal) # LP balance + + obj_dom = cp.Maximize(cp.sum(self.alpha @ self.C)) # obj for descent + constraints_res = [self.alpha >= 0, cp.sum(self.alpha) == 1, # Restrict + self.alpha @ self.Ca >= -cp.neg(cp.max(self.Ca)), + self.C @ self.alpha >= 0] + constraints_rel = [self.alpha >= 0, cp.sum(self.alpha) == 1, # Relaxed + self.C @ self.alpha >= 0] + self.prob_dom = cp.Problem(obj_dom, constraints_res) # LP dominance + self.prob_rel = cp.Problem(obj_dom, constraints_rel) # LP dominance + + self.gamma = 0 # Stores the latest Optimum value of the LP problem + self.mu_rl = 0 # Stores the latest non-uniformity + + self.softmax_norm = softmax_norm # use which normalization to calc. non-uniformity + + def get_alpha(self, l, G, r=None, C=False, relax=False): + r = self.r if r is None else r + assert len(l) == len(G) == len(r) == self.m, "length != m" + + if self.softmax_norm: + r = np.exp(r) + l = np.exp(l) + rl, self.mu_rl, self.a.value = self.adjustments(l, r) + + self.C.value = G if C else G @ G.T + self.Ca.value = self.C.value @ self.a.value + + if self.mu_rl > self.eps: + J = self.Ca.value > 0 + # if len(np.where(J)[0]) > 0: + if True: + J_star_idx = np.where(rl == np.max(rl))[0] + self.rhs.value = self.Ca.value.copy() + self.rhs.value[J] = -np.inf # Not efficient; but works. + self.rhs.value[J_star_idx] = 0 + else: + self.rhs.value = np.zeros_like(self.Ca.value) + self.gamma = self.prob_bal.solve(solver=self.solver, verbose=False,reoptimize=True) + self.last_move = "bal" + else: + if relax: + self.gamma = self.prob_rel.solve(solver=self.solver, verbose=False,reoptimize=True) + else: + self.gamma = self.prob_dom.solve(solver=self.solver, verbose=False,reoptimize=True) + self.last_move = "dom" + return self.alpha.value + + + def mu(self, rl, normed=False): + if len(np.where(rl < 0)[0]): + raise ValueError(f"rl<0 \n rl={rl}") + return None + m = len(rl) + if normed: + # if self.softmax_norm: + # l_hat = softmax(rl) + # else: + l_hat = rl/rl.sum() + # l_hat = rl if normed else rl / rl.sum() + eps = np.finfo(rl.dtype).eps + l_hat = l_hat[l_hat > eps] + return np.sum(l_hat * np.log(l_hat * m)) + + + def adjustments(self, l, r=1): + m = len(l) + rl = r * l + # if self.softmax_norm: + # l_hat = softmax(rl) + # else: + # l_hat = rl / rl.sum() + # rl = np.exp(rl) if self.softmax_norm else rl + l_hat = rl/rl.sum() + # print(l_hat[0]/l_hat[2]) + mu_rl = self.mu(l_hat, normed=True) + a = r * (np.log(l_hat * m) - mu_rl) + return rl, mu_rl, a + +# def get_param_dim(model): +# for param in model.parameters(): +# if param.grad is not None: +# cur_grad.append(Variable(param.data.clone().flatten(), requires_grad=False)) +# grads.append(torch.cat(cur_grad)) + + +def getNumParams(params): + numParams, numTrainable = 0, 0 + for param in params: + npParamCount = np.prod(param.data.shape) + numParams += npParamCount + if param.requires_grad: + numTrainable += npParamCount + return numParams, numTrainable diff --git a/ColoredMNIST/misc.py b/ColoredMNIST/misc.py new file mode 100644 index 0000000..e8223b9 --- /dev/null +++ b/ColoredMNIST/misc.py @@ -0,0 +1,400 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +""" +Things that don't belong anywhere else +""" + +import hashlib +import json +import os +import sys +from shutil import copyfile +from collections import OrderedDict +from numbers import Number +import operator + +import numpy as np +import torch +import tqdm +from collections import Counter + +def make_weights_for_balanced_classes(dataset): + counts = Counter() + classes = [] + for _, y in dataset: + y = int(y) + counts[y] += 1 + classes.append(y) + + n_classes = len(counts) + + weight_per_class = {} + for y in counts: + weight_per_class[y] = 1 / (counts[y] * n_classes) + + weights = torch.zeros(len(dataset)) + for i, y in enumerate(classes): + weights[i] = weight_per_class[int(y)] + + return weights + +def pdb(): + sys.stdout = sys.__stdout__ + import pdb + print("Launching PDB, enter 'n' to step to parent function.") + pdb.set_trace() + +def seed_hash(*args): + """ + Derive an integer hash from all args, for use as a random seed. + """ + args_str = str(args) + return int(hashlib.md5(args_str.encode("utf-8")).hexdigest(), 16) % (2**31) + +def print_separator(): + print("="*80) + +def print_row(row, colwidth=10, latex=False): + if latex: + sep = " & " + end_ = "\\\\" + else: + sep = " " + end_ = "" + + def format_val(x): + if np.issubdtype(type(x), np.floating): + x = "{:.10f}".format(x) + return str(x).ljust(colwidth)[:colwidth] + print(sep.join([format_val(x) for x in row]), end_) + +class _SplitDataset(torch.utils.data.Dataset): + """Used by split_dataset""" + def __init__(self, underlying_dataset, keys): + super(_SplitDataset, self).__init__() + self.underlying_dataset = underlying_dataset + self.keys = keys + def __getitem__(self, key): + return self.underlying_dataset[self.keys[key]] + def __len__(self): + return len(self.keys) + +def split_dataset(dataset, n, seed=0): + """ + Return a pair of datasets corresponding to a random split of the given + dataset, with n datapoints in the first dataset and the rest in the last, + using the given random seed + """ + assert(n <= len(dataset)) + keys = list(range(len(dataset))) + np.random.RandomState(seed).shuffle(keys) + keys_1 = keys[:n] + keys_2 = keys[n:] + return _SplitDataset(dataset, keys_1), _SplitDataset(dataset, keys_2) + +def random_pairs_of_minibatches(minibatches): + perm = torch.randperm(len(minibatches)).tolist() + pairs = [] + + for i in range(len(minibatches)): + j = i + 1 if i < (len(minibatches) - 1) else 0 + + xi, yi = minibatches[perm[i]][0], minibatches[perm[i]][1] + xj, yj = minibatches[perm[j]][0], minibatches[perm[j]][1] + + min_n = min(len(xi), len(xj)) + + pairs.append(((xi[:min_n], yi[:min_n]), (xj[:min_n], yj[:min_n]))) + + return pairs + +def accuracy(network, loader, weights, device): + correct = 0 + total = 0 + weights_offset = 0 + + network.eval() + with torch.no_grad(): + for x, y in loader: + x = x.to(device) + y = y.to(device) + p = network.predict(x) + #print() + #print(p) + if weights is None: + batch_weights = torch.ones(len(x)) + else: + batch_weights = weights[weights_offset : weights_offset + len(x)] + weights_offset += len(x) + batch_weights = batch_weights.to(device) + if p.size(1) == 1: + #print(p.flatten().gt(0).eq(y).float()) + #print(p.flatten().gt(0).eq(y).float().sum().item()) + correct += (p.flatten().gt(0).eq(y).float() * batch_weights.flatten()).sum().item() + else: + correct += (p.argmax(1).eq(y).float() * batch_weights).sum().item() + total += batch_weights.sum().item() + #print(correct,total) + #0/0 + network.train() + + return correct / total + +class Tee: + def __init__(self, fname, mode="a"): + self.stdout = sys.stdout + self.file = open(fname, mode) + + def write(self, message): + self.stdout.write(message) + self.file.write(message) + self.flush() + + def flush(self): + self.stdout.flush() + self.file.flush() + +class ParamDict(OrderedDict): + """Code adapted from https://github.com/Alok/rl_implementations/tree/master/reptile. + A dictionary where the values are Tensors, meant to represent weights of + a model. This subclass lets you perform arithmetic on weights directly.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, *kwargs) + + def _prototype(self, other, op): + if isinstance(other, Number): + return ParamDict({k: op(v, other) for k, v in self.items()}) + elif isinstance(other, dict): + return ParamDict({k: op(self[k], other[k]) for k in self}) + else: + raise NotImplementedError + + def __add__(self, other): + return self._prototype(other, operator.add) + + def __rmul__(self, other): + return self._prototype(other, operator.mul) + + __mul__ = __rmul__ + + def __neg__(self): + return ParamDict({k: -v for k, v in self.items()}) + + def __rsub__(self, other): + # a- b := a + (-b) + return self.__add__(other.__neg__()) + + __sub__ = __rsub__ + + def __truediv__(self, other): + return self._prototype(other, operator.truediv) + + + + +def l2_between_dicts(dict_1, dict_2): + assert len(dict_1) == len(dict_2) + dict_1_values = [dict_1[key] for key in sorted(dict_1.keys())] + dict_2_values = [dict_2[key] for key in sorted(dict_1.keys())] + return ( + torch.cat(tuple([t.view(-1) for t in dict_1_values])) - + torch.cat(tuple([t.view(-1) for t in dict_2_values])) + ).pow(2).mean() + +class MovingAverage: + + def __init__(self, ema, oneminusema_correction=True): + self.ema = ema + self.ema_data = {} + self._updates = 0 + self._oneminusema_correction = oneminusema_correction + + def update(self, dict_data): + ema_dict_data = {} + for name, data in dict_data.items(): + data = data.view(1, -1) + if self._updates == 0: + previous_data = torch.zeros_like(data) + else: + previous_data = self.ema_data[name] + + ema_data = self.ema * previous_data + (1 - self.ema) * data + if self._oneminusema_correction: + # correction by 1/(1 - self.ema) + # so that the gradients amplitude backpropagated in data is independent of self.ema + ema_dict_data[name] = ema_data / (1 - self.ema) + else: + ema_dict_data[name] = ema_data + self.ema_data[name] = ema_data.clone().detach() + + self._updates += 1 + return ema_dict_data + + + +def make_weights_for_balanced_classes(dataset): + counts = Counter() + classes = [] + for _, y in dataset: + y = int(y) + counts[y] += 1 + classes.append(y) + + n_classes = len(counts) + + weight_per_class = {} + for y in counts: + weight_per_class[y] = 1 / (counts[y] * n_classes) + + weights = torch.zeros(len(dataset)) + for i, y in enumerate(classes): + weights[i] = weight_per_class[int(y)] + + return weights + +def pdb(): + sys.stdout = sys.__stdout__ + import pdb + print("Launching PDB, enter 'n' to step to parent function.") + pdb.set_trace() + +def seed_hash(*args): + """ + Derive an integer hash from all args, for use as a random seed. + """ + args_str = str(args) + return int(hashlib.md5(args_str.encode("utf-8")).hexdigest(), 16) % (2**31) + +def print_separator(): + print("="*80) + +def print_row(row, colwidth=10, latex=False): + if latex: + sep = " & " + end_ = "\\\\" + else: + sep = " " + end_ = "" + + def format_val(x): + if np.issubdtype(type(x), np.floating): + x = "{:.10f}".format(x) + return str(x).ljust(colwidth)[:colwidth] + print(sep.join([format_val(x) for x in row]), end_) + +class _SplitDataset(torch.utils.data.Dataset): + """Used by split_dataset""" + def __init__(self, underlying_dataset, keys): + super(_SplitDataset, self).__init__() + self.underlying_dataset = underlying_dataset + self.keys = keys + def __getitem__(self, key): + return self.underlying_dataset[self.keys[key]] + def __len__(self): + return len(self.keys) + +def split_dataset(dataset, n, seed=0): + """ + Return a pair of datasets corresponding to a random split of the given + dataset, with n datapoints in the first dataset and the rest in the last, + using the given random seed + """ + assert(n <= len(dataset)) + keys = list(range(len(dataset))) + np.random.RandomState(seed).shuffle(keys) + keys_1 = keys[:n] + keys_2 = keys[n:] + return _SplitDataset(dataset, keys_1), _SplitDataset(dataset, keys_2) + +def random_pairs_of_minibatches(minibatches): + perm = torch.randperm(len(minibatches)).tolist() + pairs = [] + + for i in range(len(minibatches)): + j = i + 1 if i < (len(minibatches) - 1) else 0 + + xi, yi = minibatches[perm[i]][0], minibatches[perm[i]][1] + xj, yj = minibatches[perm[j]][0], minibatches[perm[j]][1] + + min_n = min(len(xi), len(xj)) + + pairs.append(((xi[:min_n], yi[:min_n]), (xj[:min_n], yj[:min_n]))) + + return pairs + +def accuracy(network, loader, weights, device): + correct = 0 + total = 0 + weights_offset = 0 + + network.eval() + with torch.no_grad(): + for x, y in loader: + x = x.to(device) + y = y.to(device) + p = network.predict(x) + if weights is None: + batch_weights = torch.ones(len(x)) + else: + batch_weights = weights[weights_offset : weights_offset + len(x)] + weights_offset += len(x) + batch_weights = batch_weights.to(device) + if p.size(1) == 1: + correct += (p.gt(0).eq(y).float() * batch_weights.view(-1, 1)).sum().item() + else: + correct += (p.argmax(1).eq(y).float() * batch_weights).sum().item() + total += batch_weights.sum().item() + network.train() + + return correct / total + +class Tee: + def __init__(self, fname, mode="a"): + self.stdout = sys.stdout + self.file = open(fname, mode) + + def write(self, message): + self.stdout.write(message) + self.file.write(message) + self.flush() + + def flush(self): + self.stdout.flush() + self.file.flush() + +class ParamDict(OrderedDict): + """Code adapted from https://github.com/Alok/rl_implementations/tree/master/reptile. + A dictionary where the values are Tensors, meant to represent weights of + a model. This subclass lets you perform arithmetic on weights directly.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, *kwargs) + + def _prototype(self, other, op): + if isinstance(other, Number): + return ParamDict({k: op(v, other) for k, v in self.items()}) + elif isinstance(other, dict): + return ParamDict({k: op(self[k], other[k]) for k in self}) + else: + raise NotImplementedError + + def __add__(self, other): + return self._prototype(other, operator.add) + + def __rmul__(self, other): + return self._prototype(other, operator.mul) + + __mul__ = __rmul__ + + def __neg__(self): + return ParamDict({k: -v for k, v in self.items()}) + + def __rsub__(self, other): + # a- b := a + (-b) + return self.__add__(other.__neg__()) + + __sub__ = __rsub__ + + def __truediv__(self, other): + return self._prototype(other, operator.truediv) diff --git a/ColoredMNIST/models.py b/ColoredMNIST/models.py new file mode 100644 index 0000000..2239175 --- /dev/null +++ b/ColoredMNIST/models.py @@ -0,0 +1,83 @@ +import torch +from torchvision import datasets +from torch import nn, optim, autograd +import torchvision +from backpack import backpack, extend +from backpack.extensions import BatchGrad + + +class Net(nn.Module): + def __init__(self, mlp,topmlp): + super(Net, self).__init__() + self.net = nn.Sequential(mlp,topmlp) + def forward(self,data): + return self.net(data) +# Define and instantiate the model +class Linear(nn.Module): + def __init__(self, hidden_dim=1, input_dim=2*14*14): + super(Linear, self).__init__() + + self.input_dim = input_dim + + lin1 = nn.Linear(self.input_dim, hidden_dim) + + nn.init.xavier_uniform_(lin1.weight) + nn.init.zeros_(lin1.bias) + + self._main = lin1 + def forward(self,input): + out = input.view(input.shape[0], self.input_dim) + out = self._main(out) + return out + + +class MLP(nn.Module): + def __init__(self, hidden_dim=390, input_dim=2*14*14): + super(MLP, self).__init__() + + self.input_dim = input_dim + + lin1 = nn.Linear(self.input_dim, hidden_dim) + lin2 = nn.Linear(hidden_dim, hidden_dim) + + nn.init.xavier_uniform_(lin1.weight) + nn.init.zeros_(lin1.bias) + nn.init.xavier_uniform_(lin2.weight) + nn.init.zeros_(lin2.bias) + + self._main = nn.Sequential(lin1, nn.ReLU(True), lin2, nn.ReLU(True)) + + def forward(self, input): + out = input.view(input.shape[0], self.input_dim) + out = self._main(out) + return out + + +class TopMLP(nn.Module): + def __init__(self, hidden_dim=390, n_top_layers=1, n_targets=1, fishr=False): + + super(TopMLP, self).__init__() + + if fishr: + self.lin1 = lin1 = extend(nn.Linear(hidden_dim,n_targets)) + else: + self.lin1 = lin1 = nn.Linear(hidden_dim,n_targets) + nn.init.xavier_uniform_(lin1.weight) + nn.init.zeros_(lin1.bias) + self._main = nn.Sequential(lin1) + self.weights = [lin1.weight, lin1.bias] + + def forward(self,input): + out = self._main(input) + return out + + + +# # from https://github.com/facebookresearch/DomainBed/tree/master/domainbed +class Identity(nn.Module): + """An identity layer""" + def __init__(self): + super(Identity, self).__init__() + + def forward(self, x): + return x diff --git a/ColoredMNIST/mydatasets.py b/ColoredMNIST/mydatasets.py new file mode 100644 index 0000000..c075de5 --- /dev/null +++ b/ColoredMNIST/mydatasets.py @@ -0,0 +1,71 @@ + +import numpy as np +import torch +from torchvision import datasets +import math +import os +import torch +from PIL import Image, ImageFile +from torchvision import transforms +import torchvision.datasets.folder +from torch.utils.data import TensorDataset, Subset +from torchvision.datasets import MNIST, ImageFolder +from torchvision.transforms.functional import rotate + + +from misc import split_dataset,make_weights_for_balanced_classes,seed_hash +# from fast_data_loader import InfiniteDataLoader, FastDataLoader + + +#coloredmnist are modified from https://github.com/facebookresearch/InvariantRiskMinimization +def coloredmnist(label_noise_rate, trenv1, trenv2, int_target=False): + # Load MNIST, make train/val splits, and shuffle train set examples + mnist = datasets.MNIST('./data/MNIST', train=True, download=True) + mnist_train = (mnist.data[:50000], mnist.targets[:50000]) + mnist_val = (mnist.data[50000:], mnist.targets[50000:]) + + rng_state = np.random.get_state() + np.random.shuffle(mnist_train[0].numpy()) + np.random.set_state(rng_state) + np.random.shuffle(mnist_train[1].numpy()) + + # Build environments + def make_environment(images, labels, e): + def torch_bernoulli(p, size): + return (torch.rand(size) < p).float() + def torch_xor(a, b): + return (a-b).abs() # Assumes both inputs are either 0 or 1 + # 2x subsample for computational convenience + images = images.reshape((-1, 28, 28))[:, ::2, ::2] + # Assign a binary label based on the digit; flip label with probability 0.25 + labels = (labels < 5).float() + labels = torch_xor(labels, torch_bernoulli(label_noise_rate, len(labels))) + # Assign a color based on the label; flip the color with probability e + colors = torch_xor(labels, torch_bernoulli(e, len(labels))) + # Apply the color to the image by zeroing out the other color channel + images = torch.stack([images, images], dim=1) + images[torch.tensor(range(len(images))), (1-colors).long(), :, :] *= 0 + + if int_target: + return { + 'images': (images.float() / 255.).cuda(), + 'labels': labels[:, None].long().flatten().cuda() + } + else: + return { + 'images': (images.float() / 255.).cuda(), + 'labels': labels[:, None].cuda() + } + + + envs = [ + make_environment(mnist_train[0][::2], mnist_train[1][::2], trenv1), + make_environment(mnist_train[0][1::2], mnist_train[1][1::2], trenv2)] + + # init 3 test environments [0.1, 0.5, 0.9] + test_envs = [ + make_environment(mnist_val[0], mnist_val[1], 0.9), + make_environment(mnist_val[0], mnist_val[1], 0.1), + make_environment(mnist_val[0], mnist_val[1], 0.5), + ] + return envs, test_envs diff --git a/ColoredMNIST/pair.py b/ColoredMNIST/pair.py new file mode 100644 index 0000000..947f69d --- /dev/null +++ b/ColoredMNIST/pair.py @@ -0,0 +1,452 @@ +import copy +import imp +from pickletools import optimize +import torch +from torch.optim.optimizer import Optimizer, required +from torch.autograd import Variable +import traceback +import torch.nn.functional as F +from torch.optim import SGD + +class PAIR(Optimizer): + r""" + Implements Pareto Invariant Risk Minimization (PAIR) algorithm. + It is proposed in the ICLR 2023 paper + `Pareto Invariant Risk Minimization: Towards Mitigating the Optimization Dilemma in Out-of-Distribution Generalization` + https://arxiv.org/abs/2206.07766 . + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining parameter groups + optimizer (pytorch optim): inner optimizer + balancer (str, optional): indicates which MOO solver to use + preference (list[float], optional): preference of the objectives + eps (float, optional): precision up to the preference (default: 1e-04) + coe (float, optional): L2 regularization weight onto the yielded objective weights (default: 0) + """ + + def __init__(self, params, optimizer=required, balancer="EPO",preference=[1e-8,1-1e-8], eps=1e-4, coe=0, verbose=False): + # TODO: parameter validty checking + if eps < 0.0: + raise ValueError("Invalid epsilon value: {}".format(eps)) + for _pp in preference: + if _pp < 0.0: + raise ValueError("Invalid preference: {}".format(preference)) + + self.optimizer = optimizer + if type(preference) == list: + preference = np.array(preference) + self.preference = preference + + self.descent = 0 + self.losses = [] + self.params = params + if balancer.lower() == "epo": + self.balancer = EPO(len(self.preference),self.preference,eps=eps,coe=coe,verbose=verbose) + elif balancer.lower() == "sepo": + self.balancer = SEPO(len(self.preference),self.preference,eps=eps,coe=coe,verbose=verbose) + else: + raise NotImplementedError("Nrot supported balancer") + defaults = dict(balancer=balancer, preference=self.preference, eps=eps) + super(PAIR, self).__init__(params, defaults) + + + def __setstate__(self, state): + super(PAIR, self).__setstate__(state) + + def set_losses(self,losses): + self.losses = losses + + def step(self, closure=None): + r"""Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + if len(self.losses) == 0: + self.optimizer.step() + alphas = np.zeros(len(self.preference)) + alphas[0] = 1 + return -1, 233, alphas + else: + losses = self.losses + if closure is not None: + losses = closure() + + pair_loss = 0 + mu_rl = 0 + alphas = 0 + + grads = [] + for cur_loss in losses: + self.optimizer.zero_grad() + cur_loss.backward(retain_graph=True) + cur_grad = [] + for group in self.param_groups: + for param in group['params']: + if param.grad is not None: + cur_grad.append(Variable(param.grad.data.clone().flatten(), requires_grad=False)) + grads.append(torch.cat(cur_grad)) + + G = torch.stack(grads) + if self.get_grad_sim: + grad_sim = get_grad_sim(G,losses,preference=self.preference,is_G=True) + GG = G @ G.T + moo_losses = np.stack([l.item() for l in losses]) + reset_optimizer = False + try: + # Calculate the alphas from the LP solver + alpha, mu_rl, reset_optimizer = self.balancer.get_alpha(moo_losses, G=GG.cpu().numpy(), C=True,get_mu=True) + if self.balancer.last_move == "dom": + self.descent += 1 + print("dom") + except Exception as e: + print(traceback.format_exc()) + alpha = None + if alpha is None: # A patch for the issue in cvxpy + alpha = self.preference / np.sum(self.preference) + + scales = torch.from_numpy(alpha).float().to(losses[-1].device) + pair_loss = scales.dot(losses) + if reset_optimizer: + self.optimizer.param_groups[0]["lr"]/=5 + # self.optimizer = torch.optim.Adam(self.params,lr=self.optimizer.param_groups[0]["lr"]/5) + self.optimizer.zero_grad() + pair_loss.backward() + self.optimizer.step() + + return pair_loss, moo_losses, mu_rl, alpha + + + +import numpy as np +import cvxpy as cp +import cvxopt + +class EPO(object): + r""" + The original EPO solver proposed in ICML2020 + https://proceedings.mlr.press/v119/mahapatra20a.html + """ + def __init__(self, m, r, eps=1e-4, coe=0, verbose=False): + # self.solver = cp.GLPK + self.solver = cp.GUROBI + # cvxopt.glpk.options["msg_lev"] = "GLP_MSG_OFF" + self.m = m + self.r = r/np.sum(r) + self.eps = eps + self.last_move = None + self.a = cp.Parameter(m) # Adjustments + self.C = cp.Parameter((m, m)) # C: Gradient inner products, G^T G + self.Ca = cp.Parameter(m) # d_bal^TG + self.rhs = cp.Parameter(m) # RHS of constraints for balancing + + self.alpha = cp.Variable(m) # Variable to optimize + self.last_alpha = np.zeros_like(r)-1 + self.coe = coe + + obj_bal = cp.Maximize(self.alpha @ self.Ca) # objective for balance + constraints_bal = [self.alpha >= 0, cp.sum(self.alpha) == 1, # Simplex + self.C @ self.alpha >= self.rhs] + self.prob_bal = cp.Problem(obj_bal, constraints_bal) # LP balance + + obj_dom = cp.Maximize(cp.sum(self.alpha @ self.C)-self.coe*cp.sum_squares(self.alpha-self.last_alpha)) # obj for descent + constraints_dom = [self.alpha >= 0, cp.sum(self.alpha) == 1, # Restrict + self.alpha @ self.Ca >= -cp.neg(cp.max(self.Ca)), + self.C @ self.alpha >= 0] + self.prob_dom = cp.Problem(obj_dom, constraints_dom) # LP dominance + + + self.gamma = 0 # Stores the latest Optimum value of the LP problem + self.mu_rl = 0 # Stores the latest non-uniformity + + self.verbose = verbose + + + def get_alpha(self, l, G, r=None, C=False, get_mu=False): + """calculate weights for all objectives given the gradient information + + Args: + l (ndarray): the values of objective losses + G (ndarray): inner products of the gradients of each objective loss w.r.t. params + r (ndarray, optional): adopt this preference if specified + C (bool, optional): True if the input gradients are inner products + get_mu (bool, optional): return detailed information if True. + + Returns: + alpha: the objective weights + mu_rl (optional): the optimal value to the LP + reset_optimizer (optional): whether to reset the inner optimizer + """ + r = self.r if r is None else r + assert len(l) == len(G) == len(r) == self.m, "length != m" + l_hat, rl, self.mu_rl, self.a.value = self.adjustments(l, r) + reset_optimizer = False + self.C.value = G if C else G @ G.T + self.Ca.value = self.C.value @ self.a.value + + if self.last_alpha.sum() is None: + self.last_alpha = np.array(r) + if self.mu_rl > self.eps: + J = self.Ca.value > 0 + J_star_idx = np.where(rl == np.max(rl))[0] + self.rhs.value = self.Ca.value.copy() + # it's equivalent to setting no constraints to objectives in J + # as maximize alpha^TCa would trivially satisfy the non-negativity + self.rhs.value[J] = -np.inf + self.rhs.value[J_star_idx] = 0 + + self.gamma = self.prob_bal.solve(solver=self.solver, verbose=False) + self.last_move = "bal" + + if self.verbose: + test_alpha = np.ones_like(self.a.value)/self.m + print(self.last_alpha,self.C.value,self.Ca.value,self.rhs.value) + print(self.gamma,test_alpha@self.Ca.value, self.alpha.value @ self.C.value) + print(self.gamma,self.coe*np.linalg.norm(self.alpha.value-self.last_alpha)**2) + + else: + self.gamma = self.prob_dom.solve(solver=self.solver, verbose=False) + self.last_move = "dom" + self.last_alpha = np.array(self.alpha.value) + + if get_mu: + return self.alpha.value, self.mu_rl, reset_optimizer + + return self.alpha.value + + + def mu(self, rl, normed=False): + if len(np.where(rl < 0)[0]): + raise ValueError(f"rl<0 \n rl={rl}") + return None + m = len(rl) + l_hat = rl if normed else rl / rl.sum() + eps = np.finfo(rl.dtype).eps + l_hat = l_hat[l_hat > eps] + return np.sum(l_hat * np.log(l_hat * m)) + + + def adjustments(self, l, r=1): + m = len(l) + rl = r * l + + l_hat = rl / rl.sum() + mu_rl = self.mu(l_hat, normed=True) + uniformity_div = np.log(l_hat * m) - mu_rl + div_r = np.array(r) + a = div_r * uniformity_div + + if self.verbose: + print(a, rl, div_r, uniformity_div, l_hat, a.dot(l)) + return l_hat, rl, mu_rl, a + + +class SEPO(object): + r""" + A smoothed variant of EPO, with two adjustments for unrobust OOD objectives: + a) normalization: unrobust OOD objective can yield large loss values that dominate the solutions of the LP, + hence we adopt the normalized OOD losses in the LP to resolve the issue + b) regularization: solutions yielded by the LP can change sharply among steps, especially when switching descending phases + hence we incorporate a L2 regularization in the LP to resolve the issue + """ + def __init__(self, m, r, eps=1e-4, coe=0, verbose=False): + # self.solver = cp.GLPK + self.solver = cp.GUROBI + # cvxopt.glpk.options["msg_lev"] = "GLP_MSG_OFF" + self.m = m + self.r = r/np.sum(r) + self.eps = eps + self.last_move = None + self.a = cp.Parameter(m) # Adjustments + self.C = cp.Parameter((m, m)) # C: Gradient inner products, G^T G + self.Ca = cp.Parameter(m) # d_bal^TG + self.rhs = cp.Parameter(m) # RHS of constraints for balancing + + self.alpha = cp.Variable(m) # Variable to optimize + self.last_alpha = np.zeros_like(r)-1 + self.coe = coe + + obj_bal = cp.Maximize(self.alpha @ self.Ca-self.coe*cp.sum_squares(self.alpha-self.last_alpha)) # objective for balance + obj_bal_orig = cp.Maximize(self.alpha @ self.Ca) # objective for balance + constraints_bal = [self.alpha >= 0, cp.sum(self.alpha) == 1, # Simplex + self.C @ self.alpha >= self.rhs] + self.prob_bal = cp.Problem(obj_bal, constraints_bal) # LP balance + self.prob_bal_orig = cp.Problem(obj_bal_orig, constraints_bal) # LP balance + + obj_dom = cp.Maximize(cp.sum(self.alpha @ self.C)-self.coe*cp.sum_squares(self.alpha-self.last_alpha)) # obj for descent + constraints_res = [self.alpha >= 0, cp.sum(self.alpha) == 1, # Restrict + self.alpha @ self.Ca >= -cp.neg(cp.max(self.Ca)), + self.C @ self.alpha >= 0] + constraints_rel = [self.alpha >= 0, cp.sum(self.alpha) == 1, # Relaxed + self.C @ self.alpha >= 0] + self.prob_dom = cp.Problem(obj_dom, constraints_res) # LP dominance + self.prob_rel = cp.Problem(obj_dom, constraints_rel) # LP dominance + + self.gamma = 0 # Stores the latest Optimum value of the LP problem + self.mu_rl = 0 # Stores the latest non-uniformity + + self.verbose = verbose + + + def get_alpha(self, l, G, r=None, C=False, get_mu=False): + """calculate weights for all objectives given the gradient information + + Args: + l (ndarray): the values of objective losses + G (ndarray): inner products of the gradients of each objective loss w.r.t. params + r (ndarray, optional): adopt this preference if specified + C (bool, optional): True if the input gradients are inner products + get_mu (bool, optional): return detailed information if True. + + Returns: + alpha: the objective weights + mu_rl (optional): the optimal value to the LP + reset_optimizer (optional): whether to reset the inner optimizer + """ + r = self.r if r is None else r + assert len(l) == len(G) == len(r) == self.m, "length != m" + l_hat, rl, self.mu_rl, self.a.value = self.adjustments(l, r) + reset_optimizer = False + if self.mu_rl <= 0.1: + self.r[0]=max(1e-15,self.r[0]/10000) + self.r = self.r/self.r.sum() + print(f"pua preference {self.r}") + l_hat, rl, self.mu_rl, self.a.value = self.adjustments(l, r) + + + a_norm = np.linalg.norm(self.a.value) + G_norm = np.linalg.norm(G,axis=1) + Ga = G.T @ self.a.value + self.C.value = G if C else G/np.expand_dims(G_norm,axis=1) @ G.T/a_norm + self.Ca.value = G/np.expand_dims(G_norm,axis=1) @ Ga.T/a_norm + + if self.last_alpha.sum() is None: + self.last_alpha = np.array(r) + if self.mu_rl > self.eps: + J = self.Ca.value > 0 + + J_star_idx = np.where(rl == np.max(rl))[0] + self.rhs.value = self.Ca.value.copy() + # it's equivalent to setting no constraints to objectives in J + # as maximize alpha^TCa would trivially satisfy the non-negativity + self.rhs.value[J] = -np.inf # Not efficient; but works. + self.rhs.value[J_star_idx] = max(0,self.Ca.value[J_star_idx]/2) + + if self.last_alpha.sum()<0: + self.gamma = self.prob_bal_orig.solve(solver=self.solver, verbose=False) + else: + self.gamma = self.prob_bal.solve(solver=self.solver, verbose=False) + + self.last_move = "bal" + + if self.verbose: + test_alpha = np.ones_like(self.a.value)/self.m + print(self.last_alpha,self.C.value,self.Ca.value,self.rhs.value) + print(self.gamma,test_alpha@self.Ca.value, self.alpha.value @ self.C.value) + print(self.gamma,self.coe*np.linalg.norm(self.alpha.value-self.last_alpha)**2) + else: + self.gamma = self.prob_dom.solve(solver=self.solver, verbose=False) + self.last_move = "dom" + self.last_alpha = np.array(self.alpha.value) + + if get_mu: + return self.alpha.value, self.mu_rl, reset_optimizer + + return self.alpha.value + + + def mu(self, rl, normed=False): + if len(np.where(rl < 0)[0]): + raise ValueError(f"rl<0 \n rl={rl}") + return None + m = len(rl) + l_hat = rl if normed else rl / rl.sum() + eps = np.finfo(rl.dtype).eps + l_hat = l_hat[l_hat > eps] + return np.sum(l_hat * np.log(l_hat * m)) + + + def adjustments(self, l, r=1): + m = len(l) + rl = r * l + + l_hat = rl / rl.sum() + mu_rl = self.mu(l_hat, normed=True) + uniformity_div = np.log(l_hat * m) - mu_rl + div_r = np.array(r) + a = div_r * uniformity_div + + if self.verbose: + print(a, rl, div_r, uniformity_div, l_hat, a.dot(l)) + return l_hat, rl, mu_rl, a + + +def getNumParams(params): + numParams, numTrainable = 0, 0 + for param in params: + npParamCount = np.prod(param.data.shape) + numParams += npParamCount + if param.requires_grad: + numTrainable += npParamCount + return numParams, numTrainable + +def get_kl_div(losses, preference): + pair_score = losses.dot(preference) + return pair_score + +def pair_selection(losses,val_accs,test_accs,anneal_iter=0,val_acc_bar=-1,pood=None): + + losses = losses[anneal_iter:] + val_accs = val_accs[anneal_iter:] + test_accs = test_accs[anneal_iter:] + if val_acc_bar < 0: + val_acc_bar = (np.max(val_accs)-np.min(val_accs))*0.05+np.min(val_accs) + + try: + preference_base = 10**max(-12,int(np.log10(np.mean(losses[:,-1]))-2)) + except Exception as e: + print(e) + preference_base = 1e-12 + if len(losses[0])==2: + preference = np.array([preference_base,1]) + elif len(losses[0])==4: + preference = np.array([1e-12,1e-4,1e-2,1]) + elif len(losses[0])==5: + preference = np.array([1e-12,1e-6,1e-4,1e-2,1]) + else: + preference = np.array([1e-12,1e-2,1]) + + if pood is not None: + preference = pood + print(f"Use preference: {preference}, validation acc bar: {val_acc_bar}") + + pair_score = np.array([get_kl_div(l,preference) if a>=val_acc_bar else 1e9 for (a,l) in zip(val_accs,losses)]) + sel_idx = np.argmin(pair_score) + return sel_idx+anneal_iter, val_accs[sel_idx], test_accs[sel_idx] + +def get_grad_sim(params,losses,preference=None,is_G=False,cosine=True): + num_ood_losses = len(losses)-1 + if is_G: + G = params + else: + pesudo_opt = SGD(params,lr=1e-6) + grads = [] + for cur_loss in losses: + pesudo_opt.zero_grad() + cur_loss.backward(retain_graph=True) + cur_grad = [] + for param in params: + if param.grad is not None: + cur_grad.append(Variable(param.grad.data.clone().flatten(), requires_grad=False)) + # print(torch.cat(cur_grad).sum()) + grads.append(torch.cat(cur_grad)) + G = torch.stack(grads) + if cosine: + G = F.normalize(G,dim=1) + GG = (G @ G.T).cpu() + if preference is not None: + G_weights = preference[1:]/np.sum(preference[1:]) + else: + G_weights = np.ones(num_ood_losses)/num_ood_losses + grad_sim =G_weights.dot(GG[0,1:]) + return grad_sim.item() diff --git a/ColoredMNIST/pair_alg.py b/ColoredMNIST/pair_alg.py new file mode 100644 index 0000000..10cc2c6 --- /dev/null +++ b/ColoredMNIST/pair_alg.py @@ -0,0 +1,134 @@ +import argparse +import copy +from collections import OrderedDict + +import numpy as np +import torch +import torch.nn.functional as F +from backpack import backpack, extend +from backpack.extensions import BatchGrad +from torch import autograd, nn, optim +from torch.autograd import Variable +from torch.optim import Adam +from torch.utils.data import DataLoader +from torchvision import datasets + +from misc import MovingAverage +from models import MLP, Net, TopMLP +from mydatasets import coloredmnist +from utils import (EMA, GeneralizedCELoss, correct_pred, mean_accuracy, + mean_mse, mean_nll, mean_weight, pretty_print, validation, + validation_details) +from pair import PAIR + + +def IRM_penalty_pair(envs_logits, envs_y, scale, lossf): + + train_penalty = 0 + for i in range(len(envs_logits)): + loss = lossf(envs_logits[i], envs_y[i]) + grad0 = autograd.grad(loss, [scale], create_graph=True)[0] + train_penalty += torch.sum(grad0**2) + + train_penalty /= len(envs_logits) + + return train_penalty + +def IRM_penalty_single(env_logits, env_y, scale, lossf): + + loss = lossf(env_logits*scale, env_y) + grad0 = autograd.grad(loss, [scale], create_graph=True)[0] + train_penalty = torch.sum(grad0**2) + + train_penalty /= len(env_logits) + + return train_penalty + +def pair_train(mlp, topmlp, steps, envs, test_envs, lossf, \ + penalty_anneal_iters, penalty_term_weight, anneal_val, \ + lr,l2_regularizer_weight, freeze_featurizer=False, eval_steps= 5, verbose=True,hparams={}): + net = Net(mlp,topmlp) + if freeze_featurizer: + trainable_params = [var for var in mlp.parameters()] + for param in mlp.parameters(): + param.requires_grad = False + else: + trainable_params = [var for var in mlp.parameters()] + \ + [var for var in topmlp.parameters()] + + if hparams['opt'].lower() == 'sgd': + optimizer = optim.SGD(trainable_params, lr=lr) + elif hparams['opt'].lower() == 'pair': + optimizer = optim.Adam( trainable_params, lr=1e-3) + else: + optimizer = optim.Adam( trainable_params, lr=lr) + + logs = [] + for step in range(steps): + envs_logits = [] + envs_y = [] + erm_losses = [] + scale = torch.tensor([1.])[0].cuda().requires_grad_() + for env in envs: + logits = topmlp(mlp(env['images'])) * scale + env['nll'] = lossf(logits, env['labels']) + env['acc'] = mean_accuracy(logits, env['labels']) + envs_logits.append(logits) + envs_y.append(env['labels']) + erm_losses.append(env['nll']) + + irm_penalty = IRM_penalty_pair(envs_logits, envs_y,scale, lossf) + erm_losses = torch.stack(erm_losses) + vrex_penalty = erm_losses.var() + erm_loss = erm_losses.mean() + alphas = np.array([0]) + device = logits.device + + # Compile loss + losses = torch.stack([erm_loss,irm_penalty,vrex_penalty]).to(device) + + if step >= penalty_anneal_iters: + if step==penalty_anneal_iters: + r = 1e-12 + r2 = 1e10 + r_l2 = r*r2 + preference = np.array([r]+[r_l2,(1-r-r_l2)]) + inner_optimizer = optim.SGD(trainable_params, lr=lr,momentum=0.9) + optimizer = PAIR(topmlp.parameters(),inner_optimizer,preference=preference,eps=1e-1,verbose=hparams['opt_verbose'],coe=hparams['opt_coe']) + print(f"Switch optimizer to {optimizer}") + optimizer.zero_grad() + optimizer.set_losses(losses=losses) + pair_loss, moo_losses, mu_rl, alphas = optimizer.step() + pair_res = np.array([pair_loss, mu_rl, alphas]) + else: + loss = erm_loss + + weight_norm = 0 + for w in [var for var in mlp.parameters()] + [var for var in topmlp.parameters()]: + weight_norm += w.norm().pow(2) + penalty_weight = (penalty_term_weight + if step >= penalty_anneal_iters else anneal_val) + if penalty_weight > 1.0: + # Rescale the entire loss to keep gradients in a reasonable range + weight_norm /= penalty_weight + + loss += l2_regularizer_weight * weight_norm + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if step % eval_steps == 0: + train_loss, train_acc, test_worst_loss, test_worst_acc = \ + validation(topmlp, mlp, envs, test_envs, lossf) + + print_log = [np.int32(step), train_loss, train_acc, \ + losses.detach().cpu().numpy(),alphas,test_worst_loss, test_worst_acc] + log = [np.int32(step), train_loss, train_acc,\ + losses.detach().cpu().numpy(),test_worst_loss, test_worst_acc] + logs.append(log) + if verbose: + pretty_print(*print_log) + + return (train_acc, train_loss, test_worst_acc, test_worst_loss), logs + diff --git a/ColoredMNIST/run_exp.py b/ColoredMNIST/run_exp.py new file mode 100644 index 0000000..3c63a59 --- /dev/null +++ b/ColoredMNIST/run_exp.py @@ -0,0 +1,246 @@ +import argparse +import copy +import os +from collections import OrderedDict + +import numpy as np +import torch +import torch.nn.functional as F +from backpack import backpack, extend +from backpack.extensions import BatchGrad +from torch import autograd, nn, optim +from torchvision import datasets + +from models import MLP, TopMLP +from mydatasets import coloredmnist +from train import get_train_func +from utils import (EMA, GeneralizedCELoss, correct_pred, mean_accuracy, + mean_mse, mean_nll, mean_weight, parse_bool, pretty_print, + validation) + + +def main(flags): + if flags.save_dir is not None and not os.path.exists(flags.save_dir): + os.makedirs(flags.save_dir) + flags.freeze_featurizer = False if flags.freeze_featurizer.lower() == 'false' else True + final_train_accs = [] + final_train_losses = [] + final_test_accs = [] + final_test_losses = [] + final_grad_sims = [] + logs = [] + + + for restart in range(flags.n_restarts): + if flags.seed>=0 and restart != flags.seed: + print(f"Jump over seed {restart}") + continue + if flags.verbose: + print("Restart", restart) + + + ### loss function binary_cross_entropy + input_dim = 2 * 14 * 14 + if flags.methods in ['rsc', 'lff']: + n_targets = 2 + lossf = F.cross_entropy + int_target = True + else: + n_targets = 1 + lossf = mean_nll + int_target = False + + + np.random.seed(restart) + torch.manual_seed(restart) + ### load datasets + if flags.dataset == 'coloredmnist025' or flags.dataset == 'coloredmnist25': + envs, test_envs = coloredmnist(0.25, 0.1, 0.2, int_target = int_target) + elif flags.dataset == 'coloredmnist025gray': + envs, test_envs = coloredmnist(0.25, 0.5, 0.5,int_target = int_target) + elif flags.dataset == 'coloredmnist01': + envs, test_envs = coloredmnist(0.1, 0.2, 0.25, int_target = int_target) + elif flags.dataset == 'coloredmnist01gray': + envs, test_envs = coloredmnist(0.1, 0.5, 0.5, int_target = int_target) + elif flags.dataset == 'coloredmnist': + envs, test_envs = coloredmnist(flags.flip_p,flags.envs_p[0],flags.envs_p[1], int_target = int_target) + else: + raise NotImplementedError + + + mlp = MLP(hidden_dim = flags.hidden_dim, input_dim=input_dim).cuda() + topmlp = TopMLP(hidden_dim = flags.hidden_dim, n_top_layers=flags.n_top_layers, \ + n_targets=n_targets, fishr= flags.methods in ['fishr']).cuda() + + print(mlp, topmlp) + + if flags.load_model_dir is not None and os.path.exists(flags.load_model_dir): + device = torch.device("cuda") + state = torch.load(os.path.join(flags.load_model_dir,'mlp%d.pth' % restart), map_location=device) + mlp.load_state_dict(state) + + state = torch.load(os.path.join(flags.load_model_dir,'topmlp%d.pth' % restart), map_location=device) + topmlp.load_state_dict(state) + print("Load model from %s" % flags.load_model_dir) + + + if len(flags.group_dirs)>0: + print('load groups') + x = torch.cat([env['images'] for env in envs]) + y = torch.cat([env['labels'] for env in envs]) + #print(x.shape, y.shape) + groups = [np.load(os.path.join(group_dir,'group%d.npy' % restart)) for group_dir in flags.group_dirs] + n_groups = len(groups) + new_envs = [] + + for group in groups: + for val in np.unique(group): + env = {} + env['images'] = x[group == val] + env['labels'] = y[group == val] + + new_envs.append(env) + train_envs = new_envs + + else: + train_envs = envs + + train_func = get_train_func(flags.methods) + params = [mlp, topmlp, flags.steps, train_envs, test_envs,lossf,\ + flags.penalty_anneal_iters, flags.penalty_weight, \ + flags.anneal_val, flags.lr, \ + flags.l2_regularizer_weight, flags.freeze_featurizer, flags.eval_steps, flags.verbose, ] + if flags.methods in ['vrex', 'iga','irm','fishr','gm','lff','erm','dro','pair']: + hparams = {} + elif flags.methods in ['clove']: + hparams = {'batch_size': flags.batch_size, 'kernel_scale': flags.kernel_scale} + elif flags.methods in ['rsc']: + hparams = {'rsc_f_drop_factor' : flags.rsc_f, 'rsc_b_drop_factor': flags.rsc_b} + elif flags.methods in ['sd']: + hparams = {'lr_s2_decay': flags.lr_s2_decay} + else: + raise NotImplementedError + # additional exp configs + hparams['opt'] = flags.opt + # hparams['pair_bal'] = flags.pair_bal + hparams['opt_verbose'] = flags.opt_verbose + hparams['opt_coe'] = flags.opt_coe + hparams['pair_sim'] = flags.pair_sim + + res = train_func(*params,hparams) + (train_acc, train_loss, test_worst_acc, test_worst_loss), per_logs = res + + + logs.extend(per_logs) + final_train_accs.append(train_acc) + final_train_losses.append(train_loss) + final_test_accs.append(test_worst_acc) + final_test_losses.append(test_worst_loss) + + if flags.verbose: + + print('Final train acc (mean/std across restarts so far):') + print(np.mean(final_train_accs), np.std(final_train_accs)) + print('Final train loss (mean/std across restarts so far):') + print(np.mean(final_train_losses), np.std(final_train_losses)) + print('Final worest test acc (mean/std across restarts so far):') + print(np.mean(final_test_accs), np.std(final_test_accs)) + print('Final worest test loss (mean/std across restarts so far):') + print(np.mean(final_test_losses), np.std(final_test_losses)) + + results = [np.mean(final_train_accs), np.std(final_train_accs), + np.mean(final_train_losses), np.std(final_train_losses), + np.mean(final_test_accs), np.std(final_test_accs), + np.mean(final_test_losses), np.std(final_test_losses), + ] + + + + if flags.save_dir is not None: + state = mlp.state_dict() + torch.save(state, os.path.join(flags.save_dir,'mlp%d.pth' % restart)) + state = topmlp.state_dict() + torch.save(state, os.path.join(flags.save_dir,'topmlp%d.pth' % restart)) + + with torch.no_grad(): + x = torch.cat([env['images'] for env in envs]) + y = torch.cat([env['labels'] for env in envs]) + logits = topmlp(mlp(x)) + group, _ = correct_pred(logits, y) + + pseudolabel = np.copy(y.cpu().numpy().flatten()) + pseudolabel[~group] = 1-pseudolabel[~group] + np.save(os.path.join(flags.save_dir,'group%d.npy' % restart), group) + np.save(os.path.join(flags.save_dir,'pseudolabel%d.npy' % restart), pseudolabel ) + + logs = np.array(logs) + + if flags.save_dir is not None: + np.save(os.path.join(flags.save_dir,'logs.npy'), logs) + + return results, logs + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description='Colored MNIST & CowCamel') + parser.add_argument('--verbose', type=bool, default=False) + # additional exp name id + parser.add_argument('--exp_id', type=str,default='default') + parser.add_argument('--n_restarts', type=int, default=10) + parser.add_argument('--dataset', type=str, default='coloredmnist025') + parser.add_argument('--hidden_dim', type=int, default=256) + parser.add_argument('--n_top_layers', type=int, default=1) + parser.add_argument('--l2_regularizer_weight', '-l2',type=float,default=0.001) + + parser.add_argument('--opt', type=str, default='adam') + parser.add_argument('--opt_verbose',action='store_true') + parser.add_argument('--pair_sim',action='store_true') + parser.add_argument('--opt_coe', type=float, default=0.0) + parser.add_argument('--lr', type=float, default=0.001) + parser.add_argument('--steps', type=int, default=501) + parser.add_argument('--lossf', type=str, default='nll') + parser.add_argument('--penalty_anneal_iters', '-pi', type=int, default=100) + parser.add_argument('--penalty_weight', '-p', type=float, default=10000.0) + parser.add_argument('--irmx_p2', '-p2', type=float, default=-1) + parser.add_argument('--anneal_val', '-av',type=float, default=1) + + parser.add_argument('--methods', type=str, default='irmv2') + parser.add_argument('--lr_s2_decay', type=float, default=500) + parser.add_argument('--freeze_featurizer', type=str, default='False') + parser.add_argument('--eval_steps', type=int, default=5) + parser.add_argument('--seed', type=int, default=-1) # eval at a specific seed + + parser.add_argument('--load_model_dir', type=str, default=None) + parser.add_argument('--save_dir', type=str, default=None) + parser.add_argument('--group_dirs', type=str, nargs='*',default={}) + + #RSC + parser.add_argument('--rsc_f', type=float, default=0.99) + parser.add_argument('--rsc_b', type=float, default=0.97) + + #clove + parser.add_argument('--batch_size', type=int, default=512) + parser.add_argument('--kernel_scale', type=float, default=0.4) + + parser.add_argument('--n_examples', type=int, default=18000) + + parser.add_argument('--flip_p', default=0.25, type=float) + parser.add_argument('--envs_p', nargs='?', default='[0.1,0.2]', help='random seed') + parser.add_argument('--norun',type=parse_bool, default=False) + + parser.add_argument('--no_plot',action='store_true') + flags = parser.parse_args() + flags.envs_p = eval(flags.envs_p) + if flags.norun: + if flags.verbose: + print('Flags:') + for k,v in sorted(vars(flags).items()): + print("\t{}: {}".format(k, v)) + else: + main(flags) + + + + + + diff --git a/ColoredMNIST/train.py b/ColoredMNIST/train.py new file mode 100644 index 0000000..70ccc75 --- /dev/null +++ b/ColoredMNIST/train.py @@ -0,0 +1,954 @@ +import argparse +import copy +from collections import OrderedDict + +import numpy as np +import torch +import torch.nn.functional as F +from backpack import backpack, extend +from backpack.extensions import BatchGrad +from torch import autograd, nn, optim +from torchvision import datasets + +from models import MLP, TopMLP +from mydatasets import coloredmnist +from utils import (EMA, GeneralizedCELoss, correct_pred, mean_accuracy, + mean_mse, mean_nll, mean_weight, pretty_print, validation, + validation_details) +from pair_alg import * + + +def IGA_penalty(envs_logits, envs_y, scale, lossf): + + grads = [] + grad_mean = 0 + for i in range(len(envs_logits)): + + loss = lossf(envs_logits[i], envs_y[i]) + grad0 = [val.view(-1) for val in autograd.grad(loss, scale, create_graph=True)] + grad0 = torch.cat(grad0) + grads.append(grad0) + grad_mean += grad0 / len(envs_logits) + + grad_mean = grad_mean.detach() + + train_penalty = 0 + for i in range(len(grads)): + train_penalty += torch.sum((grads[i] - grad_mean)**2) + + return train_penalty + +def IRM_penalty(envs_logits, envs_y, scale, lossf): + + train_penalty = 0 + for i in range(len(envs_logits)): + loss = lossf(envs_logits[i], envs_y[i]) + grad0 = autograd.grad(loss, [scale], create_graph=True)[0] + train_penalty += torch.sum(grad0**2) + + train_penalty /= len(envs_logits) + + return train_penalty + +def GM_penalty(envs_logits, envs_y, scale, lossf): + + grads = [] + grad_mean = 0 + for i in range(len(envs_logits)): + + loss = lossf(envs_logits[i], envs_y[i]) + grad0 = [val.view(-1) for val in autograd.grad(loss, scale, create_graph=True)] + grad0 = torch.cat(grad0) + grads.append(grad0) + + train_penalty = 0 + for i in range(len(grads)-1): + for j in range(i+1,len(grads)): + train_penalty += -torch.sum(grads[i]*grads[j]) + + return train_penalty + + +def rsc_train(mlp, topmlp, + steps, + envs, test_envs, + lossf, \ + penalty_anneal_iters, penalty_term_weight, anneal_val, \ + lr,l2_regularizer_weight, freeze_featurizer=False, verbose=True,eval_steps=1, hparams={}): + + if freeze_featurizer: + optimizer = optimizer = optim.Adam( [var for var in topmlp.parameters()], lr=lr) + else: + optimizer = optimizer = optim.Adam([var for var in mlp.parameters()] + \ + [var for var in topmlp.parameters()], lr=lr) + + drop_f = (1 - hparams['rsc_f_drop_factor']) * 100 + drop_b = (1 - hparams['rsc_b_drop_factor']) * 100 + num_classes = 2 + logs = [] + for step in range(steps): + # inputs + all_x = torch.cat([envs[i]['images'] for i in range(len(envs))]) + all_y = torch.cat([envs[i]['labels'] for i in range(len(envs))]) + + + + # one-hot labels + all_o = torch.nn.functional.one_hot(all_y, num_classes) + # features + all_f = mlp(all_x) + # predictions + all_p = topmlp(all_f) + + if step < penalty_anneal_iters: + loss = F.cross_entropy(all_p, all_y) + optimizer.zero_grad() + loss.backward() + optimizer.step() + else: + # Equation (1): compute gradients with respect to representation + all_g = autograd.grad((all_p * all_o).sum(), all_f)[0] + + # Equation (2): compute top-gradient-percentile mask + percentiles = np.percentile(all_g.cpu(), drop_f, axis=1) + percentiles = torch.Tensor(percentiles) + percentiles = percentiles.unsqueeze(1).repeat(1, all_g.size(1)) + mask_f = all_g.lt(percentiles.cuda()).float() + + # Equation (3): mute top-gradient-percentile activations + all_f_muted = all_f * mask_f + + # Equation (4): compute muted predictions + all_p_muted = topmlp(all_f_muted) + + # Section 3.3: Batch Percentage + all_s = F.softmax(all_p, dim=1) + all_s_muted = F.softmax(all_p_muted, dim=1) + changes = (all_s * all_o).sum(1) - (all_s_muted * all_o).sum(1) + percentile = np.percentile(changes.detach().cpu(), drop_b) + mask_b = changes.lt(percentile).float().view(-1, 1) + mask = torch.logical_or(mask_f, mask_b).float() + + # Equations (3) and (4) again, this time mutting over examples + all_p_muted_again = topmlp(all_f * mask) + + # Equation (5): update + loss = F.cross_entropy(all_p_muted_again, all_y) + #print(loss) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + + if step % eval_steps == 0: + train_loss, train_acc, test_worst_loss, test_worst_acc = \ + validation(topmlp, mlp, envs, test_envs, lossf) + log = [np.int32(step), train_loss, train_acc,\ + np.int32(0),test_worst_loss, test_worst_acc] + logs.append(log) + if verbose: + pretty_print(*log) + + return (train_acc, train_loss, test_worst_acc, test_worst_loss), logs + +def vrex_train(mlp, topmlp, steps, envs, test_envs, lossf, \ + penalty_anneal_iters, penalty_term_weight, anneal_val, \ + lr,l2_regularizer_weight,freeze_featurizer=False, eval_steps=5, verbose=True ): + logs = [] + + if freeze_featurizer: + optimizer = optimizer = optim.Adam( [var for var in topmlp.parameters()], lr=lr) + for param in mlp.parameters(): + param.requires_grad = False + + else: + optimizer = optimizer = optim.Adam([var for var in mlp.parameters()] + \ + [var for var in topmlp.parameters()], lr=lr) + + for step in range(steps): + + train_penalty = 0 + erm_losses = [] + for env in envs: + logits = topmlp(mlp(env['images'])) + #lossf = mean_nll if flags.lossf == 'nll' else mean_mse + env['nll'] = lossf(logits, env['labels']) + env['acc'] = mean_accuracy(logits, env['labels']) + erm_losses.append(env['nll']) + + erm_losses = torch.stack(erm_losses) + + train_penalty = erm_losses.var() + erm_loss = erm_losses.sum() + + loss = erm_loss.clone() + + weight_norm = 0 + for w in [var for var in mlp.parameters()] + [var for var in topmlp.parameters()]: + weight_norm += w.norm().pow(2) + loss += l2_regularizer_weight * weight_norm + + penalty_weight = (penalty_term_weight if step >= penalty_anneal_iters else anneal_val) + loss += penalty_weight * train_penalty + if penalty_weight > 1.0: + # Rescale the entire loss to keep gradients in a reasonable range + loss /= penalty_weight + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + + if step % eval_steps == 0: + train_loss, train_acc, test_worst_loss, test_worst_acc = \ + validation(topmlp, mlp, envs, test_envs, lossf) + log = [np.int32(step), train_loss, train_acc,\ + train_penalty.detach().cpu().numpy(),test_worst_loss, test_worst_acc] + logs.append(log) + if verbose: + pretty_print(*log) + + return (train_acc, train_loss, test_worst_acc, test_worst_loss), logs + +def iga_train(mlp, topmlp, steps, envs, test_envs, lossf, \ + penalty_anneal_iters, penalty_term_weight, anneal_val, \ + lr,l2_regularizer_weight,freeze_featurizer=False, verbose=True, eval_steps = 5, hparams={}): + + if freeze_featurizer: + optimizer = optimizer = optim.Adam( [var for var in topmlp.parameters()], lr=lr) + + else: + optimizer = optimizer = optim.Adam([var for var in mlp.parameters()] + \ + [var for var in topmlp.parameters()], lr=lr) + + + logs = [] + for step in range(steps): + train_penalty = 0 + envs_logits = [] + envs_y = [] + erm_loss = 0 + for env in envs: + logits = topmlp(mlp(env['images'])) + env['nll'] = lossf(logits, env['labels']) + env['acc'] = mean_accuracy(logits, env['labels']) + envs_logits.append(logits) + envs_y.append(env['labels']) + erm_loss += env['nll'] + + if freeze_featurizer: + params = [var for var in topmlp.parameters()] + else: + params = [var for var in mlp.parameters()] + [var for var in topmlp.parameters()] + train_penalty = IGA_penalty(envs_logits, envs_y, params, lossf) + + + + loss = erm_loss.clone() + + + weight_norm = 0 + for w in [var for var in mlp.parameters()] + [var for var in topmlp.parameters()]: + weight_norm += w.norm().pow(2) + loss += l2_regularizer_weight * weight_norm + + + penalty_weight = (penalty_term_weight + if step >= penalty_anneal_iters else anneal_val) + loss += penalty_weight * train_penalty + if penalty_weight > 1.0: + # Rescale the entire loss to keep gradients in a reasonable range + loss /= penalty_weight + + optimizer.zero_grad() + loss.backward() + optimizer.step() + if step % eval_steps == 0: + train_loss, train_acc, test_worst_loss, test_worst_acc = \ + validation(topmlp, mlp, envs, test_envs, lossf) + log = [np.int32(step), train_loss, train_acc,\ + train_penalty.detach().cpu().numpy(),test_worst_loss, test_worst_acc] + logs.append(log) + if verbose: + pretty_print(*log) + return (train_acc, train_loss, test_worst_acc, test_worst_loss), logs + +def dro_train(mlp, topmlp, steps, envs, test_envs, lossf, \ + penalty_anneal_iters, penalty_term_weight, anneal_val, \ + lr,l2_regularizer_weight,freeze_featurizer=False, verbose=True,eval_steps=5,hparams={}): + + if freeze_featurizer: + optimizer = optimizer = optim.Adam( [var for var in topmlp.parameters()], lr=lr) + for param in mlp.parameters(): + param.requires_grad = False + + else: + optimizer = optimizer = optim.Adam([var for var in mlp.parameters()] + \ + [var for var in topmlp.parameters()], lr=lr) + + logs = [] + for step in range(steps): + train_penalty = 0 + envs_logits = [] + envs_y = [] + + erm_losses = [] + + for env in envs: + logits = topmlp(mlp(env['images'])) + env['nll'] = lossf(logits, env['labels']) + env['acc'] = mean_accuracy(logits, env['labels']) + envs_logits.append(logits) + envs_y.append(env['labels']) + erm_losses.append(env['nll']) + + loss = max(erm_losses) + + weight_norm = 0 + for w in [var for var in mlp.parameters()] + [var for var in topmlp.parameters()]: + weight_norm += w.norm().pow(2) + loss += l2_regularizer_weight * weight_norm + + optimizer.zero_grad() + loss.backward() + optimizer.step() + if step % eval_steps == 0: + train_loss, train_acc, test_losses, test_acces = \ + validation_details(topmlp, mlp, envs, test_envs, lossf) + + log = [np.int32(step), train_loss, train_acc,\ + np.int32(0),*test_losses, *test_acces] + logs.append(log) + if verbose: + pretty_print(*log) + return (train_acc, train_loss, test_losses, test_acces), logs + +def sd_train(mlp, topmlp, steps, envs, test_envs, lossf, \ + penalty_anneal_iters, penalty_term_weight, anneal_val, \ + lr,l2_regularizer_weight,freeze_featurizer=False, verbose=True,eval_steps=5, hparams={'lr_s2_decay':500}): + if freeze_featurizer: + optimizer = optimizer = optim.Adam( [var for var in topmlp.parameters()], lr=lr) + for param in mlp.parameters(): + param.requires_grad = False + + else: + optimizer = optimizer = optim.Adam([var for var in mlp.parameters()] + \ + [var for var in topmlp.parameters()], lr=lr) + logs = [] + for step in range(steps): + + train_penalty = 0 + erm_loss = 0 + for env in envs: + logits = topmlp(mlp(env['images'])) + + #lossf = mean_nll if lossf == 'nll' else mean_mse + + env['nll'] = lossf(logits, env['labels']) + env['acc'] = mean_accuracy(logits, env['labels']) + + train_penalty += (logits**2).mean() + erm_loss += env['nll'] + + + loss = erm_loss.clone() + + + weight_norm = 0 + for w in [var for var in mlp.parameters()] \ + +[var for var in topmlp.parameters()]: + weight_norm += w.norm().pow(2) + + loss += l2_regularizer_weight * weight_norm + + + penalty_weight = (penalty_term_weight + if step >= penalty_anneal_iters else anneal_val) + loss += penalty_weight * train_penalty + if penalty_weight > 1.0: + # Rescale the entire loss to keep gradients in a reasonable range + loss /= penalty_weight + + if penalty_anneal_iters > 0 and step >= penalty_anneal_iters: + # using anneal, so decay lr + loss /= hparams['lr_s2_decay'] + + optimizer.zero_grad() + loss.backward() + optimizer.step() + if step % eval_steps == 0: + train_loss, train_acc, test_worst_loss, test_worst_acc = \ + validation(topmlp, mlp, envs, test_envs, lossf) + log = [np.int32(step), train_loss, train_acc,\ + train_penalty.detach().cpu().numpy(),test_worst_loss, test_worst_acc] + logs.append(log) + if verbose: + pretty_print(*log) + + return (train_acc, train_loss, test_worst_acc, test_worst_loss), logs + +def irm_train(mlp, topmlp, steps, envs, test_envs, lossf, \ + penalty_anneal_iters, penalty_term_weight, anneal_val, \ + lr,l2_regularizer_weight, freeze_featurizer=False, verbose=True, eval_steps= 5,hparams={}): + if freeze_featurizer: + optimizer = optimizer = optim.Adam( [var for var in topmlp.parameters()], lr=lr) + for param in mlp.parameters(): + param.requires_grad = False + + else: + optimizer = optimizer = optim.Adam([var for var in mlp.parameters()] + \ + [var for var in topmlp.parameters()], lr=lr) + + logs = [] + for step in range(steps): + + + envs_logits = [] + envs_y = [] + erm_loss = 0 + scale = torch.tensor([1.])[0].cuda().requires_grad_() + for env in envs: + logits = topmlp(mlp(env['images'])) * scale + env['nll'] = lossf(logits, env['labels']) + env['acc'] = mean_accuracy(logits, env['labels']) + envs_logits.append(logits) + envs_y.append(env['labels']) + erm_loss += env['nll'] + + train_penalty = IRM_penalty(envs_logits, envs_y,scale, lossf) + + loss = erm_loss.clone() + + + weight_norm = 0 + for w in [var for var in mlp.parameters()] + [var for var in topmlp.parameters()]: + weight_norm += w.norm().pow(2) + + loss += l2_regularizer_weight * weight_norm + + + penalty_weight = (penalty_term_weight + if step >= penalty_anneal_iters else anneal_val) + loss += penalty_weight * train_penalty + if penalty_weight > 1.0: + # Rescale the entire loss to keep gradients in a reasonable range + loss /= penalty_weight + + optimizer.zero_grad() + loss.backward() + optimizer.step() + if step % eval_steps == 0: + train_loss, train_acc, test_worst_loss, test_worst_acc = \ + validation(topmlp, mlp, envs, test_envs, lossf) + log = [np.int32(step), train_loss, train_acc,\ + train_penalty.detach().cpu().numpy(),test_worst_loss, test_worst_acc] + logs.append(log) + if verbose: + pretty_print(*log) + + return (train_acc, train_loss, test_worst_acc, test_worst_loss), logs + +def clove_train(mlp, topmlp, steps, envs, test_envs, lossf, \ + penalty_anneal_iters, penalty_term_weight, anneal_val, \ + lr,l2_regularizer_weight,freeze_featurizer=False, verbose=True,eval_steps=5,hparams={}): + if freeze_featurizer: + optimizer = optimizer = optim.Adam( [var for var in topmlp.parameters()], lr=lr) + for param in mlp.parameters(): + param.requires_grad = False + + else: + optimizer = optimizer = optim.Adam([var for var in mlp.parameters()] + \ + [var for var in topmlp.parameters()], lr=lr) + + logs = [] + batch_size = hparams['batch_size'] if 'batch_size' in hparams else 512 + kernel_scale = hparams['kernel_scale'] if 'kernel_scale' in hparams else 0.4 + def mmce_penalty(logits, y, kernel='laplacian'): + + c = ~((logits.flatten() > 0) ^ (y.flatten()>0.5)) + c = c.detach().float() + + preds = F.sigmoid(logits).flatten() + + y_hat = (preds < 0.5).detach().bool() + + confidence = torch.ones(len(y_hat)).cuda() + confidence[y_hat] = 1-preds[y_hat] + confidence[~y_hat] = preds[~y_hat] + + k = (-(confidence.view(-1,1)-confidence).abs() / kernel_scale).exp() + + + conf_diff = (c - confidence).view(-1,1) * (c -confidence) + + res = conf_diff * k + + return res.sum() / (len(logits)**2) + + pretty_print('step', 'train nll', 'train acc', 'train penalty', 'test acc') + + for step in range(steps): + length = min(len(envs[0]['labels']), len(envs[1]['labels'])) + + idx0 = np.arange(length) + np.random.shuffle(idx0) + idx1 = np.arange(length) + np.random.shuffle(idx1) + idx = [idx0, idx1] + + for i in range(length // batch_size): + + train_penalty = 0 + train_nll = 0 + train_acc = 0 + for j, env in enumerate(envs[0:2]): + x, y = env['images'], env['labels'] + x_batch, y_batch = x[idx[j][i*batch_size:(i+1)*batch_size]], y[idx[j][i*batch_size:(i+1)*batch_size]] + logits = topmlp(mlp(x_batch)) + nll = mean_nll(logits, y_batch) + acc = mean_accuracy(logits, y_batch) + mmce = mmce_penalty(logits, y_batch) + train_penalty += mmce + train_nll += nll + train_acc += acc + + train_acc /=2 + + + weight_norm = torch.tensor(0.).cuda() + for w in mlp.parameters(): + weight_norm += w.norm().pow(2) + + loss = train_nll.clone() + penalty_weight = (penalty_term_weight + if step >= penalty_anneal_iters else anneal_val) + loss += penalty_weight * train_penalty + if penalty_weight > 1.0: + # Rescale the entire loss to keep gradients in a reasonable range + loss /= penalty_weight + + optimizer.zero_grad() + + + loss.backward() + optimizer.step() + + + if step % eval_steps == 0: + train_loss, train_acc, test_worst_loss, test_worst_acc = \ + validation(topmlp, mlp, envs, test_envs, lossf) + log = [np.int32(step), train_loss, train_acc,\ + train_penalty.detach().cpu().numpy(),test_worst_loss, test_worst_acc] + logs.append(log) + if verbose: + pretty_print(*log) + + return (train_acc, train_loss, test_worst_acc, test_worst_loss), logs + +def fishr_train(mlp, topmlp, steps, envs, test_envs, lossf, \ + penalty_anneal_iters, penalty_term_weight, anneal_val, \ + lr,l2_regularizer_weight,freeze_featurizer=False, verbose=True, eval_steps=5, hparams={}): + + def compute_grads_variance(features, labels, classifier): + logits = classifier(features) + loss = bce_extended(logits, labels) + + with backpack(BatchGrad()): + loss.backward( + inputs=list(classifier.parameters()), retain_graph=True, create_graph=True + ) + + dict_grads = OrderedDict( + [ + (name, weights.grad_batch.clone().view(weights.grad_batch.size(0), -1)) + for name, weights in classifier.named_parameters() + ] + ) + dict_grads_variance = {} + for name, _grads in dict_grads.items(): + grads = _grads * labels.size(0) # multiply by batch size + env_mean = grads.mean(dim=0, keepdim=True) + + dict_grads_variance[name] = (grads).pow(2).mean(dim=0) + + return dict_grads_variance + + def l2_between_grads_variance(cov_1, cov_2): + assert len(cov_1) == len(cov_2) + cov_1_values = [cov_1[key] for key in sorted(cov_1.keys())] + cov_2_values = [cov_2[key] for key in sorted(cov_2.keys())] + return ( + torch.cat(tuple([t.view(-1) for t in cov_1_values])) - + torch.cat(tuple([t.view(-1) for t in cov_2_values])) + ).pow(2).sum() + + if freeze_featurizer: + optimizer = optimizer = optim.Adam( [var for var in topmlp.parameters()], lr=lr) + for param in mlp.parameters(): + param.requires_grad = False + else: + optimizer = optimizer = optim.Adam([var for var in mlp.parameters()] + \ + [var for var in topmlp.parameters()], lr=lr) + logs = [] + + bce_extended = extend(nn.BCEWithLogitsLoss()) + for step in range(steps): + for edx, env in enumerate(envs): + features = mlp(env['images']) + logits = topmlp(features) + env['nll'] = mean_nll(logits, env['labels']) + env['acc'] = mean_accuracy(logits, env['labels']) + if edx in [0, 1]: + # True when the dataset is in training + optimizer.zero_grad() + env["grads_variance"] = compute_grads_variance(features, env['labels'], topmlp) + + train_nll = torch.stack([envs[0]['nll'], envs[1]['nll']]).sum() + train_acc = torch.stack([envs[0]['acc'], envs[1]['acc']]).mean() + + weight_norm = torch.tensor(0.).cuda() + for w in mlp.parameters(): + weight_norm += w.norm().pow(2) + + loss = train_nll.clone() + loss += l2_regularizer_weight * weight_norm + + dict_grads_variance_averaged = OrderedDict( + [ + ( + name, + torch.stack([envs[0]["grads_variance"][name], envs[1]["grads_variance"][name]], + dim=0).mean(dim=0) + ) for name in envs[0]["grads_variance"] + ] + ) + fishr_penalty = ( + l2_between_grads_variance(envs[0]["grads_variance"], dict_grads_variance_averaged) + + l2_between_grads_variance(envs[1]["grads_variance"], dict_grads_variance_averaged) + ) + train_penalty = fishr_penalty + + + penalty_weight = (penalty_term_weight + if step >= penalty_anneal_iters else anneal_val) + loss += penalty_weight * train_penalty + if penalty_weight > 1.0: + # Rescale the entire loss to keep gradients in a reasonable range + loss /= penalty_weight + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if step % eval_steps == 0: + train_loss, train_acc, test_worst_loss, test_worst_acc = \ + validation(topmlp, mlp, envs, test_envs, lossf) + log = [np.int32(step), train_loss, train_acc,\ + train_penalty.detach().cpu().numpy(),test_worst_loss, test_worst_acc] + logs.append(log) + if verbose: + pretty_print(*log) + + return (train_acc, train_loss, test_worst_acc, test_worst_loss), logs + +def gm_train(mlp, topmlp, steps, envs, test_envs, lossf, \ + penalty_anneal_iters, penalty_term_weight, anneal_val, \ + lr,l2_regularizer_weight, freeze_featurizer=False, verbose=True, eval_steps=5, hparams={}): + + if freeze_featurizer: + optimizer = optimizer = optim.Adam( [var for var in topmlp.parameters()], lr=lr) + else: + optimizer = optimizer = optim.Adam([var for var in mlp.parameters()] + \ + [var for var in topmlp.parameters()], lr=lr) + logs = [] + for step in range(steps): + + train_penalty = 0 + envs_logits = [] + envs_y = [] + for env in envs: + logits = topmlp(mlp(env['images'])) + env['nll'] = lossf(logits, env['labels']) + env['acc'] = mean_accuracy(logits, env['labels']) + envs_logits.append(logits) + envs_y.append(env['labels']) + + + + train_penalty = GM_penalty(envs_logits, envs_y, [var for var in mlp.parameters()] + [var for var in topmlp.parameters()], lossf) + + erm_loss = (envs[0]['nll'] + envs[1]['nll']) + + + loss = erm_loss.clone() + + + weight_norm = 0 + for w in [var for var in mlp.parameters()] + [var for var in topmlp.parameters()]: + weight_norm += w.norm().pow(2) + + loss += flags.l2_regularizer_weight * weight_norm + + + penalty_weight = (flags.penalty_weight + if step >= flags.penalty_anneal_iters else flags.anneal_val) + loss += penalty_weight * train_penalty + if penalty_weight > 1.0: + # Rescale the entire loss to keep gradients in a reasonable range + loss /= penalty_weight + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + if step % eval_steps == 0: + train_loss, train_acc, test_worst_loss, test_worst_acc = \ + validation(topmlp, mlp, envs, test_envs, lossf) + log = [np.int32(step), train_loss, train_acc,\ + train_penalty.detach().cpu().numpy(),test_worst_loss, test_worst_acc] + logs.append(log) + if verbose: + pretty_print(*log) + + return (train_acc, train_loss, test_worst_acc, test_worst_loss), logs + +def lff_train(mlp, topmlp, steps, envs, test_envs, lossf, \ + penalty_anneal_iters, penalty_term_weight, anneal_val, \ + lr,l2_regularizer_weight, freeze_featurizer=False, verbose=True,eval_steps=5, hparams={}): + + if freeze_featurizer is False: + raise NotImplementedError + + x = torch.cat([envs[i]['images'] for i in range(len(envs))]) + y = torch.cat([envs[i]['labels'] for i in range(len(envs))]) + + y = y.long().flatten() + logs = [] + if penalty_anneal_iters > 0: + optimizer = torch.optim.Adam([var for var in mlp.parameters()] \ + + [var for var in topmlp.parameters()], + lr=lr, weight_decay=l2_regularizer_weight,) + + for step in range(penalty_anneal_iters): + logits = topmlp(mlp(x)) + loss = F.cross_entropy(logits, y) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + train_penalty = torch.tensor([0]).cuda()[0] + + if step % 5 == 0: + train_loss, train_acc, test_worst_loss, test_worst_acc = \ + validation(topmlp, mlp, envs, test_envs, lossf) + log = [np.int32(step), train_loss, train_acc,\ + train_penalty.detach().cpu().numpy(),test_worst_loss, test_worst_acc] + logs.append(log) + if verbose: + pretty_print(*log) + + + _mlp = copy.deepcopy(mlp) + _topmlp = copy.deepcopy(topmlp) + + model_b = torch.nn.Sequential(_mlp, _topmlp) + + model_d = torch.nn.Sequential(mlp, topmlp) + + optimizer_b = torch.optim.Adam( + model_b.parameters(), + lr=lr / 100, + weight_decay=l2_regularizer_weight, + ) + optimizer_d = torch.optim.Adam( + model_d.parameters(), + lr=lr / 100, + weight_decay= l2_regularizer_weight, + ) + lossf = nn.CrossEntropyLoss(reduction='mean') + criterion = nn.CrossEntropyLoss(reduction='none') + bias_criterion = GeneralizedCELoss(q = penalty_term_weight) + + sample_loss_ema_b = EMA(y.cpu().numpy(), alpha=0.7) + sample_loss_ema_d = EMA(y.cpu().numpy(), alpha=0.7) + + index = np.arange(len(y)) + for step in range(penalty_anneal_iters, steps): + + logit_b = model_b(x) + logit_d = model_d(x) + + loss_b = criterion(logit_b, y).cpu().detach() + loss_d = criterion(logit_d, y).cpu().detach() + + sample_loss_ema_b.update(loss_b,index) + sample_loss_ema_d.update(loss_d,index) + + loss_b = sample_loss_ema_b.parameter[index].clone().detach() + loss_d = sample_loss_ema_d.parameter[index].clone().detach() + + # mnist target has one class, so I can do in this way. + label_cpu = y.cpu() + num_classes = 2 + for c in range(num_classes): + class_index = np.where(label_cpu == c)[0] + max_loss_b = sample_loss_ema_b.max_loss(c) + max_loss_d = sample_loss_ema_d.max_loss(c) + loss_b[class_index] /= max_loss_b + loss_d[class_index] /= max_loss_d + + loss_weight = loss_b / (loss_b + loss_d + 1e-8) + + loss_b_update = bias_criterion(logit_b, y) + loss_d_update = criterion(logit_d, y) * loss_weight.cuda() + loss = loss_b_update.mean() + loss_d_update.mean() + + optimizer_b.zero_grad() + optimizer_d.zero_grad() + loss.backward() + optimizer_b.step() + optimizer_d.step() + + train_penalty = torch.tensor([0]).cuda()[0] + + if step % eval_steps == 0: + train_loss, train_acc, test_worst_loss, test_worst_acc = \ + validation(topmlp, mlp, envs, test_envs, lossf) + log = [np.int32(step), train_loss, train_acc,\ + train_penalty.detach().cpu().numpy(),test_worst_loss, test_worst_acc] + logs.append(log) + if verbose: + pretty_print(*log) + return (train_acc, train_loss, test_worst_acc, test_worst_loss), logs + +def erm_train(mlp, topmlp, steps, envs, test_envs, lossf, \ + penalty_anneal_iters, penalty_term_weight, anneal_val, \ + lr,l2_regularizer_weight, freeze_featurizer=False, verbose=True,eval_steps=5, hparams={}): + + + x = torch.cat([envs[i]['images'] for i in range(len(envs))]) + y = torch.cat([envs[i]['labels'] for i in range(len(envs))]) + + if freeze_featurizer: + optimizer = optimizer = optim.Adam( [var for var in topmlp.parameters()], lr=lr) + for param in mlp.parameters(): + param.requires_grad = False + print('freeze_featurizer') + + else: + optimizer = optimizer = optim.Adam([var for var in mlp.parameters()] + \ + [var for var in topmlp.parameters()], lr=lr) + + logs = [] + for step in range(steps): + + logits = topmlp(mlp(x)) + #print(logits) + loss = lossf(logits, y) + #print(loss) + #0/0 + weight_norm = 0 + for w in [var for var in mlp.parameters()] + [var for var in topmlp.parameters()]: + weight_norm += w.norm().pow(2) + + loss += l2_regularizer_weight * weight_norm + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + train_penalty = torch.tensor([0]).cuda()[0] + + if step % eval_steps == 0: + train_loss, train_acc, test_worst_loss, test_worst_acc = \ + validation(topmlp, mlp, envs, test_envs, lossf) + log = [np.int32(step), train_loss, train_acc,\ + train_penalty.detach().cpu().numpy(),test_worst_loss, test_worst_acc] + logs.append(log) + if verbose: + pretty_print(*log) + + return (train_acc, train_loss, test_worst_acc, test_worst_loss), logs + +def syn_train(mlp, topmlp, steps, envs, test_envs, lossf, \ + penalty_anneal_iters, penalty_term_weight, anneal_val, \ + lr,l2_regularizer_weight, verbose=True,eval_steps=50, hparams={}): + + x = torch.cat([envs[i]['images'] for i in range(len(envs))]) + y = torch.cat([envs[i]['labels'] for i in range(len(envs))]) + + optimizer = optim.Adam([var for var in mlp.parameters()] \ + +[var for var in topmlp.parameters()], lr=lr) + logs = [] + ntasks = hparams['ntasks'] + for step in range(steps): + logits = topmlp(mlp(x)) + + per_logits_size = logits.shape[1] // ntasks + per_y_size = y.shape[1] // ntasks + loss = 0 + for i in range(ntasks): + + loss += lossf(logits[:, i*per_logits_size:(i+1)*per_logits_size],y[:,i*per_y_size:(i+1)*per_y_size]) + loss /= ntasks + + + weight_norm = 0 + for w in [var for var in mlp.parameters()] + [var for var in topmlp.parameters()]: + weight_norm += w.norm().pow(2) + + + loss += l2_regularizer_weight * weight_norm + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + train_penalty = torch.tensor([0]).cuda()[0] + + + if step % eval_steps == 0: + + with torch.no_grad(): + for j, env in enumerate(envs + test_envs): + logits = topmlp(mlp(env['images'])) + loss = 0 + acc = 0 + for i in range(ntasks): + per_logits = logits[:, i*per_logits_size:(i+1)*per_logits_size] + + if j < len(envs): + per_y = env['labels'][:,i*per_y_size:(i+1)*per_y_size] + else: + per_y = env['labels'] + loss += lossf(per_logits,per_y) + + acc += mean_accuracy(per_logits, per_y) + + loss /= ntasks + acc /=ntasks + + env['nll'] = loss + env['acc'] = acc + + test_worst_loss = torch.stack([env['nll'] for env in test_envs]).max() + test_worst_acc = torch.stack([env['acc'] for env in test_envs]).min() + train_loss = torch.stack([env['nll'] for env in envs]).mean() + train_acc = torch.stack([env['acc'] for env in envs]).mean() + + train_loss, train_acc, test_worst_loss, test_worst_acc = train_loss.detach().cpu().numpy(), \ + train_acc.detach().cpu().numpy(), \ + test_worst_loss.detach().cpu().numpy(),\ + test_worst_acc.detach().cpu().numpy() + log = [np.int32(step), train_loss, train_acc,\ + train_penalty.detach().cpu().numpy(),test_worst_loss, test_worst_acc] + logs.append(log) + + if verbose: + pretty_print(*log) + + return (train_acc, train_loss, test_worst_acc, test_worst_loss), logs + +def get_train_func(methods): + assert methods in ['rsc', 'vrex', 'iga','sd','irm','clove','fishr','gm','lff','erm','dro','syn','pair'] + return eval("%s_train" % methods) diff --git a/ColoredMNIST/utils.py b/ColoredMNIST/utils.py new file mode 100644 index 0000000..03bf54b --- /dev/null +++ b/ColoredMNIST/utils.py @@ -0,0 +1,161 @@ +import numpy as np +from torch import nn +import torch +import torch.nn.functional as F + +import numpy as np + +def parse_bool(v): + if v.lower()=='true': + return True + elif v.lower()=='false': + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + +def pretty_print(*values): + col_width = 13 + def format_val(v): + if not isinstance(v, str): + v = np.array2string(v, precision=5, floatmode='fixed') + return v.ljust(col_width) + str_values = [format_val(v) for v in values] + print(" ".join(str_values)) + +# Define loss function helpers +def mean_weight(weights): + + weight = copy.deepcopy(weights[0]) + for key in weight: + for val in weights[1:]: + weight[key] += val[key] + + for key in weight: + weight[key] /= len(weights) + + return weight + + +def mean_nll(logits, y, reduction='mean'): + return nn.functional.binary_cross_entropy_with_logits(logits, y,reduction=reduction) + +def mean_mse(logits, y, reduction = 'mean'): + if reduction == 'mean': + return ((logits - (2*y-1))**2).mean()/2 + elif reduction == 'none': + return ((logits - (2*y-1))**2)/2 + +def mean_accuracy(logits, y, reduction = 'mean'): + if logits.size(1) == 1: + preds = (logits > 0.).float() + if reduction == 'mean': + return ((preds - y).abs() < 1e-2).float().mean() + else: + return ((preds - y).abs() < 1e-2).float() + else: + if reduction == 'mean': + return (logits.argmax(1).eq(y).float()).mean() + else: + return (logits.argmax(1).eq(y).float()) + +def correct_pred(logits, y): + if logits.size(1) == 1: + preds = (logits > 0.).float() + correct = ((preds - y).abs() < 1e-2).float().cpu().detach().numpy().flatten().astype(bool) + + else: + correct = logits.argmax(1).eq(y).float().cpu().detach().numpy().flatten().astype(bool) + + return correct, ~correct + +def validation(topmlp, mlp, envs, test_envs, lossf): + + with torch.no_grad(): + for env in envs + test_envs: + logits = topmlp(mlp(env['images'])) + + env['nll'] = lossf(logits, env['labels']) + env['acc'] = mean_accuracy(logits, env['labels']) + + test_worst_loss = torch.stack([env['nll'] for env in test_envs]).max() + test_worst_acc = torch.stack([env['acc'] for env in test_envs]).min() + train_loss = torch.stack([env['nll'] for env in envs]).mean() + train_acc = torch.stack([env['acc'] for env in envs]).mean() + + return train_loss.detach().cpu().numpy(), train_acc.detach().cpu().numpy(), \ + test_worst_loss.detach().cpu().numpy(),test_worst_acc.detach().cpu().numpy() + +def validation_details(topmlp, mlp, envs, test_envs, lossf): + + with torch.no_grad(): + for env in envs + test_envs: + logits = topmlp(mlp(env['images'])) + + env['nll'] = lossf(logits, env['labels']) + env['acc'] = mean_accuracy(logits, env['labels']) + + train_loss = torch.stack([env['nll'] for env in envs]).mean() + train_acc = torch.stack([env['acc'] for env in envs]).mean() + + return train_loss.detach().cpu().numpy(), train_acc.detach().cpu().numpy(), \ + [env['nll'].detach().cpu().numpy() for env in test_envs], \ + [env['acc'].detach().cpu().numpy() for env in test_envs] + +def validation2(model, envs, test_envs, lossf): + + with torch.no_grad(): + for env in envs + test_envs: + logits = model(env['images']) + + env['nll'] = lossf(logits, env['labels']) + env['acc'] = mean_accuracy(logits, env['labels']) + + test_worst_loss = torch.stack([env['nll'] for env in test_envs]).max() + test_worst_acc = torch.stack([env['acc'] for env in test_envs]).min() + train_loss = torch.stack([env['nll'] for env in envs]).mean() + train_acc = torch.stack([env['acc'] for env in envs]).mean() + + return train_loss.detach().cpu().numpy(), train_acc.detach().cpu().numpy(), \ + test_worst_loss.detach().cpu().numpy(),test_worst_acc.detach().cpu().numpy() + + + + +# from https://github.com/alinlab/LfF/blob/e66796ec117ea52d2e44176055b7ef7959680a1b/module/loss.py#L8 +class GeneralizedCELoss(nn.Module): + + def __init__(self, q=0.7): + super(GeneralizedCELoss, self).__init__() + self.q = q + + def forward(self, logits, targets): + p = F.softmax(logits, dim=1) + if np.isnan(p.mean().item()): + raise NameError('GCE_p') + Yg = torch.gather(p, 1, torch.unsqueeze(targets, 1)) + # modify gradient of cross entropy + loss_weight = (Yg.squeeze().detach()**self.q)*self.q + if np.isnan(Yg.mean().item()): + raise NameError('GCE_Yg') + + loss = F.cross_entropy(logits, targets, reduction='none') * loss_weight + + return loss + +# https://github.com/alinlab/LfF/blob/e66796ec117ea52d2e44176055b7ef7959680a1b/util.py#L33 +class EMA: + + def __init__(self, label, alpha=0.9): + self.label = label + self.alpha = alpha + self.parameter = torch.zeros(label.shape[0]) + self.updated = torch.zeros(label.shape[0]) + + def update(self, data, index): + self.parameter[index] = self.alpha * self.parameter[index] + (1-self.alpha*self.updated[index]) * data + self.updated[index] = 1 + + def max_loss(self, label): + label_index = np.where(self.label == label)[0] + return self.parameter[label_index].max() + diff --git a/DomainBed/model_selection.py b/DomainBed/model_selection.py new file mode 100644 index 0000000..71e40c5 --- /dev/null +++ b/DomainBed/model_selection.py @@ -0,0 +1,483 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +import itertools +from tkinter.filedialog import test +import numpy as np +from numpy.linalg import norm + +top_percentile = 0.9 +adjust_coe = 1 + +def get_kl_div(losses, preference): + pair_score = losses.dot(preference) + return pair_score + +def get_losses(record): + if "groupdro" in record['args']['output_dir'] and 'penalty' in record.keys(): + record['loss'] = record['loss'] + record.pop("penalty") + if 'nll' in record.keys(): + erm_loss = record['nll'] + if 'mu_rl' in record.keys(): + pass + if "vrex_penalty" in record.keys() and "IRM_penalty" in record.keys(): + losses = np.array([erm_loss,record["IRM_penalty"],record["vrex_penalty"]]) + elif "nvrex_penalty" in record.keys() and "nIRM_penalty" in record.keys(): + losses = np.array([erm_loss,record["nIRM_penalty"],record["nvrex_penalty"]]) + elif 'penalty' in record.keys(): + ood_loss = record['penalty'] + losses = np.array([erm_loss,ood_loss]) + else: + if 'disc_loss' in record.keys() and 'gen_loss' not in record.keys(): + losses = np.array([(2-record['disc_loss']) if record['disc_loss']<=2 else 3]) + losses = np.array([1e9]) + elif 'gen_loss' in record.keys(): + losses = np.array([1e9]) + if np.abs(record['gen_loss'])>=50: + losses = np.array([1e9]) + else: + losses = np.array([(1e9+record['gen_loss']) if record['gen_loss']>=-1e9 else 1e9]) + else: + losses = np.array([record['loss']]) + + return losses + +def get_pair_score(record, get_loss=False,preference_base=1e-6): + if "groupdro" in record['args']['output_dir']and 'penalty' in record.keys(): + record['loss'] = record['loss'] + record.pop("penalty") + if 'nll' in record.keys(): + erm_loss = record['nll'] + if 'mu_rl' in record.keys(): + pass + if "vrex_penalty" in record.keys() and "IRM_penalty" in record.keys(): + losses = np.array([erm_loss,record["IRM_penalty"],record["vrex_penalty"]]) + if record["IRM_penalty"] < 0: + losses[1] *=-adjust_coe + preference = np.array([preference_base,1e-2,1]) + elif "nvrex_penalty" in record.keys() and "nIRM_penalty" in record.keys(): + losses = np.array([erm_loss,record["nIRM_penalty"],record["nvrex_penalty"]]) + if record["nIRM_penalty"] < 0: + losses[1] *=-adjust_coe + preference = np.array([preference_base,1e-2,1]) + elif 'penalty' in record.keys(): + ood_loss = record['penalty'] + losses = np.array([erm_loss,ood_loss]) + if record["penalty"] < 0: + losses[1] *=-adjust_coe + preference = np.array([preference_base,1]) + else: + if 'disc_loss' in record.keys() and 'gen_loss' not in record.keys(): + losses = np.array([(2-record['disc_loss']) if record['disc_loss']<=2 else 3]) + elif 'gen_loss' in record.keys(): + if np.abs(record['gen_loss'])>=50: + losses = np.array([1e9]) + else: + losses = np.array([(1e9+record['gen_loss']) if record['gen_loss']>=-1e9 else 1e9]) + else: + losses = np.array([record['loss']]) + preference = np.array([1]) if len(losses)==1 else np.array([preference_base,1]) + + pair_score = get_kl_div(losses,preference) + if get_loss: + return -pair_score, losses + return -pair_score + +def get_test_records(records): + """Given records with a common test env, get the test records (i.e. the + records with *only* that single test env and no other test envs)""" + return records.filter(lambda r: len(r['args']['test_envs']) == 1) + +class SelectionMethod: + """Abstract class whose subclasses implement strategies for model + selection across hparams and timesteps.""" + + def __init__(self): + raise TypeError + + @classmethod + def run_acc(self, run_records): + """ + Given records from a run, return a {val_acc, test_acc} dict representing + the best val-acc and corresponding test-acc for that run. + """ + raise NotImplementedError + + @classmethod + def hparams_accs(self, records): + """ + Given all records from a single (dataset, algorithm, test env) pair, + return a sorted list of (run_acc, records) tuples. + """ + return (records.group('args.hparams_seed') + .map(lambda _, run_records: + ( + self.run_acc(run_records), + run_records + ) + ).filter(lambda x: x[0] is not None) + .sorted(key=lambda x: x[0]['val_acc'])[::-1] + ) + + @classmethod + def sweep_acc(self, records): + """ + Given all records from a single (dataset, algorithm, test env) pair, + return the mean test acc of the k runs with the top val accs. + """ + _hparams_accs = self.hparams_accs(records) + if len(_hparams_accs): + return _hparams_accs[0][0]['test_acc'] + else: + return None + +class OracleSelectionMethod(SelectionMethod): + """Like Selection method which picks argmax(test_out_acc) across all hparams + and checkpoints, but instead of taking the argmax over all + checkpoints, we pick the last checkpoint, i.e. no early stopping.""" + name = "test-domain validation set (oracle)" + + @classmethod + def run_acc(self, run_records): + run_records = run_records.filter(lambda r: + len(r['args']['test_envs']) == 1) + if not len(run_records): + return None + test_env = run_records[0]['args']['test_envs'][0] + test_out_acc_key = 'env{}_out_acc'.format(test_env) + test_in_acc_key = 'env{}_in_acc'.format(test_env) + chosen_record = run_records.sorted(lambda r: r['step'])[-1] + return { + 'val_acc': chosen_record[test_out_acc_key], + 'test_acc': chosen_record[test_in_acc_key] + } + +class IIDAccuracySelectionMethod(SelectionMethod): + """Picks argmax(mean(env_out_acc for env in train_envs))""" + name = "training-domain validation set" + + @classmethod + def _step_acc(self, record): + """Given a single record, return a {val_acc, test_acc} dict.""" + test_env = record['args']['test_envs'][0] + val_env_keys = [] + for i in itertools.count(): + if f'env{i}_out_acc' not in record: + break + if i != test_env: + val_env_keys.append(f'env{i}_out_acc') + test_in_acc_key = 'env{}_in_acc'.format(test_env) + return { + 'val_acc': np.mean([record[key] for key in val_env_keys]), + 'test_acc': record[test_in_acc_key] + } + + @classmethod + def run_acc(self, run_records): + test_records = get_test_records(run_records) + if not len(test_records): + return None + return test_records.map(self._step_acc).argmax('val_acc') + +class LeaveOneOutSelectionMethod(SelectionMethod): + """Picks (hparams, step) by leave-one-out cross validation.""" + name = "leave-one-domain-out cross-validation" + + @classmethod + def _step_acc(self, records): + """Return the {val_acc, test_acc} for a group of records corresponding + to a single step.""" + test_records = get_test_records(records) + if len(test_records) != 1: + return None + + test_env = test_records[0]['args']['test_envs'][0] + n_envs = 0 + for i in itertools.count(): + if f'env{i}_out_acc' not in records[0]: + break + n_envs += 1 + val_accs = np.zeros(n_envs) - 1 + # it implicitly assumes there is a test env, and n-1 training env + # hence given n envs, it does the eval with all 2-test-env combinations + for r in records.filter(lambda r: len(r['args']['test_envs']) == 2): + val_env = (set(r['args']['test_envs']) - set([test_env])).pop() + val_accs[val_env] = r['env{}_in_acc'.format(val_env)] + + val_accs = list(val_accs[:test_env]) + list(val_accs[test_env+1:]) + if any([v==-1 for v in val_accs]): + return None + val_acc = np.sum(val_accs) / (n_envs-1) + return { + 'val_acc': val_acc, + 'test_acc': test_records[0]['env{}_in_acc'.format(test_env)] + } + + @classmethod + def run_acc(self, records): + step_accs = records.group('step').map(lambda step, step_records: + self._step_acc(step_records) + ).filter_not_none() + if len(step_accs): + return step_accs.argmax('val_acc') + else: + return None + +from domainbed.lib.query import Q +class PAIRIIDAccuracySelectionMethod(SelectionMethod): + """Model selection according to PAIR score from + Pareto Invariant Risk Minimization.""" + name = "pair training-domain validation set" + preference_base=1e-6 + + @classmethod + def _step_acc(self, record): + """Given a single record, return a {val_acc, test_acc} dict.""" + test_env = record['args']['test_envs'][0] + val_env_keys = [] + for i in itertools.count(): + if f'env{i}_out_acc' not in record: + break + if i != test_env: + val_env_keys.append(f'env{i}_out_acc') + test_in_acc_key = 'env{}_in_acc'.format(test_env) + + pair_score,losses = get_pair_score(record=record,get_loss=True,preference_base=self.preference_base) + return { + 'losses': losses, + 'pair_score': pair_score, + 'val_acc': np.mean([record[key] for key in val_env_keys]), + 'test_acc': record[test_in_acc_key] + } + + @classmethod + def run_acc(self, run_records): + """ + Given records from a run, return a {val_acc, test_acc} dict representing + the best val-acc and corresponding test-acc for that run. + """ + + test_records = get_test_records(run_records) + if not len(test_records): + return None + num_records = len(test_records) + + test_records = test_records.map(self._step_acc) + # filter out worst top_percentile% records in val acc to avoid trivial case + # return test_records.argmax('val_acc') + train_accs = [r['val_acc'] for r in test_records] + train_acc_bar = (np.max(train_accs)-np.min(train_accs))*0.8+np.min(train_accs) + pair_scores = [r['pair_score'] for r in test_records] + pair_score_bar = (np.max(pair_scores)-np.min(pair_scores))*0.9+np.min(pair_scores) + + if "coloredmnist" in run_records[0]['args']['dataset'].lower()or ("irm" in run_records[0]['args']['output_dir']): + test_records = Q(test_records[-5:]) + else: + test_records = Q(test_records[-10:]) + + return test_records.argmax(lambda x: x['val_acc']*(-1 if x['pair_score']0: + tmp_records.append(r) + self.preference_base = 10**int(np.log10(np.mean([np.min([np.abs(get_losses(r)[-1]) for r in rr]) for rr in tmp_records]))-2) + records = (records.group('args.hparams_seed') + .map(lambda _, run_records: + ( + self.run_acc(run_records), + run_records + ) + ).filter(lambda x: x[0] is not None) + ) + + num_records = len(records) + # filter out worst top_percentile% records in val acc to avoid trivial case + train_accs = [r[0]['val_acc'] for r in records] + train_acc_bar = (np.max(train_accs)-np.min(train_accs))*0.5+np.min(train_accs) + pair_scores = [r[0]['pair_score'] for r in records] + pair_score_bar = (np.max(pair_scores)-np.min(pair_scores))*0.9+np.min(pair_scores) + if "dann" not in records[0][1][0]['args']['output_dir'] and "groupdro" not in records[0][1][0]['args']['output_dir']: + return records.sorted(key=lambda x: x[0]['pair_score']*(1e8 if x[0]['val_acc'] 1: + return None + test_env = record['args']['test_envs'][0] + test_out_acc_key = 'env{}_out_acc'.format(test_env) + test_in_acc_key = 'env{}_in_acc'.format(test_env) + train_accs = [] + for i in range(1,10): + if i == test_env or 'env{}_out_acc'.format(i) not in record.keys(): + continue + train_accs.append(record['env{}_out_acc'.format(i)]) + pair_score,losses = get_pair_score(record=record,get_loss=True,preference_base=self.preference_base) + return { + 'losses': losses, + 'train_acc': np.mean(train_accs), + 'pair_score': pair_score, + 'val_acc': record[test_out_acc_key], + 'test_acc': record[test_in_acc_key] + } + + @classmethod + def run_acc(self, run_records): + """ + Given records from a run, return a {val_acc, test_acc} dict representing + the best val-acc and corresponding test-acc for that run. + """ + run_records = run_records.filter(lambda r: + len(r['args']['test_envs']) == 1) + if not len(run_records): + return None + + test_records = get_test_records(run_records) + if not len(test_records): + return None + num_records = len(test_records) + test_records = test_records.map(self._step_acc) + train_acc_bar = 0 + train_accs = [r['train_acc'] for r in test_records] + train_acc_bar = (np.max(train_accs)-np.min(train_accs))*0.1+np.min(train_accs) + + erm_bar = 1 + erm_losses = [r['losses'][0] for r in test_records] + erm_bar = (np.max(erm_losses)-np.min(erm_losses))*0.8+np.min(erm_losses) + + pair_scores = [r['pair_score'] for r in test_records] + pair_score_bar = (np.max(pair_scores)-np.min(pair_scores))*0.9+np.min(pair_scores) + + for r in test_records: + r['train_bar']=train_acc_bar + r['erm_bar']=erm_bar + r['pair_score_bar']=pair_score_bar + + if "dann" in run_records[0]['args']['output_dir']: + test_records = Q(test_records[-5:]) + else: + test_records = Q(test_records[-10:]) + + return test_records.argmax('pair_score') + return test_records[-1] + + @classmethod + def hparams_accs(self, records): + """ + Given all records from a single (dataset, algorithm, test env) pair, + return a sorted list of (run_acc, records) tuples. + """ + tmp_records = [] + for r in records.group('args.hparams_seed'): + r = get_test_records(r[1]) + if len(r)>0: + tmp_records.append(r) + + self.preference_base = 10**int(np.log10(np.mean([np.min([np.abs(get_losses(r)[-1]) for r in rr]) for rr in tmp_records]))-2) + return (records.group('args.hparams_seed') + .map(lambda _, run_records: + ( + self.run_acc(run_records), + run_records + ) + ).filter(lambda x: x[0] is not None) + .sorted(key=lambda x: x[0]['val_acc'])[::-1] + ) + + @classmethod + def sweep_acc(self, records): + """ + Given all records from a single (dataset, algorithm, test env) pair, + return the mean test acc of the k runs with the top val accs. + """ + _hparams_accs = self.hparams_accs(records) + if len(_hparams_accs): + return _hparams_accs[0][0]['test_acc'] + else: + return None + +class PAIRLeaveOneOutSelectionMethod(SelectionMethod): + """Model selection according to PAIR score from + Pareto Invariant Risk Minimization.""" + name = "pair leave-one-domain-out cross-validation" + + @classmethod + def _step_acc(self, records): + """Return the {val_acc, test_acc} for a group of records corresponding + to a single step.""" + test_records = get_test_records(records) + if len(test_records) != 1: + return None + + test_env = test_records[0]['args']['test_envs'][0] + n_envs = 0 + for i in itertools.count(): + if f'env{i}_out_acc' not in records[0]: + break + n_envs += 1 + val_accs = np.zeros(n_envs) - 1 + pair_scores = np.zeros(n_envs) - 1 + # it implicitly assumes there is a test env, and n-1 training env + # hence given n envs, it does the eval with all 2-test-env combinations + for r in records.filter(lambda r: len(r['args']['test_envs']) == 2): + val_env = (set(r['args']['test_envs']) - set([test_env])).pop() + val_accs[val_env] = r['env{}_in_acc'.format(val_env)] + pair_scores[val_env] = get_pair_score(r) + + val_accs = list(val_accs[:test_env]) + list(val_accs[test_env+1:]) + if any([v==-1 for v in val_accs]): + return None + val_acc = np.sum(val_accs) / (n_envs-1) + pair_score = np.sum(pair_scores) / (n_envs-1) + return { + 'pair_score': pair_score, + 'val_acc': val_acc, + 'test_acc': test_records[0]['env{}_in_acc'.format(test_env)] + } + + + @classmethod + def run_acc(self, records): + step_accs = records.group('step').map(lambda step, step_records: + self._step_acc(step_records) + ).filter_not_none() + + if len(step_accs): + num_records = len(step_accs) + # filter out worst top_percentile% records in val acc to avoid trivial case + step_accs = Q(step_accs.sorted(key=lambda x: x['val_acc'])[int(num_records*top_percentile):]) + return step_accs.argmax('pair_score') + else: + return None diff --git a/Extrapolation/pair.py b/Extrapolation/pair.py new file mode 100644 index 0000000..947f69d --- /dev/null +++ b/Extrapolation/pair.py @@ -0,0 +1,452 @@ +import copy +import imp +from pickletools import optimize +import torch +from torch.optim.optimizer import Optimizer, required +from torch.autograd import Variable +import traceback +import torch.nn.functional as F +from torch.optim import SGD + +class PAIR(Optimizer): + r""" + Implements Pareto Invariant Risk Minimization (PAIR) algorithm. + It is proposed in the ICLR 2023 paper + `Pareto Invariant Risk Minimization: Towards Mitigating the Optimization Dilemma in Out-of-Distribution Generalization` + https://arxiv.org/abs/2206.07766 . + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining parameter groups + optimizer (pytorch optim): inner optimizer + balancer (str, optional): indicates which MOO solver to use + preference (list[float], optional): preference of the objectives + eps (float, optional): precision up to the preference (default: 1e-04) + coe (float, optional): L2 regularization weight onto the yielded objective weights (default: 0) + """ + + def __init__(self, params, optimizer=required, balancer="EPO",preference=[1e-8,1-1e-8], eps=1e-4, coe=0, verbose=False): + # TODO: parameter validty checking + if eps < 0.0: + raise ValueError("Invalid epsilon value: {}".format(eps)) + for _pp in preference: + if _pp < 0.0: + raise ValueError("Invalid preference: {}".format(preference)) + + self.optimizer = optimizer + if type(preference) == list: + preference = np.array(preference) + self.preference = preference + + self.descent = 0 + self.losses = [] + self.params = params + if balancer.lower() == "epo": + self.balancer = EPO(len(self.preference),self.preference,eps=eps,coe=coe,verbose=verbose) + elif balancer.lower() == "sepo": + self.balancer = SEPO(len(self.preference),self.preference,eps=eps,coe=coe,verbose=verbose) + else: + raise NotImplementedError("Nrot supported balancer") + defaults = dict(balancer=balancer, preference=self.preference, eps=eps) + super(PAIR, self).__init__(params, defaults) + + + def __setstate__(self, state): + super(PAIR, self).__setstate__(state) + + def set_losses(self,losses): + self.losses = losses + + def step(self, closure=None): + r"""Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + if len(self.losses) == 0: + self.optimizer.step() + alphas = np.zeros(len(self.preference)) + alphas[0] = 1 + return -1, 233, alphas + else: + losses = self.losses + if closure is not None: + losses = closure() + + pair_loss = 0 + mu_rl = 0 + alphas = 0 + + grads = [] + for cur_loss in losses: + self.optimizer.zero_grad() + cur_loss.backward(retain_graph=True) + cur_grad = [] + for group in self.param_groups: + for param in group['params']: + if param.grad is not None: + cur_grad.append(Variable(param.grad.data.clone().flatten(), requires_grad=False)) + grads.append(torch.cat(cur_grad)) + + G = torch.stack(grads) + if self.get_grad_sim: + grad_sim = get_grad_sim(G,losses,preference=self.preference,is_G=True) + GG = G @ G.T + moo_losses = np.stack([l.item() for l in losses]) + reset_optimizer = False + try: + # Calculate the alphas from the LP solver + alpha, mu_rl, reset_optimizer = self.balancer.get_alpha(moo_losses, G=GG.cpu().numpy(), C=True,get_mu=True) + if self.balancer.last_move == "dom": + self.descent += 1 + print("dom") + except Exception as e: + print(traceback.format_exc()) + alpha = None + if alpha is None: # A patch for the issue in cvxpy + alpha = self.preference / np.sum(self.preference) + + scales = torch.from_numpy(alpha).float().to(losses[-1].device) + pair_loss = scales.dot(losses) + if reset_optimizer: + self.optimizer.param_groups[0]["lr"]/=5 + # self.optimizer = torch.optim.Adam(self.params,lr=self.optimizer.param_groups[0]["lr"]/5) + self.optimizer.zero_grad() + pair_loss.backward() + self.optimizer.step() + + return pair_loss, moo_losses, mu_rl, alpha + + + +import numpy as np +import cvxpy as cp +import cvxopt + +class EPO(object): + r""" + The original EPO solver proposed in ICML2020 + https://proceedings.mlr.press/v119/mahapatra20a.html + """ + def __init__(self, m, r, eps=1e-4, coe=0, verbose=False): + # self.solver = cp.GLPK + self.solver = cp.GUROBI + # cvxopt.glpk.options["msg_lev"] = "GLP_MSG_OFF" + self.m = m + self.r = r/np.sum(r) + self.eps = eps + self.last_move = None + self.a = cp.Parameter(m) # Adjustments + self.C = cp.Parameter((m, m)) # C: Gradient inner products, G^T G + self.Ca = cp.Parameter(m) # d_bal^TG + self.rhs = cp.Parameter(m) # RHS of constraints for balancing + + self.alpha = cp.Variable(m) # Variable to optimize + self.last_alpha = np.zeros_like(r)-1 + self.coe = coe + + obj_bal = cp.Maximize(self.alpha @ self.Ca) # objective for balance + constraints_bal = [self.alpha >= 0, cp.sum(self.alpha) == 1, # Simplex + self.C @ self.alpha >= self.rhs] + self.prob_bal = cp.Problem(obj_bal, constraints_bal) # LP balance + + obj_dom = cp.Maximize(cp.sum(self.alpha @ self.C)-self.coe*cp.sum_squares(self.alpha-self.last_alpha)) # obj for descent + constraints_dom = [self.alpha >= 0, cp.sum(self.alpha) == 1, # Restrict + self.alpha @ self.Ca >= -cp.neg(cp.max(self.Ca)), + self.C @ self.alpha >= 0] + self.prob_dom = cp.Problem(obj_dom, constraints_dom) # LP dominance + + + self.gamma = 0 # Stores the latest Optimum value of the LP problem + self.mu_rl = 0 # Stores the latest non-uniformity + + self.verbose = verbose + + + def get_alpha(self, l, G, r=None, C=False, get_mu=False): + """calculate weights for all objectives given the gradient information + + Args: + l (ndarray): the values of objective losses + G (ndarray): inner products of the gradients of each objective loss w.r.t. params + r (ndarray, optional): adopt this preference if specified + C (bool, optional): True if the input gradients are inner products + get_mu (bool, optional): return detailed information if True. + + Returns: + alpha: the objective weights + mu_rl (optional): the optimal value to the LP + reset_optimizer (optional): whether to reset the inner optimizer + """ + r = self.r if r is None else r + assert len(l) == len(G) == len(r) == self.m, "length != m" + l_hat, rl, self.mu_rl, self.a.value = self.adjustments(l, r) + reset_optimizer = False + self.C.value = G if C else G @ G.T + self.Ca.value = self.C.value @ self.a.value + + if self.last_alpha.sum() is None: + self.last_alpha = np.array(r) + if self.mu_rl > self.eps: + J = self.Ca.value > 0 + J_star_idx = np.where(rl == np.max(rl))[0] + self.rhs.value = self.Ca.value.copy() + # it's equivalent to setting no constraints to objectives in J + # as maximize alpha^TCa would trivially satisfy the non-negativity + self.rhs.value[J] = -np.inf + self.rhs.value[J_star_idx] = 0 + + self.gamma = self.prob_bal.solve(solver=self.solver, verbose=False) + self.last_move = "bal" + + if self.verbose: + test_alpha = np.ones_like(self.a.value)/self.m + print(self.last_alpha,self.C.value,self.Ca.value,self.rhs.value) + print(self.gamma,test_alpha@self.Ca.value, self.alpha.value @ self.C.value) + print(self.gamma,self.coe*np.linalg.norm(self.alpha.value-self.last_alpha)**2) + + else: + self.gamma = self.prob_dom.solve(solver=self.solver, verbose=False) + self.last_move = "dom" + self.last_alpha = np.array(self.alpha.value) + + if get_mu: + return self.alpha.value, self.mu_rl, reset_optimizer + + return self.alpha.value + + + def mu(self, rl, normed=False): + if len(np.where(rl < 0)[0]): + raise ValueError(f"rl<0 \n rl={rl}") + return None + m = len(rl) + l_hat = rl if normed else rl / rl.sum() + eps = np.finfo(rl.dtype).eps + l_hat = l_hat[l_hat > eps] + return np.sum(l_hat * np.log(l_hat * m)) + + + def adjustments(self, l, r=1): + m = len(l) + rl = r * l + + l_hat = rl / rl.sum() + mu_rl = self.mu(l_hat, normed=True) + uniformity_div = np.log(l_hat * m) - mu_rl + div_r = np.array(r) + a = div_r * uniformity_div + + if self.verbose: + print(a, rl, div_r, uniformity_div, l_hat, a.dot(l)) + return l_hat, rl, mu_rl, a + + +class SEPO(object): + r""" + A smoothed variant of EPO, with two adjustments for unrobust OOD objectives: + a) normalization: unrobust OOD objective can yield large loss values that dominate the solutions of the LP, + hence we adopt the normalized OOD losses in the LP to resolve the issue + b) regularization: solutions yielded by the LP can change sharply among steps, especially when switching descending phases + hence we incorporate a L2 regularization in the LP to resolve the issue + """ + def __init__(self, m, r, eps=1e-4, coe=0, verbose=False): + # self.solver = cp.GLPK + self.solver = cp.GUROBI + # cvxopt.glpk.options["msg_lev"] = "GLP_MSG_OFF" + self.m = m + self.r = r/np.sum(r) + self.eps = eps + self.last_move = None + self.a = cp.Parameter(m) # Adjustments + self.C = cp.Parameter((m, m)) # C: Gradient inner products, G^T G + self.Ca = cp.Parameter(m) # d_bal^TG + self.rhs = cp.Parameter(m) # RHS of constraints for balancing + + self.alpha = cp.Variable(m) # Variable to optimize + self.last_alpha = np.zeros_like(r)-1 + self.coe = coe + + obj_bal = cp.Maximize(self.alpha @ self.Ca-self.coe*cp.sum_squares(self.alpha-self.last_alpha)) # objective for balance + obj_bal_orig = cp.Maximize(self.alpha @ self.Ca) # objective for balance + constraints_bal = [self.alpha >= 0, cp.sum(self.alpha) == 1, # Simplex + self.C @ self.alpha >= self.rhs] + self.prob_bal = cp.Problem(obj_bal, constraints_bal) # LP balance + self.prob_bal_orig = cp.Problem(obj_bal_orig, constraints_bal) # LP balance + + obj_dom = cp.Maximize(cp.sum(self.alpha @ self.C)-self.coe*cp.sum_squares(self.alpha-self.last_alpha)) # obj for descent + constraints_res = [self.alpha >= 0, cp.sum(self.alpha) == 1, # Restrict + self.alpha @ self.Ca >= -cp.neg(cp.max(self.Ca)), + self.C @ self.alpha >= 0] + constraints_rel = [self.alpha >= 0, cp.sum(self.alpha) == 1, # Relaxed + self.C @ self.alpha >= 0] + self.prob_dom = cp.Problem(obj_dom, constraints_res) # LP dominance + self.prob_rel = cp.Problem(obj_dom, constraints_rel) # LP dominance + + self.gamma = 0 # Stores the latest Optimum value of the LP problem + self.mu_rl = 0 # Stores the latest non-uniformity + + self.verbose = verbose + + + def get_alpha(self, l, G, r=None, C=False, get_mu=False): + """calculate weights for all objectives given the gradient information + + Args: + l (ndarray): the values of objective losses + G (ndarray): inner products of the gradients of each objective loss w.r.t. params + r (ndarray, optional): adopt this preference if specified + C (bool, optional): True if the input gradients are inner products + get_mu (bool, optional): return detailed information if True. + + Returns: + alpha: the objective weights + mu_rl (optional): the optimal value to the LP + reset_optimizer (optional): whether to reset the inner optimizer + """ + r = self.r if r is None else r + assert len(l) == len(G) == len(r) == self.m, "length != m" + l_hat, rl, self.mu_rl, self.a.value = self.adjustments(l, r) + reset_optimizer = False + if self.mu_rl <= 0.1: + self.r[0]=max(1e-15,self.r[0]/10000) + self.r = self.r/self.r.sum() + print(f"pua preference {self.r}") + l_hat, rl, self.mu_rl, self.a.value = self.adjustments(l, r) + + + a_norm = np.linalg.norm(self.a.value) + G_norm = np.linalg.norm(G,axis=1) + Ga = G.T @ self.a.value + self.C.value = G if C else G/np.expand_dims(G_norm,axis=1) @ G.T/a_norm + self.Ca.value = G/np.expand_dims(G_norm,axis=1) @ Ga.T/a_norm + + if self.last_alpha.sum() is None: + self.last_alpha = np.array(r) + if self.mu_rl > self.eps: + J = self.Ca.value > 0 + + J_star_idx = np.where(rl == np.max(rl))[0] + self.rhs.value = self.Ca.value.copy() + # it's equivalent to setting no constraints to objectives in J + # as maximize alpha^TCa would trivially satisfy the non-negativity + self.rhs.value[J] = -np.inf # Not efficient; but works. + self.rhs.value[J_star_idx] = max(0,self.Ca.value[J_star_idx]/2) + + if self.last_alpha.sum()<0: + self.gamma = self.prob_bal_orig.solve(solver=self.solver, verbose=False) + else: + self.gamma = self.prob_bal.solve(solver=self.solver, verbose=False) + + self.last_move = "bal" + + if self.verbose: + test_alpha = np.ones_like(self.a.value)/self.m + print(self.last_alpha,self.C.value,self.Ca.value,self.rhs.value) + print(self.gamma,test_alpha@self.Ca.value, self.alpha.value @ self.C.value) + print(self.gamma,self.coe*np.linalg.norm(self.alpha.value-self.last_alpha)**2) + else: + self.gamma = self.prob_dom.solve(solver=self.solver, verbose=False) + self.last_move = "dom" + self.last_alpha = np.array(self.alpha.value) + + if get_mu: + return self.alpha.value, self.mu_rl, reset_optimizer + + return self.alpha.value + + + def mu(self, rl, normed=False): + if len(np.where(rl < 0)[0]): + raise ValueError(f"rl<0 \n rl={rl}") + return None + m = len(rl) + l_hat = rl if normed else rl / rl.sum() + eps = np.finfo(rl.dtype).eps + l_hat = l_hat[l_hat > eps] + return np.sum(l_hat * np.log(l_hat * m)) + + + def adjustments(self, l, r=1): + m = len(l) + rl = r * l + + l_hat = rl / rl.sum() + mu_rl = self.mu(l_hat, normed=True) + uniformity_div = np.log(l_hat * m) - mu_rl + div_r = np.array(r) + a = div_r * uniformity_div + + if self.verbose: + print(a, rl, div_r, uniformity_div, l_hat, a.dot(l)) + return l_hat, rl, mu_rl, a + + +def getNumParams(params): + numParams, numTrainable = 0, 0 + for param in params: + npParamCount = np.prod(param.data.shape) + numParams += npParamCount + if param.requires_grad: + numTrainable += npParamCount + return numParams, numTrainable + +def get_kl_div(losses, preference): + pair_score = losses.dot(preference) + return pair_score + +def pair_selection(losses,val_accs,test_accs,anneal_iter=0,val_acc_bar=-1,pood=None): + + losses = losses[anneal_iter:] + val_accs = val_accs[anneal_iter:] + test_accs = test_accs[anneal_iter:] + if val_acc_bar < 0: + val_acc_bar = (np.max(val_accs)-np.min(val_accs))*0.05+np.min(val_accs) + + try: + preference_base = 10**max(-12,int(np.log10(np.mean(losses[:,-1]))-2)) + except Exception as e: + print(e) + preference_base = 1e-12 + if len(losses[0])==2: + preference = np.array([preference_base,1]) + elif len(losses[0])==4: + preference = np.array([1e-12,1e-4,1e-2,1]) + elif len(losses[0])==5: + preference = np.array([1e-12,1e-6,1e-4,1e-2,1]) + else: + preference = np.array([1e-12,1e-2,1]) + + if pood is not None: + preference = pood + print(f"Use preference: {preference}, validation acc bar: {val_acc_bar}") + + pair_score = np.array([get_kl_div(l,preference) if a>=val_acc_bar else 1e9 for (a,l) in zip(val_accs,losses)]) + sel_idx = np.argmin(pair_score) + return sel_idx+anneal_iter, val_accs[sel_idx], test_accs[sel_idx] + +def get_grad_sim(params,losses,preference=None,is_G=False,cosine=True): + num_ood_losses = len(losses)-1 + if is_G: + G = params + else: + pesudo_opt = SGD(params,lr=1e-6) + grads = [] + for cur_loss in losses: + pesudo_opt.zero_grad() + cur_loss.backward(retain_graph=True) + cur_grad = [] + for param in params: + if param.grad is not None: + cur_grad.append(Variable(param.grad.data.clone().flatten(), requires_grad=False)) + # print(torch.cat(cur_grad).sum()) + grads.append(torch.cat(cur_grad)) + G = torch.stack(grads) + if cosine: + G = F.normalize(G,dim=1) + GG = (G @ G.T).cpu() + if preference is not None: + G_weights = preference[1:]/np.sum(preference[1:]) + else: + G_weights = np.ones(num_ood_losses)/num_ood_losses + grad_sim =G_weights.dot(GG[0,1:]) + return grad_sim.item() diff --git a/Extrapolation/pair_extrapolation.ipynb b/Extrapolation/pair_extrapolation.ipynb new file mode 100644 index 0000000..664d6b5 --- /dev/null +++ b/Extrapolation/pair_extrapolation.ipynb @@ -0,0 +1,851 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "7XrhbNoymf6H" + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "import torch as pt\n", + "from torch import nn\n", + "import matplotlib.pyplot as plt\n", + "import matplotlib.patches as patches\n", + "import matplotlib.cm as cm\n", + "from torch.utils.data import TensorDataset, DataLoader\n", + "import sys\n", + "import torch\n", + "import argparse\n", + "from tqdm import tqdm\n", + "from torch.autograd import Variable\n", + "from pair import PAIR\n", + "import os\n", + "\n", + "os.environ['CUDA_VISIBLE_DEVICES'] = '1'\n", + "\n", + "# prepare data generating hyper-parameters\n", + "seed = 2022\n", + "sample_size = [5000, 1000]\n", + "batch_size = 512\n", + "coffe = 2\n", + "n_epoch = 10000\n", + "anneal_epoch = 150\n", + "algorithm = \"erm\"\n", + "penalty_weight = 1\n", + "opt = \"Adam\"\n", + "sampling = \"gaussian\"\n", + "is_uniform = False\n", + "if is_uniform:\n", + " x1_l = -3\n", + " x1_r = 1\n", + " y1_l = -3\n", + " y1_r = -2\n", + "\n", + " x2_l = -1\n", + " x2_r = 3\n", + " y2_l = 2\n", + " y2_r = 3\n", + "else:\n", + " mean1 = (-0.9, -2.2)\n", + " cov1 = [[0.9, 0.11], [0.11, 0.1]]\n", + " mean2 = (1, 2)\n", + " cov2 = [[1, -0.3], [-0.3, 0.1]]\n", + "\n", + "# Choosing and saving a random seed for reproducibility\n", + "device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n", + "if seed == -1:\n", + " seed = int(torch.randint(0, 2 ** 32 - 1, (1,)).item())\n", + "torch.manual_seed(seed)\n", + "np.random.seed(seed)\n", + "torch.cuda.manual_seed_all(seed)\n", + "torch.manual_seed(seed)\n", + "torch.backends.cudnn.deterministic = True\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 281 + }, + "id": "xFuvbvXQA_yI", + "outputId": "e51a217e-21a3-4f01-81ff-9354fce46901" + }, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "def f_eval(x):\n", + " x = np.asarray(x, dtype=np.float32)\n", + " x = np.delete(x, 1, 1)\n", + " return np.sin(coffe * x) + 1\n", + "\n", + "def f_eval2(x):\n", + " x = np.asarray(x, dtype=np.float32)\n", + " return np.sin(coffe * x[0]) + 1\n", + "\n", + "\n", + "\n", + "class MLP(nn.Module):\n", + " def __init__(self):\n", + " super(MLP, self).__init__()\n", + " self.lin1 = nn.Linear(2, 128)\n", + " self.lin2 = nn.Linear(128, 128)\n", + " self.lin3 = nn.Linear(128, 1)\n", + " for lin in [self.lin1, self.lin2, self.lin3]:\n", + " nn.init.xavier_uniform_(lin.weight)\n", + " nn.init.zeros_(lin.bias)\n", + " self._main = nn.Sequential(self.lin1, nn.ReLU(True), self.lin2, nn.ReLU(True), self.lin3)\n", + " self.optimizer= pt.optim.Adam(self.parameters(), lr=1e-3)\n", + "\n", + " self.eval()\n", + " r2, r = 1e4, 1e-12\n", + " self.preference = np.array([r]*1+[(1-1*r-r2*r),r2*r])\n", + " self.eps =1e-1\n", + " self.n_tasks = 3\n", + " self.pair_optimizer = PAIR(self.parameters(),self.optimizer,preference=self.preference,eps=self.eps)\n", + " self.descent = 0\n", + "\n", + " def reset_parameters(self):\n", + " self.lin1.reset_parameters()\n", + " self.lin2.reset_parameters()\n", + " self.lin3.reset_parameters()\n", + " self.optimizer= pt.optim.Adam(self.parameters(), lr=1e-3)\n", + "\n", + " def update_preference(self,preference):\n", + " self.preference=preference\n", + " self.pair_optimizer = PAIR(self.parameters(),self.optimizer,preference=self.preference,eps=self.eps)\n", + " self.descent = 0\n", + "\n", + " def forward(self, x, to_numpy=False):\n", + " x = pt.as_tensor(x, dtype=pt.float32).to(device)\n", + " out = self._main(x)\n", + " if to_numpy:\n", + " out = out.to('cpu').detach().numpy()\n", + " return out\n", + " \n", + " def train_step(self, x, y):\n", + " self.train()\n", + " x = pt.as_tensor(x, dtype=pt.float32).to(device)\n", + " y = pt.as_tensor(y, dtype=pt.float32).to(device)\n", + " \n", + " self.optimizer.zero_grad()\n", + " y_pred = self.forward(x)\n", + " loss = nn.functional.mse_loss(y_pred, y)\n", + " loss.backward()\n", + " self.optimizer.step()\n", + " self.eval()\n", + " return loss.item()\n", + " \n", + " def train_step_ood(self, x1, y1, x2, y2, apply_ood_obj=False, penalty_weight=1e4):\n", + " self.train()\n", + " x1 = pt.as_tensor(x1, dtype=pt.float32).to(device)\n", + " y1 = pt.as_tensor(y1, dtype=pt.float32).to(device)\n", + " x2 = pt.as_tensor(x2, dtype=pt.float32).to(device)\n", + " y2 = pt.as_tensor(y2, dtype=pt.float32).to(device)\n", + "\n", + " self.optimizer.zero_grad()\n", + " y1_pred = self.forward(x1)\n", + " loss1 = nn.functional.mse_loss(y1_pred, y1)\n", + " y2_pred = self.forward(x2)\n", + " loss2 = nn.functional.mse_loss(y2_pred, y2)\n", + "\n", + " import torch.autograd as autograd\n", + "\n", + " scale = torch.tensor(1.).to(device).requires_grad_()\n", + " losses1 = nn.functional.mse_loss(y1_pred*scale, y1,reduction='none')\n", + " losses2 = nn.functional.mse_loss(y2_pred*scale, y2,reduction='none')\n", + " grad_1 = autograd.grad(losses1.mean(), [scale], create_graph=True)[0]\n", + " grad_2 = autograd.grad(losses2.mean(), [scale], create_graph=True)[0]\n", + " irm_penalty = pt.stack([pt.sum(grad_1**2), pt.sum(grad_2**2)]).mean()\n", + " vrex_penalty = pt.stack([loss1, loss2]).var()\n", + "\n", + " erm_loss=pt.stack([loss1, loss2]).mean()\n", + "\n", + " if algorithm.lower() == 'vrex':\n", + " penalty = vrex_penalty\n", + " elif algorithm.lower() == 'pair':\n", + " penalty = vrex_penalty\n", + " else:\n", + " penalty = irm_penalty\n", + " losses = torch.stack([erm_loss,irm_penalty,vrex_penalty]).to(device)\n", + " if apply_ood_obj and algorithm.lower() == 'pair':\n", + " self.pair_optimizer.zero_grad()\n", + " self.pair_optimizer.set_losses(losses=losses)\n", + " loss, moo_losses, mu_rl, alphas = self.pair_optimizer.step()\n", + " else:\n", + " loss = erm_loss+penalty*(penalty_weight if apply_ood_obj else 1)\n", + " if apply_ood_obj and penalty_weight>0:\n", + " loss /= penalty_weight\n", + "\n", + " loss.backward()\n", + " self.optimizer.step()\n", + " loss = loss.item()\n", + " self.eval()\n", + " return loss, losses\n", + " \n", + " def valid_loss(self, x, y):\n", + " self.eval()\n", + " with pt.no_grad():\n", + " x = pt.as_tensor(x, dtype=pt.float32).to(device)\n", + " y = pt.as_tensor(y, dtype=pt.float32).to(device)\n", + "\n", + " y_pred = self.forward(x)\n", + " loss = nn.functional.mse_loss(y_pred, y)\n", + " self.train()\n", + " return y_pred.detach(), loss.item()\n", + "\n", + "def prepare_data(gen):\n", + " train_x = []\n", + " train_y = []\n", + " if is_uniform:\n", + " t_x1 = gen.uniform(x1_l, x1_r, [int(sample_size[0]/2), 1])\n", + " t_y1 = gen.uniform(y1_l, y1_r, [int(sample_size[0]/2), 1])\n", + " train_x.append(np.hstack((t_x1, t_y1)))\n", + " t_x2 = gen.uniform(x2_l, x2_r, [int(sample_size[0]/2), 1])\n", + " t_y2 = gen.uniform(y2_l, y2_r, [int(sample_size[0]/2), 1])\n", + " train_x.append(np.hstack((t_x2, t_y2)))\n", + " else:\n", + " t_xy1 = gen.multivariate_normal(mean1, cov1, size=int(sample_size[0]/2))\n", + " train_x.append(t_xy1)\n", + " t_xy2 = gen.multivariate_normal(mean2, cov2, size=int(sample_size[0]/2))\n", + " train_x.append(t_xy2)\n", + " valid_x = gen.uniform(-4, 4, [sample_size[1], 2])\n", + " train_y.append(f_eval(train_x[0]))\n", + " train_y.append(f_eval(train_x[1]))\n", + " valid_y = f_eval(valid_x)\n", + " return train_x, valid_x, train_y, valid_y\n", + "\n", + "def print_truth():\n", + " grid_c = 200\n", + " x = np.linspace(-4, 4, grid_c)\n", + " y = np.linspace(-4, 4, grid_c)\n", + " xy = np.meshgrid(x, y)\n", + " f = f_eval2(xy)\n", + " fig, ax = plt.subplots()\n", + " ax.set_title(\"ground truth\")\n", + " ax.set_xlabel(\"$x_1$\")\n", + " ax.set_ylabel(\"$x_2$\")\n", + "\n", + " im = ax.imshow(np.array(f), cmap = cm.jet, vmin=0., vmax=2., origin = 'lower', extent=[-4, 4, -4, 4])\n", + " if is_uniform:\n", + " eps = 0.05\n", + " rect1 = patches.Rectangle((x1_l+eps,y1_l+eps),(x1_r-x1_l-2*eps),(y1_r-y1_l-2*eps),linewidth=1,edgecolor='r',facecolor='none')\n", + " rect2 = patches.Rectangle((x2_l+eps,y2_l+eps),(x2_r-x2_l-2*eps),(y2_r-y2_l-2*eps),linewidth=1,edgecolor='r',facecolor='none')\n", + " ax.add_patch(rect1)\n", + " ax.add_patch(rect2)\n", + " else:\n", + " w, v = np.linalg.eig(cov1)\n", + " theta = np.sign(cov1[0][1]) * np.degrees(np.arccos(v[0,0]))\n", + " ellip1 = patches.Ellipse(mean1, 6*np.sqrt(np.abs(w[0])),6*np.sqrt(np.abs(w[1])), angle=theta,linewidth=1,edgecolor='r',facecolor='none')\n", + " w, v = np.linalg.eig(cov2)\n", + " theta = np.sign(cov2[0][1]) * np.degrees(np.arccos(v[0,0]))\n", + " ellip2 = patches.Ellipse(mean2, 6*np.sqrt(np.abs(w[0])),6*np.sqrt(np.abs(w[1])), angle=theta,linewidth=1,edgecolor='r',facecolor='none')\n", + " ax.add_patch(ellip1)\n", + " ax.add_patch(ellip2)\n", + " # plt.tight_layout()\n", + " fig.colorbar(im, shrink=0.4, aspect=9)\n", + " plt.show()\n", + " # plt.savefig(f'extrapolate_truth_s{sampling}.png')\n", + "\n", + "def print_samples(train_x):\n", + " fig, ax = plt.subplots()\n", + " ax.set_title(\"training samples\")\n", + " plt.plot(np.vstack(train_x)[:, 0], np.vstack(train_x)[:, 1], '.', alpha=0.5)\n", + " if is_uniform:\n", + " eps = 0.05\n", + " rect1 = patches.Rectangle((x1_l+eps,y1_l+eps),(x1_r-x1_l-2*eps),(y1_r-y1_l-2*eps),linewidth=1,edgecolor='r',facecolor='none')\n", + " rect2 = patches.Rectangle((x2_l+eps,y2_l+eps),(x2_r-x2_l-2*eps),(y2_r-y2_l-2*eps),linewidth=1,edgecolor='r',facecolor='none')\n", + " ax.add_patch(rect1)\n", + " ax.add_patch(rect2)\n", + " else:\n", + " w, v = np.linalg.eig(cov1)\n", + " theta = np.sign(cov1[0][1]) * np.degrees(np.arccos(v[0,0]))\n", + " ellip1 = patches.Ellipse(mean1, 6*np.sqrt(np.abs(w[0])),6*np.sqrt(np.abs(w[1])), angle=theta,linewidth=1,edgecolor='r',facecolor='none')\n", + " w, v = np.linalg.eig(cov2)\n", + " theta = np.sign(cov2[0][1]) * np.degrees(np.arccos(v[0,0]))\n", + " ellip2 = patches.Ellipse(mean2, 6*np.sqrt(np.abs(w[0])),6*np.sqrt(np.abs(w[1])), angle=theta,linewidth=1,edgecolor='r',facecolor='none')\n", + " ax.add_patch(ellip1)\n", + " ax.add_patch(ellip2)\n", + " plt.show()\n", + "\n", + "@torch.no_grad()\n", + "def print_predict(model, valid_loss, exp_name=None):\n", + " model.eval()\n", + " grid_c = 200\n", + " x = np.linspace(-4, 4, grid_c)\n", + " y = np.linspace(-4, 4, grid_c)\n", + " f = []\n", + " for _y in y:\n", + " f_r = []\n", + " for _x in x:\n", + " f_r.append(model.forward(x=[_x,_y], to_numpy=True))\n", + " f.append(f_r)\n", + " fig, ax = plt.subplots()\n", + " if exp_name != None:\n", + " ax.set_title(f\"{exp_name} {valid_loss}\")\n", + " else:\n", + " ax.set_title(f\"valid_loss {valid_loss}\")\n", + "\n", + " im = ax.imshow(np.array(f).squeeze(-1), cmap = cm.jet, vmin=0., vmax=2., origin = 'lower', extent=[-4, 4, -4, 4])\n", + " if is_uniform:\n", + " eps = 0.05\n", + " rect1 = patches.Rectangle((x1_l+eps,y1_l+eps),(x1_r-x1_l-2*eps),(y1_r-y1_l-2*eps),linewidth=1,edgecolor='r',facecolor='none')\n", + " rect2 = patches.Rectangle((x2_l+eps,y2_l+eps),(x2_r-x2_l-2*eps),(y2_r-y2_l-2*eps),linewidth=1,edgecolor='r',facecolor='none')\n", + " ax.add_patch(rect1)\n", + " ax.add_patch(rect2)\n", + " else:\n", + " w, v = np.linalg.eig(cov1)\n", + " theta = np.sign(cov1[0][1]) * np.degrees(np.arccos(v[0,0]))\n", + " ellip1 = patches.Ellipse(mean1, 6*np.sqrt(np.abs(w[0])),6*np.sqrt(np.abs(w[1])), angle=theta,linewidth=1,edgecolor='r',facecolor='none')\n", + " w, v = np.linalg.eig(cov2)\n", + " theta = np.sign(cov2[0][1]) * np.degrees(np.arccos(v[0,0]))\n", + " ellip2 = patches.Ellipse(mean2, 6*np.sqrt(np.abs(w[0])),6*np.sqrt(np.abs(w[1])), angle=theta,linewidth=1,edgecolor='r',facecolor='none')\n", + " ax.add_patch(ellip1)\n", + " ax.add_patch(ellip2)\n", + " # plt.tight_layout()\n", + " fig.colorbar(im, shrink=0.4, aspect=9)\n", + " plt.show()\n", + " # plt.savefig(f'extrapolate_{algorithm}_s{sampling}_{opt}_p{penalty_weight}.png')\n", + " model.train()\n", + "\n", + "@torch.no_grad()\n", + "def print_predict3D(model, valid_loss, exp_name=None):\n", + " model.eval()\n", + " grid_c = 200\n", + " x = np.linspace(-4, 4, grid_c)\n", + " y = np.linspace(-4, 4, grid_c)\n", + " f = []\n", + " for _y in y:\n", + " f_r = []\n", + " for _x in x:\n", + " f_r.append(model.forward(x=[_x,_y], to_numpy=True)[0])\n", + " f.append(f_r)\n", + " fig = plt.figure()\n", + " ax = plt.axes(projection='3d')\n", + " if exp_name != None:\n", + " ax.set_title(f\"{exp_name} {valid_loss}\")\n", + " else:\n", + " ax.set_title(f\"valid_loss {valid_loss}\")\n", + " X_grid, Y_grid = np.meshgrid(x, y)\n", + " f_grid = np.array(f)\n", + " print(f_grid.shape)\n", + " surf = ax.plot_surface(X_grid, Y_grid, np.array(f), cmap=cm.jet, vmin=0., vmax=2., edgecolors='None', antialiased=True, rcount=70, ccount=70)\n", + " fig.colorbar(surf, shrink=0.4, aspect=9)\n", + " ax.set_box_aspect((np.ptp(X_grid), np.ptp(Y_grid), np.ptp(f_grid)))\n", + " plt.clabel(surf)\n", + " plt.show()\n", + " model.train()\n", + "\n", + "rng = np.random.default_rng(seed)\n", + "train_x, valid_x, train_y, valid_y = prepare_data(rng)\n", + "model = MLP().cuda()\n", + "train_xx = np.vstack(train_x)\n", + "train_yy = np.vstack(train_y)\n", + "data_loader = DataLoader(TensorDataset(\n", + " pt.as_tensor(train_xx, dtype=pt.float32),\n", + " pt.as_tensor(train_yy, dtype=pt.float32)\n", + "), batch_size=batch_size, shuffle=True)\n", + "\n", + "data_loader1 = DataLoader(TensorDataset(\n", + " pt.as_tensor(train_x[0], dtype=pt.float32),\n", + " pt.as_tensor(train_y[0], dtype=pt.float32)\n", + "), batch_size=int(batch_size/2), shuffle=True)\n", + "data_loader2 = DataLoader(TensorDataset(\n", + " pt.as_tensor(train_x[1], dtype=pt.float32),\n", + " pt.as_tensor(train_y[1], dtype=pt.float32)\n", + "), batch_size=int(batch_size/2), shuffle=True)\n", + "\n", + "print_truth()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 281 + }, + "id": "c5uGpIPG6oNd", + "outputId": "8a2f1d7f-ae52-40d9-faa3-d1971497bab6" + }, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "print_samples(train_x)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TjBlN98I7uJL" + }, + "source": [ + "# ERM" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 307 + }, + "id": "CuFAW2bkBTwy", + "outputId": "8c7b5f0c-2b81-455e-a55c-4297b2380d03" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[10000|10000] train_loss=3.57042e-02, valid_loss=8.02226e-01: 100%|██████████| 10000/10000 [11:01<00:00, 15.12it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(200, 200)\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "\n", + "def train_ERM():\n", + " pbar = tqdm(range(n_epoch))\n", + " for ep in pbar:\n", + " train_loss = 0\n", + " for sub_x, sub_y in data_loader:\n", + " sub_x=sub_x.to(device)\n", + " sub_y=sub_y.to(device)\n", + " loss = model.train_step(sub_x, sub_y)\n", + " train_loss += len(sub_y) * loss\n", + " train_loss /= len(train_y)\n", + "\n", + " y_pred, valid_loss = model.valid_loss(valid_x, valid_y)\n", + "\n", + " pbar.set_description(f'[{ep+1}|{n_epoch}] train_loss={train_loss:0.5e}, valid_loss={valid_loss:0.5e}')\n", + " \n", + " exp_name = f\"{algorithm}_s{sampling}_{opt}_p{penalty_weight}\"\n", + " # print_predict(model,valid_loss,exp_name=exp_name)\n", + " print_predict3D(model,valid_loss,exp_name=exp_name)\n", + "train_ERM()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "id": "X-838VHgBVoI" + }, + "outputs": [], + "source": [ + "def train_OOD():\n", + " losses = []\n", + " pbar = tqdm(range(n_epoch))\n", + " for ep in pbar:\n", + " train_loss = 0\n", + " cur_losses = [0]*3\n", + " if ep == anneal_epoch:\n", + " if algorithm=='pair':\n", + " if opt.lower() == 'adam':\n", + " model.optimizer = pt.optim.Adam(model.parameters(), lr=1e-3)\n", + " elif opt.lower() == 'sgd':\n", + " model.optimizer = pt.optim.SGD(model.parameters(), lr=2e-3, momentum=0.9)\n", + " else:\n", + " model.optimizer = pt.optim.Adam(model.parameters(), lr=1e-3)\n", + " for _, ((sub_x, sub_y), (sub_x2, sub_y2)) in enumerate(zip(data_loader1, data_loader2)):\n", + " if ep < anneal_epoch:\n", + " loss, sep_losses = model.train_step_ood(sub_x, sub_y, sub_x2, sub_y2)\n", + " else:\n", + " loss, sep_losses = model.train_step_ood(sub_x, sub_y, sub_x2, sub_y2, apply_ood_obj=True, penalty_weight=penalty_weight)\n", + " train_loss += len(sub_y) * loss\n", + " cur_losses = [l1+len(sub_y)*l2.item() for (l1,l2) in zip(cur_losses,sep_losses)]\n", + " train_loss /= len(train_y)\n", + " losses.append([l/len(train_y) for l in cur_losses])\n", + "\n", + " _, valid_loss = model.valid_loss(valid_x, valid_y)\n", + " losses[-1][0] = valid_loss\n", + "\n", + " pbar.set_description(f'[{ep+1}|{n_epoch}] train_loss={train_loss:0.5e}, valid_loss={valid_loss:0.5e}')\n", + " \n", + " \n", + " exp_name = f\"{algorithm}_s{sampling}_{opt}_p{penalty_weight}\"\n", + " print_predict(model,valid_loss,exp_name=exp_name)\n", + " # print_predict3D(model,valid_loss,exp_name=exp_name)\n", + " plt.close()\n", + " num_epochs = len(losses)\n", + " fig, ax1 = plt.subplots()\n", + " ax1.set_title(exp_name+f\" {valid_loss}\")\n", + " ax1.set_xlabel(\"epoch\")\n", + " ax1.set_ylabel(\"val loss\")\n", + " # heuristic approach to beautify the visualization\n", + " erm_vis_max = np.max([log_i[0] for log_i in losses[140:200]])+1e9\n", + " erm_pens = np.array([min(log_i[0],erm_vis_max) for log_i in losses])\n", + " ax1.plot(np.arange(num_epochs),erm_pens,label=f'erm_pen')\n", + " ax2 = ax1.twinx()\n", + " ax2.set_ylabel(\"penalty\")\n", + " if len(losses[0])>=3:\n", + " irm_pens = np.array([min(log_i[-2],1) for log_i in losses])\n", + " vrex_pens = np.array([min(log_i[-1],1) for log_i in losses])\n", + " ax2.plot(np.arange(num_epochs),irm_pens,label=f'irm_pen',c='r',alpha=0.2)\n", + " ax2.plot(np.arange(num_epochs),vrex_pens,label=f'vrex_pen',c='g',alpha=0.2)\n", + " else:\n", + " irm_pens = np.array([log_i[-1] for log_i in losses])\n", + " ax2.plot(np.arange(num_epochs),irm_pens,label=f'irm_pen',c='r',alpha=0.2)\n", + " plt.legend()\n", + " plt.show()\n", + " # plt.savefig(f\"{exp_name}_s{sampling}.png\")\n", + " plt.close()\n", + "\n", + "# if __name__ == '__main__':\n", + "# print_truth()\n", + "# if algorithm.lower() in ['vrex','irm','pair']:\n", + "# train_VREx()\n", + "# else:\n", + "# train_ERM()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NS2kBnBn7zJT" + }, + "source": [ + "# IRM" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 861 + }, + "id": "n87poM15BxWo", + "outputId": "8574a66c-8eac-4b58-8406-efaf089c378d" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[10000|10000] train_loss=4.81506e+00, valid_loss=1.01370e+00: 100%|██████████| 10000/10000 [16:49<00:00, 9.90it/s]\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "# IRM\n", + "algorithm = \"irm\"\n", + "penalty_weight = 1e-2\n", + "model = MLP().to(device)\n", + "train_OOD()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "U5Sya9VA71ml" + }, + "source": [ + "# VREX" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 861 + }, + "id": "jltOpblmB1CZ", + "outputId": "7a61b33e-e969-442b-9c91-3347a01fec60" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[10000|10000] train_loss=2.12006e-01, valid_loss=6.44612e-01: 100%|██████████| 10000/10000 [15:56<00:00, 10.46it/s]\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAacAAAEWCAYAAADCeVhIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAABeF0lEQVR4nO2deXxU1dn4v0/2QAKBsO+7sgqCK9pi61KVqq17q1XrXrVqW99Xa6vW1r7dfrVUrYqte9Uq7tZWbStaF0SQfQ8QIEAgBBISQkKW5/fHuZPcmcxyk8xkZjLn+/nMZ+bee+655y5znvs85znPI6qKxWKxWCyJRFq8G2CxWCwWSyBWOFksFosl4bDCyWKxWCwJhxVOFovFYkk4rHCyWCwWS8JhhZPFYrFYEg4rnCwhEZFVIjIr3u3wISIjRERFJCPebbFYLLHFCidLSFR1oqrOj/VxRGSWI3T+N9bHijYi0ltEXhWRAyKyRUS+FabsSSLyvohUikixh7qnishiEalxvqdGKH+RiKxx2rJRRE4MUuYu51qfHOJcykTko4D13UTkTyKyx2n7h17PSUR+LiIrRKRBRO4J2PZjEal2fQ6KSJOI9HG2XyAinzjnPz9I3eki8gsR2SEiVSKyREQKnG2XOddsv4iUiMhv3C81AcetFpFGEXnAtf2rIrLWOfb7IjLc6/WKdK1F5GQR+cK5TyUicoFr21ecbftFZJOIXBOs7lTACqcwiMFeo9hzGbAX+E68G9IOHgIOAf2BbwMPi8jEEGUPAI8Dt0WqVESygNeBZ4FewFPA6876YOVPAX4NXAHkA18CNgWUGQ2cD+wMcdhfA2uCrJ8L9AbGO9+3tuGcioD/Af4euEFVf6mqeb6Pc/z5qrrHKbIX+APwqxB1/ww4HjgO6AFcCtQ627oBtwB9gGOArwI/ch3bfdwBwEHgJQBHOL4C/NQ530XA34IcP9T1CnmtRWQC8BxwJ9ATOAJY7GzLBF4FHnW2XQj8XkSOCHH+XRtVTakP8L/AvIB1c4A/Or/nA/cBH2Me2DHA4cB7mD/LOuACp+xoZ92RzvIgoAyYFaENl2M6jipgM/BtZ3068P+APc76GwEFMpztV2D+DFXO/tcG1PlRwHEUGOP8PgNY7ey7HfiRs74P8BZQ4ZzLf4E0Z1sxcLLz+2jgU6fcTuBBICvgWNcBG5wyDwHi4X50d9p0EaaTn+Halg78zrkem4Ab2nA9ZgElmI5xt9Pmc5zrsN451x97aN89wDxM51QFfAEc4Wr7IWCcq/wzwK8i1HkyUByhzKnOfRLXuq3A10KU/wS4MkKd/3TOv/m+urYd79zfK9zPEebZ3w/06Mg5YYTsPWG2i3MPLwuy7SqM0HKv6wVUA6M9/u9/ALwZYttlzrHFWb4G+CTgGT0IHB7pekW61hjB9PMQ7ejvPN/dXOs+By72co5d7ZOKWsELwBkikg/GNABcgHlofFyKeUDzMcLmPWd7P0wn+icRmaCqGzHC7lkR6QY8ATylYUxhItId+CNwuqrmYx7ypc7mq4HTganAkZjO1M1uYDbmLfEK4H4ROdLjef8F03nnA5OA/zjrf4jpxPti/hw/xvxBAmnEvDH3wbypfhX4XkCZ2cBRwBTMNT3NQ7u+ielkXgLewXQUPq526pwGzADOC9g30vUYAOQAg4G7gMeAS4DpwInAT0VkpIc2nu20rzfmOXjNecsdBzSo6npX2WVAKM2pLUwElqvTQzksD1a38wzPAPqKSJFjKnpQRHJdZc4H6lT17RD7P0jLy5Cbo4EtwM8cs94KETm3oycXhBMx/6+XPZafDDQA54lIqYisF5EbwpT/ErAqxLbLgKdd13oi5j4CoKoHgI3O+kjXK+y1Bo51yqwQkZ0i8qyI9HaOswt4HrjCMVkeBwwHgpoNuzopJ5xUdQvm7fcbzqqvADWqusBV7ElVXaWqDcDXMG+ET6hqg6ouwfyBznfqewxjuvgMGIhR1yPRBEwSkVxV3amqvj/NBcAcVS1R1X0EmDNU9e+qulENHwDvYv7UXqgHJohID1Xdp6pfuNYPBIarar2q/jegQ/Qde7GqLnCuQTHG9PDlgGK/UtUKVd0KvI8RspG4DPibqjZiOv6LnI4fzPX4g6puU9W9wP8FtCnS9agH7lPVesxLSR/M9a1yrvlqjFklEotVdZ5Tz+8xAu9YIA+jVbipxLzUdJQ8py4vdfcHMjHC+0TMdZ8G/ATAeRH7JXBziGN9H/hMVRcH2TYE8zJTibEM3Ag8JSLj23AuXrgMY9Go9lh+CMb0NQ4YiTn3exzzph8i8l2M8P5dkG3DMc/xU67Vka59yOvl4VoPwbz8nguMBXKBB1zbn8e8SNVhrBh3quq2EHV1aVJOODk8B1zs/P4W/loTgPthGA4cIyIVvg9mbGGAq8xjmD/wA6paF+7AzlvYhRgT2E4R+buIHO5sHhRwbL+HUkROF5EFIrLXaccZmA7XC+c65beIyAfOWxnAbzHC9V1nAPb2YDuLyDgRect5S92P+QMGHrvU9bsG8ycPiYgMBU4C/uqseh3T8Z/pLAdejy0B+0e6HuWO0ANjlgHY5dp+MFIbHZrboKpNGE1zEEbj6xFQtgfG/NdR2lK379wecF529mCE6BnO+nuAZ5yXCj9EZBCmsw31UnUQI+R/oaqHnJeA9zFmx6jgWB3Ox19ARMJ3zveq6kFVXY5jFQmo+xzMS83p2jKW5eZSjFlus2tdyGvv4XrdQ4hr7Wr3E6q63hHEv/S12ekHXsCMvWZhNLX/EZEzQ9TVpUlV4fQSMEtEhmA0qEDh5NYctgEfqGqB65OnqtcDiEgeZtD2L5g3t96RDq6q76jqKRiNZS1GuIEZFxniKjrU90NEsjEa2++A/qpaALyNsdWDGZju5irvFp6o6ueqejbGdPIa8KKzvkpVf6iqo4CzgB+IyFeDNPthp61jVbUHxvwnQcq1hUsxz+CbIlKKsfvn0GLa24nrGgDDXOcX6XpEE/d9SMPcox2YsasMERnrKnsEoc1HbWEVMEVE3OczJVjdjpZdgv9z6/79VeD7zotFKeZ8XhTjHXk05jlc7WybAxztlE3HmBJbHbID5xWMb2DGAOe3YR9fu0KdMyLyNcx/6+uquiJEPd+htVBchUujdkzxo531ka5XuGvta3eoNk8C1jv9Q5OqrsM4kpwe6iJ0ZVJSOKlqGeaP8ASwWVWDetw4vAWME5FLRSTT+RzlMmvMARap6lWYB+mRcMcWkf4icrbzwNdh3tKanM0vAjeLyGAxLrFu1+osIBszBtYgIqfj//a6DJgoxv04B/MG5ztmloh8W0R6Oqap/b5jishsERnjdIKVmLGlJlqT7+xX7bzhXR/uPD1yGcbjaqrrcy5mTLAQcz2+LyJDRKQX4NbqIl2PaDJdRL4pxhX5Fsx9W+Bowa8A94pIdxGZiRmfeiZYJSKS5tybTLMoORLC+w7zfDZizj9bRG501v8nRPkngJtEpJ9zrW7FPLtgOsxJtFzjHcC1GKeVfwAjXNvuApYAUx2t80OMI8YdIpLhnONJmPHBiOfk/F9yMH1NhrM9PaDtgWM+vn3TnX0zgDRn30wAZ7z3v8CdzvUZjxkPfsvZ9ysYjfxcVV0Y7IKJyPGY8ciXAja9ijG7n+sc/y7M+N9aD9cr3LUGc5+uEJFRjsZ4Oy33aQkwVow7uYjx+JtN8BeErk9bvCe60gfz1q7AbQHr5wNXBaw7DCN4yoByTAcxFdMRbQd6O+XyMCayb4c57kDgA4wgqHCON8HZlgHc7xxjM6aDqafFi+gGjFmqAtMBvoAxt/jqvhPj2bYNM/CvGG/DLIz30D6MgPkcOMHZ51aMR9EBzNv3T131FdPirfcljOZUjekU7sXfq6vZM9BZftLdtiDX4ViM22/fINtWYcY2Aq9HoLdeyOuB463nqjPD2XeEa91HwCURnpN78PfWW4Ljnels743RRA9gOvFvubadCFS7lmc5bXB/5oc59jSMm/FBzDjpNNe2HwP/cC1nAn9yrkUpxukmJ0S9zfc1yLbLae31ORHjmXYAM073Da/n5DwHgdsvd20fjHFsGBOiLYH7Phmw7z8xz2Sgt+b7Tr3Vrs8/Aup/FGOCC3YdTsY87wcx/9ERXq9XpGuNeSErcz7PAL1c2y4AVjrPWgnGXT2tPX1csn98nZ4lAXG0gUdUdXi825KqiJk4OkZVL4l3WyyWVCIlzXqJiojkisgZjvlkMHA3xsRgsVgsKYUVTjFCWodH8X3CuX4LRuXfhzEfrcHYtJMWEXkkxHUIOzbXmYjIP0K08cfxbpvFkqpYs57FYrFYEg6rOVksFosl4Ui61ANpaWmam5sbuaDFYrFYmqmpqVFVTRqFJOmEU25uLgcOHIh3MywWiyWpEJGDkUslDkkjRS0Wi8WSOljhZLFYLJaEwwoni8VisSQcSTfmZLFYujb19fWUlJRQW1sbubClFTk5OQwZMoTMzMzIhRMYK5wsFktCUVJSQn5+PiNGjMA/KLslEqpKeXk5JSUljBzpJY9m4mLNehaLJaGora2lsLDQCqZ2ICIUFha2S+sUkcdFZLeIrAyxXUTkj2KyLS8X71m424UVThaLJeGwgqn9dODaPYnJ/B2K0zHZe8cC12ByvMWMlDHrHdy/l60lKxmc0Zu8KidZbUEB66qaOJDTjSMLs6CiArKyoLYWCgogPR0aG6G+Hvbvh759obwchg6F7t2hrAz27TNlc3PN7/JyGDgQevY06wB27YLMTBAx9ZeVwbBhcOCAWc7ONr9FoK4OevQwxwZTVgT6uBK8lpebupuaIM9LItdOprwcevWCNOfdp7bWXMP8aGQvTwLq6lruYzhUYe9eKCzsnHZZLGFQ1Q9FZESYImfTkndrgYgUiMhAVd0Zi/akjHCq2L+L6p1bWcdWphc4eQIrKvjhAx8B8NZNJ/jvUF3dupKtW8332rUwfXrLclVA5uzt281n+nSzXFLSuq6MDCO0wJRbu7ZlW8+eMGYMHDrUcoxevYzA2r8fiotbyvqOkShUVpr21dQYIQ6wyknemmhtjRUrHatIpPMtLYUdO8zLR++ICZQtlo6SISKLXMtzVXVuG/YfjMkV56PEWWeFU0fI7zcU0mpgTbikt22gsrJj+9fXh9526JD5dgfl9f1ubOzYcWONr30NDfFtRzLgewYS/Z6mIMcffzyffPJJvJsRbRpUdUa8G+EVO+bUXoqK4t0Ci8USI4IJpgb7wrUdGOpaHuKsiwkpozl1WQ4dMuNWFktXZNs2OBjlkHC5uS0m5xDk5eVRXV3N/Pnz+elPf0qvXr1Yu3Ytc+fO5e6776agoIAVK1ZwwQUXMHnyZObMmcPBgwd57bXXGD16dNA6L7/8cnJycli0aBH79+/n97//PbNnz6axsZHbb7+d+fPnU1dXxw033MC1117L/Pnzueeee+jTpw8rV65k+vTpPPvss/F0FnkDuFFEXgCOASpjNd4EVjglNl5yba1YAUccYcawUoHaWiOQIzkbWCxR4osvvmDlypWMHDmS+fPns2zZMtasWUPv3r0ZNWoUV111FQsXLmTOnDk88MAD/OEPfwhZV3FxMQsXLmTjxo2cdNJJFBUV8fTTT9OzZ08+//xz6urqmDlzJqeeeioAS5YsYdWqVQwaNIiZM2fy8ccfc8IJJ4SsvyOIyPPALKCPiJRgMnFnAqjqI8DbwBlAEVADXBGThjikSI/WxWlsTB3hlGrOFalOBA2nMzj66KP9JrQeddRRDBw4EIDRo0c3C5LJkyfz/vvvh63rggsuIC0tjbFjxzJq1CjWrl3Lu+++y/Lly5k3bx4AlZWVbNiwgaysLI4++miGDBkCwNSpUykuLo6ZcFLViyNsV+CGmBw8CCnSo1ksFkv76N69u99ydnZ28++0tLTm5bS0tIjjUoEmORFBVXnggQc47bTT/LbNnz/f71jp6ekpNe5lHSIcbLp6i8USa1566SWamprYuHEjmzZt4rDDDuO0007j4Ycfpt7x3ly/fr3NWYfVnJppbFIy0u2sdIvFEjuGDRvG0Ucfzf79+3nkkUfIycnhqquuori4mCOPPBJVpW/fvrz22mvxbmrcscLJwSpOFovFR7UzCX/WrFnMmjWreX3g8vz580NuC8bJJ5/MI4884rcuLS2NX/7yl/zyl7/0Wx9Y34MPPtimc0h2rFnPoYkUlk7FxbB4cbxbYbFYLM3ETHMSkceB2cBuVZ0UZPu3gf8FBKgCrlfVZbFqTzAam1oEUlOyqE6xaGd5efTrtFhSlPvuu4+XXnrJb93555/Pk08+GZ8GJSmxNOs9CTwIPB1i+2bgy6q6T0ROB+ZiJnZ1Gm4niKammBzAxE2zWCwpw5133smdd94Z72YkPTEz66nqh8DeMNs/UdV9zuICTCiMTsWthDS1QSGpqg0TF8/N9phF9rBYLJYuTaKMOV0J/CPURhG5RkQWiciiaPr5+48zeZNOnxTt4eLHPmPVDg+BX8MFd20vVhOzWCwpQNyFk4ichBFO/xuqjKrOVdUZqjojI4qRENyaU6NH1WnFdiOUNuwOklIjGDt2tLVZbaesLPbHsFgslk4krq7kIjIF+DNwuqp2+qh8TV2LFlZb723Qyae4eJq029QEOz3ERfSlyGgvdsKexWLpYsRNcxKRYcArwKWquj4ebbj+2S+af1/19KIwJVsQYmBWW7Ei+nVaLJaYkEohhOJJzISTE+H2U+AwESkRkStF5DoRuc4pchdQCPxJRJYGZGjsFA7UtyPJm082OYrTutIqDtbbh9Vi6SrcfvvtPPTQQ83L99xzD7/73e848cQTOeuss5gwYQKNjY3cdtttHHXUUUyZMoVHH30UgPvvv5/vfve7AKxYsYJJkyZRU1MT9Dj33HMPl156Kccddxxjx47lsccea97229/+trnuu+++GzARzcePH8/VV1/NxIkTOfXUUzkY7XQiCUTMzHoeItxeBVwVq+PHitpDRqBV1NSz/2A9P3xpGceO7M1PZk+Ic8sslq7HtsptHGyIbgecm5HL0J6ho51feOGF3HLLLdxwgwnA/eKLL3LHHXf4pc6YO3du0DQXN998M7NmzeLVV1/lvvvu49FHH6Vbt24hj7V8+XIWLFjAgQMHmDZtGmeeeSYrV65kw4YNLFy4EFXlrLPO4sMPP2TYsGFs2LCB559/nscee4wLLriAl19+mUsuuSSq1ydRsOGL2sg7q3cB8PKS7cw+woTNX7ljP6WVtQzomRPPplksligwbdo0du/ezY4dOygrK6NXr14MHTrUL3VGqDQXI0eO5Mknn2TKlClce+21zJw5M+yxzj77bHJzc8nNzeWkk05i4cKFfPTRR7z77rtMmzYNMKGUNmzYwLBhwxg5ciRTp04FYPr06RQXF8fsOsQbK5w6QMk+80ZXXdfAVU8v4q2bYpNnxWJJVcJpOLHk/PPPZ968eZSWlnLhhRcC/qkzQqW5ANiwYQN5eXns8OCpGyqFxh133MG1117rt624uLhVCo2ubNaLuyt5MvPT11fFpuJwc5nq6mJzTIvF0syFF17ICy+8wLx58zj//PNbbQ+V5qKyspLvf//7fPjhh5SXlzdrVqF4/fXXqa2tpby8nPnz53PUUUdx2mmn8fjjjzcHn92+fTu7d++O/kkmOFZzagPPLCiOdxNs1AmLpROYOHEiVVVVDB48mIEDB7Ju3Tq/7aHSXNx6663ccMMNjBs3jr/85S+cdNJJfOlLX6Jfv35BjzNlyhROOukk9uzZw09/+lMGDRrEoEGDWLNmDccddxwAeXl5PPvss6Snp8f8vBMJK5zawJtLPcxZcpMswWQtFksrVrimeASmrwiV5uLxxx9v/j106FCKiorCHmPKlCk8/XTr8KM333wzN998c6v1K1eubP79ox/9KOI5JDPWrOeQ6SHRYKOGn6j7zT99HK3mWCwWS0pjNSeP1Dc0UdcQXhM61Kg89Ukx3zluuBnoDDd2FIuBzKYmSEsz41IikJUV/WNYLJY28cQTTzBnzhy/dTNnzvSbS2VpTUoJp0ONjaQ1Gu2nvtF7jow3l27n0f9u9lT2pcUlnHXEIHp1jyAYop2jo7TUjEcdcQT4VP/p06N7DIvF0mauuOIKrrjiing3I+lIKbPeN//0Kfe8uRqAmkMB0SHCKEVeBZOPv6/YwSdFe9ravI7hSxgYi0joFksn4yl2pSUoXeXapZTmBGZOEkBagMWtvi0JnSLwwuclALw1Y0zU6oyITaVh6SLk5ORQXl5OYWFhq3lAlvCoKuXl5eTkJH9AgJQTTi14e+i3lNuI3xZLZzJkyBBKSkoos6lg2kVOTg5DhnR67taokzLCqb3RxDfvab9w2rW/lv7t3ttiSU0yMzObwwRZUpeUGnNys3J7hadyaR0wK7yypKTd+1osFksqk7LC6aMN3hwWOmLz3rGvtt37WiwWSyqTssLJq5Vv6bZ97T5EF3GasVgslk4nZcacWuESHH3zsjhiaEHQYu+s2tXuQ5Tuj0PEYCsRLRZLFyB1NScXJkx96/Vl1R0zy63fVd2h/S1xYPFiG1zXYkkAUlc4if/PYBPXrnii45njZz/wERt2VbVtp5oacAWd9ISdDxI9Skvj3QKLJS6IyNdEZJ2IFInI7UG2DxOR90VkiYgsF5EzYtWW1BVOLlkkIuECRHSYW19c1vadDh2KfkMsFoslBCKSDjwEnA5MAC4WkQkBxX4CvKiq04CLgD/Fqj2pK5zcmpNAU2eP1dixIYvFklgcDRSp6iZVPQS8AJwdUEaBHs7vnkDkdL/tJGWFU0a6OfU+eVkmlFGMZcVHgbH2rGZksVg6lwwRWeT6XBOwfTCwzbVc4qxzcw9wiYiUAG8DN8WqsakrnNLMqTc0KghEOUY4AOoaB/rtP9cCUNfQSG19Y6hdLBaLJVY0qOoM12duO+q4GHhSVYcAZwDPiEhM5EjKCqdhvXMBmDmm0HGIaNv+Y/vlAfDNaYOZMrhnxPKNCqt2VPKtxz7jvEc+bWtzLRaLJdZsB4a6loc469xcCbwIoKqfAjlAn1g0JmbCSUQeF5HdIrIyxPbDReRTEakTkU7PN9w3LxuAUycOIA1B22DXu+GkMXxpnLkf50wbxE+/Pt7Tfv/78grqGmKho7mwY1kWi6V9fA6MFZGRIpKFcXh4I6DMVuCrACIyHiOcYhKhN5aTcJ8EHgSeDrF9L/B94JwYtiEkZdV1zb9FIEIGdj9OnzQAVeWMyQPJzkinrsGa6SwWS3Kjqg0iciPwDpAOPK6qq0TkXmCRqr4B/BB4TERuxYzUX64xSiAVM+Gkqh+KyIgw23cDu0XkzFi1IRyPOQkEt5YfcFzJvV3fyYONo4qIkJ2RDtD8HVOKi2Ho0IjFLBaLpb2o6tsYRwf3urtcv1cDMzujLUkRvsjxKrkGICsrQvrzCFTk5PstN6kRNF5zDf7yG5M7dPx2U1kJjSE0tOpq6Natc9tjsVgsMSQpHCJUda7PwyQjo2PytCardYZI41PnL52qaoOnO+/UzJxt0ZZthAiLxdKFSArhFEvSxBlzCpADnxSVx6dBFovFYrHCCUyW3MAIEQ+8X9SmOl689ljumu3Na88zVhuyWCwpSszGnETkeWAW0MeZTXw3kAmgqo+IyABgESYURpOI3AJMUNX9sWqTm755WZRVH2JgQS4i0OjBW+/+C44Iua1bVgaTh0Se75TUVFVBfn7kchaLxdJBYumtd3GE7aWYSV5x4YzJg3jq02JGFHZnw25vqS3G9g/fMedmer+cP3ppKTeeNJYRfbqHLhRoa2xo8Fx/M/X1RgPr4FgdZWWwdSuMGgW9enWsLovFYolACpv1TMfvxXL263Mn89ZNJ0T16GtLq7nx+SVt26m2Hfmlli+HZe2Iih7q2DYmoMVi6QRSVjj5dJI0D/naczK9z2M6dXx/7po9nt+e583lfM3OSmY/8BFPf1rs+RgWi8XS1UlZ4dTU5F1zaotH9/dPHsvRIwtJ8xgLcf46E/njxUUl3g9i6RrYUFPxpbbWmL0tCUnKCqfKg+ahTAsinT4v3tvh+sf1z/NU7sMNeyIXCsJby3fw09eChi20WCxeWLWq7RmnLZ1GygknQVFV3ly+0ywH0Zx+9ubqjh9HhFevPz5iuara4E4OW8sPUF1ntj2/cCvrS/1TvT/ywSaWbKvocDstccROFYg/VntNWJIifFG0aXTFKoplxIc++dn0yctiT3XbnQi+99wS8tc0cvyhXbyzahd//Wxr1J0yLBaLJVFJOc0JzMvSaRP6k57gL65fbN3HO6t2NS//+b+beOLj4vg1yBJd7Ft7aKq9Te+wdF1STjgpQhMKIuTnZHrbpwN9yE/OnOC57KMfbKSuoZFQEehfW7qDl7+wjhOWLk5pKaxbB/s7ZT6+JUFJObNeTUa28dRTJd0RzT1zM6g82I4Jrh4Y08+bYwTAm8t38ubynRw3qnfYcrMf+KijzWo/dpzEEmt8c+qsJ11Kk3KaE056jEbV5vGmL43tS/esTsjJ5JFPN3n3FtxZcTD0xqqq0Nss0aWhATZuDJ3WxGKxtImUEU5ux4cmhSbVZjdyEQIzZgSQuGMDVz+zOPTGWAgnO04SnF27oKICdu/2vo/VQi2WkKSMcHLTpEpTk0s4IZ0ifq778qhOOIqLnTs793iWtmEFvcUSktQVTkrzmJMIntO0t4fTJvbn5q+OaVMYpLZw52tmImHR7uqQzhRRI1nf9tesgXKbo8tiSRZSUjhpk9IEvhS4iBA2TXtH+/ubvjKWUyYMYOaYwo5VFIJl2yr5bFM5t/xtKRPueofHP9rcutD+/WZcpMlDbpBkoKkJiou9R2qvqTHlLRZLUpCSwqkRRZuUdFrMer5xpeUlFa3K9+zmzeU8ErmZGbx+w8yo1BXIK0u2B/3dzIYNJjr5mjUxOX6nU15uPjt2xLslFoslBqSkcFI1USLS0locInza0YZd/pP/HvrWNPr3yInasdPTYmMWq6jxGIWiPWk3LBaLpZNJuXlO4FiEyms4cMhnEpKQprvhhWGSASYQ6ckwFnToEGRlxbsVFoslCUhJzalJldL9tc1BV2PtEBHImZMHRL3OdbuSYE7TihXtG8CrqoLKyui3x2KxJCwpK5zcpLnMeocaYu8wcP2sMTEL4ppfdyAm9UaN9gin9euhqCj6bbFYLAlLSgqnQHdr9xzcvy7c2untiSb9DuyLdxO6LnV1xuvPYrHEnJQUTo2BypEEjwGR6FHLLQFs2ABbY/hysXJl1/F2tFgSnJQUToHjSxLCIaItQVsTjZhPxk1E9u+HsrJ4t8JisUSBlBROgZqTT0EK7NB7d4+tZ9lpE/rHrO5UlE0Wi6XrEDPhJCKPi8huEVkZYruIyB9FpEhElovIkbFqSyCtxpzEt96/3KkTo+9V5+amr46NWd2fbS43qUEsramutvO9LJYEJ5aa05PA18JsPx0Y63yuAR6OYVv88PXZ4xyz3aFGsyKwKz9qRPi8SonMfW+v5ayHPmbhZhtPrhXr1sGqVd7L790Li8NEf7dYuggi8jURWecoDbeHKHOBiKwWkVUi8lys2hIz4aSqHwLhEhOdDTythgVAgYgMjGF7mn83qZKXncFhA/IBmLfYZJddvSN+c2nu/rr3jLlt4d631jD7gY/YtT9AU7Cag3diHTA2GSZQW7o8IpIOPIRRHCYAF4vIhIAyY4E7gJmqOhG4JVbtieeY02Bgm2u5xFnXChG5RkQWiciiBq+BPsOgqqhqqz6h5lCUE8V57HSmDy/w09Ie+tY0vn30sOblL4/rA8CIwm5MGdyzXU3ZutflAl1WZjSH6urQO/jYscNEdkhU7OCaxRItjgaKVHWTqh4CXsAoEW6uBh5S1X0AqtqGBGZtIynCF6nqXGAuQPfu3TvcGzU2Kaq+gK8tBE7O7Qzck3HPnz6EblkZDC/szvDC7lx8xQQaV66i5lAjH6zfw9lHDGJwr1z+5+UVbT6O35n65urU1kJeniOogwjS2lqTE6qiAvLz23zMpCNegs4KWEvnkCEii1zLc52+1UcwheGYgDrGAYjIx0A6cI+q/jMmjY1FpR7ZDgx1LQ9x1sWcJow7eWB/XBpo+upkLjt+hP+K6mrS04T8nIxmIaaq3PyVMcz5T3QiJsxft5vfPfgxcy6ayuhQhWznGRlrmrMkPg2qOqODdWRg/ARmYfrsD0VksqpWdLDeVsTTrPcG8B3Ha+9YoFJVOyV1q6pxikgL6FAaomzV6zBBJpSKCKe4vAgfv9w8a+lN4Rt/z5ur+XTjHhoC/Oj/tWYXAGt27jcrGhvteFR76IgA7wzBVlFhnDrsvbWExovCUAK8oar1qroZWI8RVlEnlq7kzwOfAoeJSImIXCki14nIdU6Rt4FNQBHwGPC9WLUlkKYmZ8wpYH16tK9GJ2gc/fJNOg8v3dt9b6/lnD994rfO18Tm/deuNeNR27cbQRWO6mrT4R1I8Hh+FtjnhLWy4ZcsofkcGCsiI0UkC7gIo0S4eQ2jNSEifTBmvk2xaExEs56I3Aw8AVQBfwamAber6rvh9lPViyNsV+AG702NHrv21xq38YA31vycpBiC6zguodn8y3ctfG/WpaWRO7L9+1u+uydHahGLxRIcVW0QkRuBdzDjSY+r6ioRuRdYpKpvONtOFZHVQCNwm6qGdGcVkcJw28PhpTf+rqrOEZHTgF7ApcAzQFjhlMg8OL+IzDRpbU1JtaGVLVsoKDXONkE1LzvWZLGkFKr6Nsaq5V53l+u3Aj9wPl5YICJLMQrOP7QNcdW8GLJ8/dYZwDOqugpvVqSERtWkyvBbF+2DJMEgeW61mdsVl6ZWV8OePe3bN1Gu7b59yZ0qXtVk37RYYsM4jKf1pcAGEfmliIzzsqMX4bRYRN7FCKd3RCQf4/CWtDSkpaO07t/Kquqaf08bWtCpbWorPzxlHNd9eVTzsqh2MEdUHDr7detgy5bYHmPTppbxlljVv3Nn4gjLtrJ2LSxZEu9WWLooTpCF95xhnquBy4CFIvKBiBwXbl8vZr0rganAJlWtEZHewBUdbXQ8+fK4vnyydler7vjlL0qaf0fdOSLKnHR4v+bff7hwKuXVdWFK+1NWXUdf34KjLtY3NCb2ZNv2sm9fbIWTj2Q1gYYbV2xqgl27YMCA5BW+lrgiIoXAJRjNaRdwE8bJYirwEjAy1L5euuDjgHWqWiEilwA/AZI6Z/aRw3oBrSfh+mLsAaQH2vwSmDH98jhmVKHn8t946OPm33sOGKE297+bjSZjsfjYtcuYLHfHLAiApevzKdADOEdVz1TVV1S1QVUXAY+E29GLcHoYqBGRI4AfAhuBpzva4njS6MnGnjzCyc2kQT08ldtSbty//UI2dYbmtGdP8moZqYbvf2Lvl6X9/ERVf66qzWYpETkfQFV/HW5HL8KpwfGwOBt4UFUfApIulk31oZY4clvKjSnjpcXbQhUnPRpmjL3h4t7Ghl+dO4VHLomcfeSG5+I0zrBtW+zfxPftM+NAFosl3gSLbH6Hlx29jDlVicgdGJvhiSKSBmS2oXEJx+4qM5enwXkx7JGTwf7aBsb0zaOozAixb04PGoM2KRjSq5uncs8uiLEzQiiiELw3LJucOYEDOxDkfu9e6J28KVMslngiIqdjnOgGi8gfXZt6AJ46AC+a04VAHWa+UykmpMVv29jWhCLQSnHLySb6hk8wARw+wJt5LJl54fNtrNlZ1bxskxO6KCmJXMaNdRiwWNzsABYDtc637/MGcJqXCiJqTqpaKiJ/BY4SkdnAQlVN6jGnQLIy0uPdhITguYVbuOTYEfFuhsXH5s1mLtjkyfFuiWHvXhOdPjOpDSeWTkBVlwHLRORZVW2XqSSi5iQiFwALgfOBC4DPROS89hwsUUh1/aDgYFXQ9e+vLevklnQhYuE0sHdv57n3b90Ka9aE3t7YaITlhg2d0x5LUiMiK0RkOfCFiCwP/Hipw8uY053AUb6kUiLSF/gXMK/dLY8D/1m7q/l3bqaRySP7mHhwzRG5uxDXfXkUj3ywicEFOWyv8BaJeldVHQ2NTWQk+iQvS/Qpi/Bi4hO+9fWxb4ulKzC7oxV4EU5pAdkOy4lvqo128cnGltiDvs730mNNttnANBJdgdlTBnHKhP5kpKVx7TOLUVV2VUWeqPuftbs51ZWSw2KxWNqKqnbY28qLcPqniLwDPO8sX0hAYMBko95JBeHL59TYRe182c5Y2p8vMzmfZj/wUcR96uoTLamVR9rqwGAJj6qd32TpME6uvgeA8UAWJtr5AVWN6HHmxSHiNhE5F5jprJqrqq92oL1xxyeMfFEgUuU/2C8/m90RtKc9BzyYbZqaTEqNRAoYumtX5DKxpCt56zU0wLJl8W6FpWvwICYv1EvADOA7OKneI+HJPKeqL6vqD5xPUgsmaHGZbsmEmxrS6Sdnjo9Y5uUvSti8pzp8oV27zCTXSOMUEPsB/VR5s+hM7LiSJYqoahGQrqqNqvoE8DUv+4XUnESkiuC9tpjjRVbLEpXGZuFklgP7tzMnd80xl1F98zyV21Jew8g+eeytPkTQaahtCWuzcWPkMjt3mgmv2dme2peQJIKQXLwY+veHIUPi3RKLxUeNk1V3qYj8BtiJR6UoZCFVzVfVHkE++cksmAAa1XSuaY50agroWM6b3nX/3F+fEjlqwu/eXc8H68s495FPWLMzTIxfLx2yl7TgO3ZAUVHkcpbIxNu8abH4cylGztwIHACGAud62THpvO7ai3tEoMF58/eZ9QL72D55SfwGH4HzPQre375jIpRvKjsQy+a0EIvxq0QaE+so69Z5M6NaLAmEqm5R1VpV3a+qP3OGhjy9iaaMcHJLJ59DhE84BWpO0pUGtwPonZfNYQO8mfdENXajceEmfHrByz2qjjB2lkxUV5uJshZLEiEiM0XkPRFZLyKbfB8v+3pxJe9yrNxeSQYtwulQV/UlD0G3TG+3PbvxELX1AdpHMFPe3r3GNDdtmvdGeDH3JRNd6YUm8B53pXOzdDZ/AW7FxNVr0zyV1NGcXDT5zHrO2U8d2jOOrUlc0puaePKTYnbv948w8Yd/rWf2Ax+x74DjiVfrbK+vhyonNFKsI49bLJZkoFJV/6Gqu1W13PfxsmNI4SQiVSKyP8inSkS6RLwfn7fe8N7eUkykKt/+82d+y88vNOalf68Nkpdpzx7zfSDGY1WNSTpZOBkIpil1pjdiIng+WqLF+yLyWxE5TkSO9H287BjSvqOqHU4oKCJfA+ZgZgX/WVV/FbB9OPA40BfYC1zizpgYa3wJBQf2zO2sQyYU9541kWGF3Xh/7W6e+rTt0Ub+vWYX50wdFJ1YfG01HcW6A6uvN+M8ed7G52yHarEE5Rjne4ZrnQJfibSj515FRPqJyDDfx0P5dOAh4HRgAnCxiEwIKPY74GlVnQLcC/yf1/Z0BF83Io7qlJmRWtZNdzfaJy+b82cMjbjPna+uoKrW31S3bd9BXl+6I8qtSyDWrYt3CyyWpEZVTwryiSiYwFvKjLNEZAOwGfgAKAb+4aHuo4EiVd2kqoeAFzCp3t1MAP7j/H4/yPaYEpVU7EnI9OEFAAzomeOpvKD8Z+1u3li6nbIq//GnvQc6KaVDPImmVmRj1nknRf+fXQkR6S8ifxGRfzjLE0TkSi/7elEZfg4cC6xX1ZHAV4EFHvYbDGxzLZc469wsA77p/P4GkC8ihYEVicg1IrJIRBY1tHOg3d0f5DgpM9JS9OE/Z+pg/nrVMQwqaJs5U4HT5/zXb93ry/w1J1Xlo6I9zVE4Ugavz9IXXyR+TqRUEJ51kSP0W6LCk8A7wCBneT1wi5cdvQinese7Ik1E0lT1ffzthx3hR8CXRWQJ8GVgO0HcDVV1rqrOUNUZGRkd935vDl8U5OwnD07q4BeeEBF65vpnM83PCX1dFUERFm/ZF3R7WXWtr2L+vXY3v/rHWt5clgDmvmh1stF+iakKnuwxJMFeyCoqotKUpKGyMrrz1rZti1zGEg36qOqLQBOAkxXXkzeTF+FUISJ5wIfAX0VkDiYMRSS2Y0JV+BjirGtGVXeo6jdVdRomqSGqWuGl4R2hPiBChJtvHtl1QxeF47iRrRTWVmzYHbxzuOKJRc2/fe7lgea/hGF3EA/DcKxZE39NorTUf3nvXhOzsK3n4pVwwjhe1oaiIjsGmJwccKxhCs0pNMLERGvBi3A6G6jBTKT6J7AR+LqH/T4HxorISCfw30XAG+4CItJHRHxtuAPjuRcT3P2LOn+w+obWnU5qGvrg+pNG8/QVR7V7/9kPfER5VS2rnazCrySqo0Rb35hrahJvzpYvarjXiO87dxqB1hFWr+7Y/pZU5QeYfn+UiHwMPA3c5GVHLzaya4G/qep24CmvLVLVBhG5EWNvTAceV9VVInIvsEhV3wBmAf8nIorRzG7wWn9bUb/fZmlAz64bQ6+tZKan0TtETEHxGMTotgffY3dxcNOfH52VkqHc01y/rs+OKLwoJJqAtiQLq4FXMQpOFfAaZtwpIl6EUz7wrojsBf4GvKSqnkIfq+rbBGTNVdW7XL/nAfO81NVhXP1rQ1oNIvldOoZee/nRqeP43buenp1WbNlWhicXi84SGrGeCGxJXux/v7N4GtgP/NJZ/hbwDHB+pB0jmvWcSLITMVrNQOADEflX+9saH0Rcclg1Zc13kZh1WD9+fvZEv3VZjQ30qakAoM+BCvoeqAi6b05DnDygtrR9AnGnUV6efOlA4j3GFu/jpzAi8jURWSciRSJye5hy54qIikgk57hJqnqVqr7vfK4GJkbYB2hbbL3dQClQDvRrw34Jgdfn3f4tYNqwXrx03bFBt2U0eQsblKbKvS8sZPYDH7FwcznPfdZGAeILTxTpxqm2hExKRIqLjaeZxeKFvXvjFk3fY+AERCQfuBn4LHBbEL5wnCB8+x4DLApTvhkvk3C/JyLzgX8DhcDVTkSHpCKwi7OaU3hyMzP44Slj272/qLLwn58CcO9ba3huYRscEQ4ehKVLzR+1q75FNza2BMy1WHxs3hxPr0QvgRPAzH39NeDlAZ4OfCIixSJSDHwKHCUiK0RkebgdvYw5DQVuUdWlHsomLuq/ECpLhhVaLXxpXD/+33ttmDAqEh1h4kunsX8/FBR0vD4vqBqh2FmsX5+4aUPseExXJUNE3FrLXFWd61oOFjjhGNcyTtDWoar6dxG5zcMxv9buxkYqoKp3tLfyRMJrl9lF39PbRXqacMbkAby9ojRyYQ9sLKtmdF+PgVQ7m+3bOzfFeXsFU01Nx93CLZ1LRYXJyty7d7xb0qCq7Q6g4Ez7+T1wudd9VLXdA8KpFfHUQe2boWeieaVufmFpFGuLMsmSNXfNGu+CraamfZ6RiWJK7Sr/040bjbku8YkUOCEfmATMd0x0xwJveHCKaBepkwk3Uf5wSUaPnMyg68+ZOoglWyvYsjeGpqmKCutM0BHWrIl3C7oOjY2Qnh7vVsSa5sAJGKF0Ecb1GwBVrQT6+JYdX4QfqaonB4e2kpKak8U7F8wYyg0njW5ePnV8f56/+liuOnEUh/Vve8qv5SUV3gs3Noae/BmNt2pVWLUq8edDRfvFat8+M55n8caOHcZBp4tPRHbi3vkCJ6wBXvQFThCRszq7PSkjnKze1D4yM9I4fdJARvbpDsCQ3rnNQWIzMvwFhHjoRH/86kq27WvRtqpq65n9wEd8tL6sfQ3syPhLebnxmFu7tv11RJv2doBtEWCbNiV+ZPR4snkzrFzZsrzPiXrSxYUTmMAJqjpOVUer6n3OuruciD6BZWfFSmuCVBJOrv9umnZ59TzqHDGkZ6t1Ewa2L4L79c9+wYPvm4mpWx2z4CtfxCABcqQO2502IZz21FlefHv2wLJlnes1GIyuMtYTjnDnuHevTamRAKSOcHL9TiODwQXBE+11z7KCKxjN2YNd/+lZh/XjiSvaNxb6z5XGA9BJRkyTHRNsMbWl4vynmhpYsiQ6cRdra1PzGnYxUkY4ter8Al6c7r9gKudPH8KEQa01BAt8aWxfAGYM93eH7ZvnLZtuMB79YCO+G7GmtCpkvqh246RGSTqqquI7DhaPF4Vdu8z9Wr6847mqVq0yH0tSkzLCya06KUp6gHQa2z+Py44f0bltSiIOG5DPWzedwNDe3VptO+uIge2q883lO9m137zhKnDuw5/w3mqPc428RNpu9BZqqdPw6tZdVpZY42CdYebzy2ljtWhLKgknFwcyu9mI5FHksuNHcNtph/Gd44a3eV9fBHTf3bj66UXeOqdoCJ7O7ARVTZw9S3RYvNjMH7J0WVJGOLm7oZ35hSkx5ttZZGek8+VxfTlt4oAO1zV6z7bQWpF9o/Yn1a9Hoqaqb2w08fGsU0WHSBnh5IcET9Fu6Rg9czP95kS1h9yGOuprW/7UW8sPUN/gjB3FKi25xTsd+d+oJv6csmhQUWEijuzcGe+WJDUpI5w04C3TyqbYMm1oQZv3yT1kxp++8adP2FNdx4qSCr733BIuemyBKZCsDg7JRqz+HCUlZizNetJZPJBCwsl/Od1Kp5gwdUgBABcfPbTVtsy08Nd89N6WuU6XP/E5q3ca1+q6hiaqarv+BMg209DQ9rE3L899rMyFvpiAKTCZ1dJxUkY4BWJlU2wYWJDLWzedwIRBPenfI9tv22kT+7eprsamlk7y4scW+C1bMJNFfS7TnrNpRijnVXCo2hBIjY3WMSOGpI5wChBGVjjFnr9cdhSPXz6jOdzRmUcMClk2s7F1p3iw3l8ruHDup7y3OjrpO4D2awiqrTvxeJkcozFp1c2KFd7KlZaaEEhVVdE9fmcRjQ7ApwkmqmNGkpMywimwH0oToejANjYeiEHYHEsz/fJz+ONFU7nzjMMZ2qsbb910gud9X13i77W3pbyGOf8u4sP2xuGLFjt2mDBDbpNaZ+aC6giROmWvQtY3buQTjg0N8R1Lqq6GQ4c695j2DTempI5wcv0WMhGEyvpqKuqT9M0vieibn8Nxo/tELuiR37yzjqraKGsMbSFcINBQ2lhXd/tesyY+URnq61tct93BWi1JT+oIJ7/OIY20lDnzxOO17x3f4To27EqS5IDRoLPTubdHI+io1tJe4b18uUln0ZE6LAlJTLtoEfmaiKwTkSIRuT3I9mEi8r6ILBGR5SJyRqzaEhhbz85zih8Z6WncecbhHarjrjdSKHZaZ2fptZ28N2wfElNiJpxEJB14CDgdmABcLCITAor9BJPQahom6+KfYtWeIO3rrENZgpCT2fbo7/2r/XM3LS+paDV/LalJxnNJxjZ3FvbadIhYak5HA0WquklVDwEvAGcHlFHAlxSoJ+Ahmmf7ED93PQ103rN0MlPbMUk3kB+/utJ7oFiLxZJUxFI4DQa2uZZLnHVu7gEuEZES4G3gpmAVicg1IrJIRBY1tHMCX+A7zL6aTvbssfgRLc31j/8ponhPCoTEsVhSjHi7BVwMPKmqQ4AzgGdEpFWbVHWuqs5Q1RkZGRntOpC/+UcpLu/kQWZLK+ZcNDUq9dz4/JL27ZhoZhdrak4u7P2KKbEUTtsBdwybIc46N1cCLwKo6qdADhA9n2MX/qIpjm7IlmZG983jt+dPiXczOgcvgrAjwjLRBK2bPXugsjLerYg9qjYSeRSJpXD6HBgrIiNFJAvj8PBGQJmtwFcBRGQ8RjjFZIal+7/bIHGexGlpZvyAHswcXQjA2H55/OjUcbxwzbFtrqe+0QaFTTh8JvgtW6CoKL5tCSQWWs/OnWaulRVQUaF9NjIPqGqDiNwIvAOkA4+r6ioRuRdYpKpvAD8EHhORWzHKzeXapdyvLF6444zxrNxeyeED8slIN+9LQ3rlULLPRBz4yuF9+c/a8C8UVz+9iP857TDue3sNlQcbeOGaY8nLjvB42+jY3ti2LXKZQOrqTEc9tHUA4E6noQHaORzQJnyhnGyqjKgQ0zEnVX1bVcep6mhVvc9Zd5cjmFDV1ao6U1WPUNWpqvpurNqSId0CPPYsicSkwT2bBRPAnIumNf/u1S0r4v57qg/xPy+voPKgeVv/f++ui3zQZI0LF0g8onxH0jx82kMoc97mzS2RNmJJfb0JNbVzpzGfRCODsg875hRT4u0Q0YkI6do33o2weCQ7o2Ue1PkzhvLDU8byi3Mmed7/8+J9rCutYkt5CnjyxSNskFdCGUL27oVNmyKX6yi+yBUVFSaXlC+ahCXhSRnhZJ59azFMRvKyMzjp8P5MHVrAMSN7e97vhy8t44bn2unJFw3cHW4sO8WGhtCp7WPNwYOdH3C1vcQ6DJSXSB6Nja013e3b4YsvYtOmJCZlhJOla3DrKePavM+GXdEx35VV17KxLEIHFK/xBq8p7KNtAty1K3yajUimr1inbe8s05vXaO5Llxozo5vS0sT2towTKSacrI04mXjs0uk8dul0v3URnRyCcOuLyyIX8sAVTyzi5heWtqyIdi6lzqCz8k75tKlIne7atbFvSyDREoglNt1OLOkEF5bEQANMel85rC9g3Y8TmYEFufFuQnjWRXC6iJep5tAhYypqL2vWRC4TSujs2wdZWcZ9PFFZuxYKCzteT6pnAo4xqSOcAv5LOyoPAtlBy1oSm3SBUX3zyM1MZ/l2b5M7F2wqp6FJOWFMTOZ4R4doZVTdti10XdEaH9q7N/j6TZuS04utsRHS0yOvizYbNsS2/iQmxYRTy59mbWk1VjglJ6/faLLpNjYpf/7vJt5cHnmc5xd/N9rAtV8axdeddPH7DhwiIz2tOY28Z5J5kqXXsamOkAzjJ24B2thoxoIGDIDBrvCfmzfDmDGxbYfVvkKSMmNOipJGt3g3wxJF0tOE/j1z2rTPox9uYn2pcZC49PGFXP7EZ7FomiWZ8DmJBGqD0Qi5lAyC2oWHHHw/EJHVTv69f4vI8Fi1JXWEU3I9IxaPfH3KoDbv84OXlrHSMQfWNdgHwxJDAj3zEhiPOfiWADNUdQowD/hNrNqTMsLJ0jVJTxPmBnj0eeH2V8K4P7eBNTv3J2/6lXjNjeqMyBDtIdhY2eLFrde1xeMxmhEpYk/EHHyq+r6q+iaMLcAE9I4JKSOcAt+Pv330sLi0wxJ9BnWyV9/rS7ezp9qMO902bzm3vbS8YxXGy4Eg0edkJQpus0tTEyxZ0nE38rKyeJhzMnx58ZzPNQHbveTgc3Ml8I9oN9JHCjlEKCBMHtKDK08YxYUNeSyu8OAya0kqbjhpNCMKu3PbvOVkpEGDh5fc+sYmMl1x/Q7WN5CZluYX68/NY//dzPx1Zdx/4VQASvd3MICstTnHjmhfW58mtKuDGZi3bm09Ty72gYgbVHVGNCoSkUuAGcCXo1FfMFJGcwIQ0khDGNM3L95NsUSZ62eN5s4zDuf0SQMZP7AHcy6ayv9901uuqG/86RO/5fMfWcBv/hl+DlPNoQaqaqM4CTea5p9wZqe0OP7lw4X38SpEvIQIcuOb5Ou1/oMHW6/zeWdG2xwZeM/dnnvV1fEwCXrJwYeInAzcCZylqjFzXU0Z4WRfTrs2Z04eyHGjW+Ywje6bx9h+eWR4fMLXlO5nbWlL5/DJpnK/7R+u90/Zsb2ilppDLZ2He992EU339HDuyR39I8T7j9SWSPLtESbh5h2FO/aePW0/VjjWrfOfDF1a2nFtLTIRc/CJyDTgUYxgiql9NnWEk/OdjPMDLe0jIz2Nedcdz6nj+0cse9tLy/lRiLGjTWXV/Oad1prUBy6B9cu/d8BEbB9K77TlWrmjnscaLzmvgjlXhMP9wrJ9e8zDJalqA+DLwbcGeNGXg09EznKK/RbIA14SkaUiEphANmqkzJhTC7YjSCUy0tP4/slj2XOgji+2VkQsv/dAS4dw/3vrueXksdTWBzevPP1pS4ieRg8axcrtlYzu253crBT828WbICayNTv307t7FpFfXVIHVX0beDtg3V2u3yd3VltSR3OKtznCElfuPdtbLqjvPP558+9/r93N1x/8mL8u3Bpxv8qDDdQcCh3xe291Hbe/soLfv7feUzsSlvr6tmsAiUBg2CZVbpu3nCufWhSf9lgikjrCyfm2FpTU5cIZ7ZuSsWybt0gBFzy6gDU7/cd76hubqKptoMbRvraUB8kp1M4Xp3dXlbJ5TxsdBJKAqtp6HplfxKHtOzrXNAfJGWm+i5I6wsn5/2eldadbZjcY1/a8QJbk5tLjRvDK9cfF9Bi3zVvO7Ac+4o//MgPrf3hvPRc/toD6RuNBt6MyhLtwiMH2ot3VNDUFF15//E8RNz2/tMNtjjuVlX5vjU9+XMxbK0r5YEmxcWooLvZeV3vSYcTLqmKtOWFJHeHkfIukmfQZ+flxbY8lPmRlxDjKtMO7a4xn1QcbjBdXVW2Lya+qNsANXSToYPe60ipu+dtSXlxkBtsv/ctn/GNlaVTaV9fQyAfryxLD3B0wv6fB1yZf08pdnpORTB9e80O1x4TSFk9BsMKng6SMcPI9KIIkxh/SEjdu/kqMI00H4aH3i5p/X/zYAi5+7DP+u2GPn9AKpMyJQrGpzGgD+2rq/erpCE99Usxv31nHUo8my5gS4EbfFrFRXl0X9hpakpeUEU52zMni45SJA7j+y6NifpzZD3zU/Ht7RWtz3q//uZaLH1vA8q0hciM5LCwu96srGAfqGpj9wEe849KsHv1gIz95baVfuUMNjby2dDuljnnxreU7qGto/2TPy59YyM/fWgVAU5O2T1CUlQVdHfQVsrHRL1fVZU98zuVPLGz7MTsD+xLcIVLOp1WQVllxLanHmVMGceaUQRyoa+DCuQvi2pYtX6xhwf5a3li2k4e+NY3hhd39tgeGYHp5cQlPfFLcvLxm5352VprIBg+8X8RpkwYA+OW5+vlbq1lbup8zJg/ieZf34Web93Luw5/y8LenkZGWTmFeZptMn3uqD7Gn+hCbyqr5vpPC/rmrjqFHbqYp4NUU5jU9RZB4gHUeYlT5pgPkZIY5t2i/uUYreWSKElPNyUNukPudiVxLRWS9iFTEqi3ul5iIZr2pU6G/nf2QCnTPzuAvl0Ul3Fi7efS/m3ljmel0b3huSbMDRKjH1C2YAP735eWsLQ0tBP72+VY+27zXcXcPriVd/9clXP3MIn7+VvsmE/sEE0C5a65YfWMT760q9XPqeG3pdmY/8BFrdnaeSfG8Rz7lormfeir7efFe/r48RMT2ouiYVYGWPFKWoMRMOHnJDaKqt6rqVFWdCjwAvBKr9vi0pTRJo0kjvGmlp8OQmEWCtyQYPVyZcKcOLYhfQxx+8ffVbRoXDeHM18wzC1o0pTQJX3jJtgpeW7o9rIv6na+u4L3VpSG9CD9Y3xLK55UvSpjznyIue2IhS7bu4+ChBv78380APPrBZj7bVM7OipZ4dpUH6z2NOX1evJfKg/5u38tLKth3IHT6kogKluOY8bM3V/PwByFc2GuCTAVoL1Y4hSWWZr3m3CAAIuLLDbI6RPmLgbtj1Rjff11Ia7tZr3//zohrZYkTuVkZvHXTCX7r5n64sVmbCUbP3AzuOH181PJCuVlYvI95i0sY0IYsv42RJJTDjooggU0D8AkP9zVpalLmfVHC6ZMGsqykkmUllfzt8+DhdA45Y1h1DY1U1BgBsq+mnp++vqpV2Z//fQ2ZacKj35nOd580E2J75ppuKZh83lFxkGueCT4J+MevrqRvXhZPXHF0xHNMCLwEsa2vh8zM2LclAYmlcAqWG+SYYAWdVL8jgf+E2H4NcA1AVlZWhxqVJu2IFuE1kvPgwSYGliXpueZLo7nmS6ODOiK8fsNM0tMkphNgn/p0C9+bNdpz+f0uLeIbD33M5TNHBC332WbvwVB/+tpK1u+q4pQJ/fls8152VtbyzIKWkE2hUoW8sWxnWMHuwxdRo75J/QRd5UHf+iYemV/Et48dQb6j3X5cFDzA6uIt5rzKqls0pw27qjjU2MTEQT1b7+AhUGt1XQP/XFnKN6cNJi0tBp5UXvqh5cuhb9/oHzsJSBSHiIuAeaoa1CCuqnOBuQDdu3dvlzdDy3MQwayX4/1t1Y+hQ6FfPyucUoD0WHRUQfjT/I2ey366qcXjr75JeczRfjrCkm0VALy2tGX8JZoOaO4JySV7W5vLXlpUQvmBQyhw/Szj/v+UK56hm7vfaNHK/r58B4cN6MGtL5oU6a9ef3yr8n4mSeek5vxrPVOGtAiyh+cX8cH6PQwuyKF7dgZThhS0qmfNzv3s2l/LrMP6Na/7Yss+SvfXcsbkgUHbGpJQDhQhvBm7OrEUTp5ygzhcBNwQw7a0uJI7w2whtafDDotlMyxJxvNXH8PFj30WdFtBt45p8ZYWVu1sneaj3Bk/Wr2ziptfWMLGMm/RHwLHi77xcEu+rtkPfMQTl8/gYBDHkPfW7Oa9NS1ZIHxjZw+9v5GKg/Xcf8FUxvb3zwV32zwTyd4tnO5yBGWbhVNbJ/l2cWIpnJpzg2CE0kXAtwILicjhQC/AmytNO/EJIxVjMgiqPWVlQUaQS5KXZ+y+Nu5WypGfk8lhA/JYV9rahNerWxbPXXUMm/cc4OkFxUHLWDrO5j3tCEkUhiueXMRA13jeipIK1i4KnY6iwjGZVhwM7WwRijeXbmf97mp+eKp96W0rMfPW85gbBIzQekFjHLYhy8k650vH3aRNMMxR7Lp3D7WboUcPyM42v0PF5LMT7rosJ4zpE3Jbj9xMjhhawC1fHduJLbJ0lJ0uk+J5j3zKU58WR9znZ2+upra+kaYm5WB9A/Uu97/GJuXNZTv81oGZJvD+ujJeWLiVsx4MP5EajOt9WOeWFPLwi+mYU6TcIM7yPbFsg4+Tx/dneUklY/r0NsdFoXseHDYO+o6ElSvDVzBihJkAmBchxXufPtHPimmJK+dMHUxh92x+8846BvQIPiY5tHd3Zk8ewFsrShnZpzub9xzg2hNH8trSHeyqilkma0sUkDZ47573SIuBJyez5d3+3dWlPPrhJg6GyP317GfGnb+0spbi8gPUNyonjm390vONP33CxEE9+PW5U4I3oKEhuHWnC5IaZwlcMXMEm/cc4NLjBrC3dnvLmFNaureZ4dnZRkBFYvhw80nGnDeWoIgIkwb3AODI4QUhy103awzXzfKP23fUyEIe/XAjR43oHdbB4Zypg/wcD9zcerLRyu7/V5gU4pZOp7a+RUv6fLNxSHl5cYt58NpnW+eKWrG9gjn/NhN5Txx7QqvtAKt2tB5/S0VSJrZefk4m9184lZ65ZhA76Fwnry7jAwZEsWWWZKB392wevXQ6V5/Ytph8A3rmcPfXJ3LaxODPzEvXHcvT3z2Kq04c1WqulZuvHN6PORdNbbV+ZJ8IJmngmJG9KMj1nytz68lj+dO3pjXv/8NTUjeFTFZjx01lC4uNK/sBl6PF9n2tXe3b4oGZ6qSMcPIhzvxzvyGurCwzR2lswLjBkCGt10FwIRZp/lW3bsHXT3AFzZg4MXwdlrgyuCC3ecyyraSnCX+5bAZPf/coXrr2WLIzzHOYm5lB7+7ZzeWevOKoVvvuqDyIiDC6bx4zxxQyZXBPpjmRLLx4td922uGMH+ifImZ4YXeGFXbn/50/hb9edQwnHd4vxN7eCeaynQwMq4hOGhIv1DcGNyEuKt7L31dEnhuWSmPbKWPW8yGOCa+V5hRMG2pLfL1I86OCPVSjR0NurhGCIu2fY2VJCvq7xqvmXjqDfTWtvT/75GXTIyeD/a7o3rPGtUzCvOP08QBsKqtmyQtLycuJ/BfOyUxnUEEuAF+fMpCLjx7WHJg1KyM9YqDXh789jdL9dfzszVDBXeDKmSPIzGgtuEf3NZpZJDfws6cO4vUQZs1ATjqsL++v85/7c9NJY3ggSulEOov3VpdyqKGJgT1zuSfEtd22r4ahvUK82HZxUk5zSpMI85w8VRLkskUSLMEcKQoKzHf//mYCbzBsjo8uSWFeNmP6BXeuefKKo5h33XHN2s7Q3q1NdyP7dOfKE0bwo1MOo19+dqvtgVx89DAuO244V54wsiVieBgmO2NsvuMfNaJ3qzI/O6tF0//GkSYW5T1fn8AfLpzavH7ORdNCnqeby48b0WrdKeNb/yfOnDyAH556GIXd/c/BF4k9mZjz7yIe/mBT87yoQD4u2sP1z37BpxtdDlZ1qeNck3LCqdms15G0GW5BMmwYTJ8eWYgMHdo+s13PIKFXQtG7dQdiST6yMtLJyUznF+dM4unvtjbzgbEAfGPaEHp1z+Khb03jz9+ZwXeOG87gghxuP/1wnvmuiS9X4MSpy8lM5/wZQ8kIY5Z87NKW6Oz/980pHDWiV9Byz1x5NE9cPoPpw1tvnzGidyth9J3jRnDCmD4cMzJ4fQCZGWnMvXS637qvOKbGfJd26NM+v33M8OZ13bPal924f342g3uZ+n58+uHtqiOW+OZ3bd5zgIvmLjChtFIoxmfKmfV8HKyPHAAzJCItLuOhhFJ2dstbztCh3s1248bB+vXmd34+FBZ6zwtjtawuRXZGOtkecivlZmWQm5XBBTOGcsGMlqAsv/rm5GazmhcGFvg/n3fNnhA04nmvNkbG6Jmbye2nH46q8vUHP261/dwjBwM0mx4BXv3e8WSmpzU7ifhiHJ491ZQ9deIATp04gNLKWrpnt1yjQJNoKPrmZfGXy1sEvzuc0Rs3zKSsuo4rn2rtbdeZvPC5CU3apCbOX6qRcsIpPc08yFsrt0Yo2UEmTYJVq0wY/vz81ttDCSr3hOBwwmbMGNi9G/Zbt1NLcCYNboPW7eK86cZEJyKkBzyCs8b5z825/fTDyfAYa1Bcz/ObN86ktqGR5z7bykVHD21VNtDx5OFLjqQgN6tVXEN35Pb/+8YkBhXkctkTn/uV+cU5k/hsUzmHD+zBl8cFD6Lq/qulpQn9e+TwyCVHsnnPAX79z3Wezi9WLNzsypTsJZJ5FyHlhJN4yhbjgb59jebUo0fkssEINcYUOJ6VHWY8YfhwKC0NHxgyIyOlZpVbOkY4d/Z51x3XSmiEip7xg1PGBtW67vvGJLplpiMi5GZmcOUJ/q75N5w0hgE9Wj/zXpwCJgcJzAomR1ekPF0+wTmsV4v2NqRXN4b06sbu/XWtEjwGct85kyitrI2aU8bGshYhtMkVvumJjzdzxfTpwXbpcqSccPJpTh2mWzcz1hSO7t2N5pTezmNmZRlvvkmTjJAKnNiblWXGvHzCKdCdPS/PaG07d5pth9oeGywokyebc1q6NDr1WZKCsCnOA/jK4cE9XY8IIUB8nB4Fx4Z7z5qIiLC9oobC7t5NkPedM4nhha2F4NDeLQLr1+dO5rD++byyZDtPOxHSf3v+FMYP6MERQ+FQYxOPfhgiUWEbuNmVWdjNy19s54oO154cpJxwykrvxEjSw4YZDSnYHCgv40NDHXOHT3uaPt0Iu23b/E2Fhx9uBI9b+BQWmogWTU3mM2gQLFnSUt/EifDFF20/p6yslvMRSal5F5bk4EjHUWPasII27XdECO3q6JGFzb99uaEumDGUCQN7MKx3Nz/vx68fMSgqwsmSgt56nUpaWuvJt31CBxH1o2fP0C7rY8f6b+veHXr1Mp9gbRgyxHz7XNeHD/cXjkccYYTosGFmOTvbCLNguLNyRgqYa7F0YSYN7unJLd/Nb86dHKPWdD2scOps+vUz40DhXMQnT4ZRbQuTAxiNZvjw0NtHjzbal0/r6t/fjJllZBgtLTe3Zb1PkGVlGecLdx0d4cgjO7a/xRIH5lw0lUcvbf9Yjy9I7PiBLWPUwaKBWFpIObNe3MnNNZpKODqSit43vuUlcvGQIf7LeXnG3JeT0+IGn5vboiGNHeuvObnJzTXa2ahRsGJF6GOKGAG1di3UtM5+GpLBg2OfZXjSpMjR6S0pyei+kScSh2Peda1DO/XJy+ala49FgQseXeCpnt+fH6Hv6EJY4dTV6NXLmOcKCyOXDYbPxT0722hJ+flG4AVz/hgwAIqKzDGDaXq+Malhw2Dr1pYoGSJmPGx1QMiW0aON8MvJMWV8Y2RgNM0ePeDAAVNXe0hPN2bVUBMZs7PNeWxqw5hBR5JQ9uoF+/a1b19LlyA3y3TBz199LPWNjXzn8c/Dlu/ZrW1mxGTGmvW6In37eo+wHo6CgvCehj17GqEVygQ5caIxCRYWmo/bJJibC9OmtSzn5Jjjde9ujpmWZjSsiRONKTQ314zf9e3bYpbs2dMIualTzXr3+J57jM3HoEFGWwxsc9++LUkk3eN2ffoYbbFvXxg/Pvh4oZc0Kn2Dz61ppbkOGtTiBGNJWnIdE95bN51AL4/CJD/HPwCwm1PG92sOURVr/yMR+ZqIrBORIhG5Pcj2bBH5m7P9MxEZEau2WM0JWLzDuGgP6TGEPt36RM/dPFU57DBjVszObvE0DNaJp6WZ9cXFMHJk6+2+qBqBHfa4cbB3rxE+PiHsc+bYv9/s4zONlpUZTWvoUP+5ZT4tLienZV8fffoYbcg3fueby1ZQYOa29e9vPCB9+/XpY8ygPq2vXz/jZt+rlzkvEWhsNG32aZE9erR+gRg40Bx3m4kM4BdlxEdhIZSXm2j2JSUtk7DHjzdlfcLVN+1g8uSWa1FVZaKPBNP2CgrMJy/PzJ2LlDBz6NCWdkJLnXl55v6rmmW3iXfoUNPG3btb1g0fDlu2tCzn5BiP1EgUFpq6wk1K7dfP/1g+vM79GzQIdngLRuvmiZtOoqGyCoD7L5zKtr3+0Wju/voEDniM+HD/BVMZUdiN6/5q7mfdYbELsyQi6cBDwClACfC5iLyhqm4Tx5XAPlUdIyIXAb8GLoxJe2KcHT3qdO/eXQ8cCB/hOBJbKrawpyb8n09EyEzLRNHmYLGRaNIm0iTNb6KvotGb+NtVaWo0SR9jRWMDpEfpPUzV2zSAujojFIKV9f3nRIygOlRnNDnfNdhbbjrQHj1hV6kJXzV2HKD+16mpCQ4eBG2CvIAoJE1NpgMONX5ZexCyc4yZNCcbMgLe8GtqjHCprjYaaX4eHKgx0xUKCowwUjVCrLA3ILBtqzH1Zruin7jP1X3sJm3RdBsboLISertM0Y0NZp/KSmhohF4FRrD26AFDXC8rNTVQuhMGDzHXoke+aXOTmnb6jr+rFHK7GQ08K8ucW1MTbCwyLxd9+oCkmWexoQGyss3xS3ealwrfS8HYceZ+lO02+/Trb+YRpgn07dfywrFtq3kZ6Flgrl1JSxJCBg1ucYhqaDDPgDbB6DHM/t2/mlN47Mov5OWLJ0BuLnf/dQGfNfXg1RvPYMbQdjhLASJSo6ohXWxF5DjgHlU9zVm+A0BV/89V5h2nzKcikgGUAn01BoIkJYUTGAHV0NRA3+59Kdpb1CpKeWZ6Jj2yeyAITdoUohZ/Kusqyc/KbyXMQgmoDgWftVgsHaexsf2T5NtS96FDnhydVmyvpHe3TLpJE3vqlLH9zUvHjoqDvLu6lJtPmkpht/aNJ4vIIcDtrTRXVee6tp8HfE1Vr3KWLwWOUdUbXWVWOmVKnOWNTpkIqnbbSVmz3vCCFpfrIwda92aLxRJ/RoUI3D6qF5wwssPJSBtUdUbkYomBdYiwWCwWC8B2wD3AO8RZF7SMY9brCZTHojFWOFksFosF4HNgrIiMFJEs4CLgjYAybwCXOb/PA/4Ti/EmSGGznsVisVhaUNUGEbkReAdIBx5X1VUici+wSFXfAP4CPCMiRcBejACLCTF1iBCRrwFzMCf6Z1X9VZAyFwD3AAosU9VvhaszWg4RFovFkkpE8tZLNGKmOXnxmReRscAdwExV3SciIZIcWSwWiyWViOWY09FAkapuUtVDwAvA2QFlrgYeUtV9AKoaZMacxWKxWFKNWAqnwYBrCjklzjo344BxIvKxiCxwzICtEJFrRGSRiCxqsFldLRaLpcsTb4eIDGAsMAvjtvihiExW1Qp3IWei2FwwY06d3EaLxWKxdDKxFE5efOZLgM9UtR7YLCLrMcIqZGjempoaFZGDobZHIANINdXLnnNqYM85NejIOedGLpI4xFI4NfvMY4TSRUCgJ95rwMXAEyLSB2PmC5uvQFXbbYoUkUXJNEM6GthzTg3sOacGqXTOMRtzUtUGwOczvwZ40eczLyJnOcXeAcpFZDXwPnCbqsZktrHFYrFYkoeYjjmp6tvA2wHr7nL9VuAHzsdisVgsFiD1whfNjVyky2HPOTWw55wapMw5J13KDIvFYrF0fVJNc7JYLBZLEmCFk8VisVgSjpQRTiLyNRFZJyJFInJ7vNvTXkRkqIi8LyKrRWSViNzsrO8tIu+JyAbnu5ezXkTkj855LxeRI111XeaU3yAil4U6ZqIgIukiskRE3nKWR4rIZ865/c0J84+IZDvLRc72Ea467nDWrxOR0+J0Kp4QkQIRmScia0VkjYgc19Xvs4jc6jzXK0XkeRHJ6Wr3WUQeF5HdTlZZ37qo3VcRmS4iK5x9/igirdNwJwOq2uU/mKjoG4FRQBawDJgQ73a181wGAkc6v/OB9cAE4DfA7c7624FfO7/PAP4BCHAsZtIzQG/MnLLeQC/nd694n1+Ec/8B8BzwlrP8InCR8/sR4Hrn9/eAR5zfFwF/c35PcO59NjDSeSbS431eYc73KeAq53cWUNCV7zMmvNlmINd1fy/vavcZ+BJwJLDStS5q9xVY6JQVZ9/T433O7bpO8W5AJz0MxwHvuJbvAO6Id7uidG6vYyK/rwMGOusGAuuc348CF7vKr3O2Xww86lrvVy7RPpgII/8GvgK85fzx9gAZgfcYM3/uOOd3hlNOAu+7u1yifTAZRjfjOC0F3r+ueJ9picfZ27lvbwGndcX7DIwIEE5Rua/OtrWu9X7lkumTKmY9L0Fokw7HjDEN+Azor6o7nU2lQH/nd6hzT7Zr8gfgf4AmZ7kQqFAz2Rv82998bs72Sqd8Mp3zSKAMEz1liYj8WUS604Xvs6puB34HbAV2Yu7bYrr2ffYRrfs62PkduD7pSBXh1OUQkTzgZeAWVd3v3qbmlanLzBEQkdnAblVdHO+2dCIZGNPPw6o6DTiAMfc00wXvcy9MWp2RwCCgOxA0U0FXpqvd1/aSKsLJSxDapEFEMjGC6a+q+oqzepeIDHS2DwR8ubFCnXsyXZOZwFkiUozJC/YVTIblAhHxRTlxt7/53JztPYFykuucS4ASVf3MWZ6HEVZd+T6fDGxW1TI1waBfwdz7rnyffUTrvm53fgeuTzpSRTg1B6F1PH0uAt6Ic5vaheN58xdgjar+3rXpDcDnsXMZZizKt/47jtfPsUClYz54BzhVRHo5b6ynOusSDlW9Q1WHqOoIzL37j6p+GxOP8TynWOA5+67FeU55ddZf5Hh5jcREwF/YSafRJlS1FNgmIoc5q74KrKYL32eMOe9YEenmPOe+c+6y99lFVO6rs22/iBzrXMPvuOpKLuI96NVZH4zXy3qM586d8W5PB87jBIzKvxxY6nzOwNja/w1sAP4F9HbKC/CQc94rgBmuur4LFDmfK+J9bh7PfxYt3nqjMJ1OEfASkO2sz3GWi5zto1z73+lci3UkuBcTMBVY5Nzr1zBeWV36PgM/A9YCK4FnMB53Xeo+A89jxtTqMRryldG8r8AM5/ptBB4kwKkmWT42fJHFYrFYEo5UMetZLBaLJYmwwslisVgsCYcVThaLxWJJOKxwslgsFkvCYYWTxWKxWBIOK5wslk5ERGaJE1XdYrGExgoni8VisSQcVjhZLEEQkUtEZKGILBWRR8XkkqoWkfudfEP/FpG+TtmpIrLAybfzqisXzxgR+ZeILBORL0RktFN9nrTkafpr0ubbsVhiiBVOFksAIjIeuBCYqapTgUbg25hApItUdSLwAXC3s8vTwP+q6hTMLH7f+r8CD6nqEcDxmKgAYCLJ34LJOzQKEz/OYrG4yIhcxGJJOb4KTAc+d5SaXEwgzibgb06ZZ4FXRKQnUKCqHzjrnwJeEpF8YLCqvgqgqrUATn0LVbXEWV6Kye3zUczPymJJIqxwslhaI8BTqnqH30qRnwaUa2/srzrX70bs/9BiaYU161ksrfk3cJ6I9AMQkd4iMhzzf/FFx/4W8JGqVgL7ROREZ/2lwAeqWgWUiMg5Th3ZItKtM0/CYklm7BubxRKAqq4WkZ8A74pIGiZ69A2YhH9HO9t2Y8alwKQ4eMQRPpuAK5z1lwKPisi9Th3nd+JpWCxJjY1KbrF4RESqVTUv3u2wWFIBa9azWCwWS8JhNSeLxWKxJBxWc7JYLBZLwmGFk8VisVgSDiucLBaLxZJwWOFksVgsloTDCieLxWKxJBz/H9dP8gDJF91+AAAAAElFTkSuQmCC", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "# vrex\n", + "algorithm = \"vrex\"\n", + "penalty_weight = 1e-1\n", + "model = MLP().to(device)\n", + "train_OOD()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "J9-5hNDn7367" + }, + "source": [ + "# PAIR" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "3uED5QFc_wxx", + "outputId": "05b5a0a9-1aa2-465e-fbf4-3721b41e051a" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[150|10000] train_loss=3.15613e+00, valid_loss=9.29974e-01: 1%|▏ | 149/10000 [00:14<17:05, 9.61it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Restricted license - for non-production use only - expires 2023-10-25\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[1204|10000] train_loss=2.06298e+00, valid_loss=7.60535e-01: 12%|█▏ | 1204/10000 [05:15<41:19, 3.55it/s]/home/yqchen/miniconda3/envs/gnn/lib/python3.8/site-packages/cvxpy/problems/problem.py:1339: UserWarning: \n", + " The problem is either infeasible or unbounded, but the solver\n", + " cannot tell which. Disable any solver-specific presolve methods\n", + " and re-solve to determine the precise problem status.\n", + "\n", + " For GUROBI and CPLEX you can automatically perform this re-solve\n", + " with the keyword argument prob.solve(reoptimize=True, ...).\n", + " \n", + " warnings.warn(INF_OR_UNB_MESSAGE)\n", + "[10000|10000] train_loss=9.97147e-02, valid_loss=1.59599e+00: 100%|██████████| 10000/10000 [46:17<00:00, 3.60it/s]\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "# PAIR\n", + "# define the preferences for the epo solver\n", + "algorithm = \"PAIR\"\n", + "r2, r = 1e4, 1e-8\n", + "preference = np.array([r]*1+[(1-r)/2,(1-r)/2])\n", + "model = MLP().to(device)\n", + "model.update_preference(preference)\n", + "train_OOD()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "id": "rQmF7csfJP9p", + "outputId": "591578a5-2696-4eef-d9c3-643749f75415" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[10000|10000] train_loss=2.41655e-01, valid_loss=6.84301e-01: 100%|██████████| 10000/10000 [41:39<00:00, 4.00it/s]\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUQAAAEICAYAAAAncI3RAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8qNh9FAAAACXBIWXMAAAsTAAALEwEAmpwYAAB1s0lEQVR4nO2dd7gkRbn/P3Wm54TN7Fl22QS7sktYySDJhCyKIooBFRNmMOdwkZ/hmq6i1+s1XBUVkWtArgnBCChJXQSWIDkum1n2LBtP7Dn1+6O6Zqqrq8PM9JyZc858n6ef6TTd1dVV335TvSWklLTRRhtttAEdzS5AG2200UaroE2IbbTRRhsB2oTYRhtttBGgTYhttNFGGwHahNhGG220EaBNiG200UYbAdqEmBOEEN8RQnyi2eUwIYRYI4Q4pdnlaKON8YJUQgw61YAQYrcQ4nEhxMVCiGnG8TcKIaQQ4lXW/04SQqw3tq8VQgwG19kqhPiVEGJ+vo/TPEgp3y6l/Gyj7yOEmBbU4R8afa+8IRS+JIToC5YvCSFEzLnzhRC/FUJsDNrXkpRrzxZC/FoIsUcI8ZgQ4jUp5x8lhLjeaNfvM44dIYS4QQixQwixPu5DJ4T4ZFC2U4x9Fwgh1gkhdgbl+Lj1nyOEELcKIfqD3yOMY88RQvw1uO8ax/0+K4T4lxDCF0J82jomhBDnCyHWBve+VAgxwzh+d/CsevGFEFcYx2VQd/r4941jHxFC3CWE2CWEeFQI8RHHM8XWlxDilUKIe4P/3yOEeIlx7CwhxP3Bf7cIIX5klftgIcRfguMPCSFeahzrFEL8IuAoKYQ4yfWeqoKUMnEB1gCnBOsLgbuALxrH/wr0Ab+z/ncSsN7YvhZ4a7A+C/gz8JO0+7eXyPt4Q1DfPrBP1nfXCgtwLnA/sChoS/cAb485dx7wTuAEQAJLUq79M+DnwDTgGcAO4Kkx584BtgCvBbqA6cDBxvF7gM8DBWB/YBPwYusa+wP/AjaadQwcCEwN1hcCdwMvC7Y7gceADwT3fW+w3RkcPxZ4PXAOsCbm3b8AuBz4tOPYfcDioA4uB34U8/wCeBQ429gngWUx538UOArwgud7DDgrS30FdTAclFsALwT6gbnB8cXAnGB9GvAT4OvBtgc8AHwwuPbJwB7gAKM+3x+8703ASXW30QyNONSpgC8DVwbr+wGjwMuxOigJhBhsvxO4O8P9jwVuAXYCjwNfNY6dHbycPuAThMn7WOAfwPagsr5pNLwlQQPwXOUDlgHXoTrVVuDnRkP6L1Rn2onqEIcExy4GPhes7wVcCTwBPBmsL7Lu9Vngb8Au1MdhTkZS+UvQ+FYDH7aOvd6oj/Oz1ofRId4JPBiU6bOoxv334FkvM8+PKdtJwHrg40G9rQFeaxz/O3COsf0WYFXKNT1SCBGYiup0Bxj7/hfjw22d/wXgfxOu1w+sMLb/DzjPOuePwGkkfHRQZPAv4KPB9vOADYAwzlkLPN/63yk4CNE4/mOihPgL4CPG9onAIDDF8f9nB+94qvX+nYTo+P/XgW9kqS/gOGCL9f8ngBMc150GXAL8Ptg+BNht1defgc86/rueHAixKhuiEGJx0AhuC3adDdwipfwlcC/qi5vlOr3Ay4CHMpz+38B/SylnoDroZcE1VgD/E9xzPjAT1QA1Sqgv8RyUlLES1eGz4LOoit8LJc18I9j/POBZwAHB/V6JIh8bHcAPUR+MfYEBFAGZeA3wJmAu6kv34bRCCSH2Q5HOT4LlbOPYCuDbKFJcAPQGZdfIUh+nAkcDx6OkgguB16G+4ocAr04rI7BPcI+FKKnlQiHEgcGxpwJ3GOfeEeyrFwcAvpTygYzXPh7YJoT4e6CmXSGE2Nc4/jXgbCFEMSj7CcDV+qAQ4hXAkJTy966LCyH+TQixG9VJpwI/DQ49FbhTBj04wJ0J5awWwlrvApY7znsD8Esp5R5r//VCiM1CmbOWOG+gTBzPREm+Gl8jvr5uAe4VQrxYCFEI1OUh1HPraz5DCLEDRdIvD66X9IyHJByvC1kJ8TdCiO3AjSjJ6QvB/rOpvOyfYnTQGHw9ePCtqE7zngz3HgGWCSHmSCl3SylXBfvPBK6QUt4opRwGPon6ygEgpbxVSrlKSulLKdcA30V9GbNgBEVmC6SUg1LKG43904GDUF+te6WUm+w/Syn7pJS/lFL2Syl3oSQ6+94/lFI+IKUcQJH8ERnK9XpUh7oHuBR4qhDiyODYmSjJ/Xop5RBKYh41ypSlPi6QUu6UUt6NMo38WUr5iJRyB/AH4Eiy4RNSyiEp5XXA71AfDlASwA7jvB3AtKCT1YNpKCnWxA7Uu3JhEYoU3of6YD2KUrk1rkTV5wBKDf2BlPJmACHEdFT7fx8xkFJ+Mbj3UShJVT+z/fxp5awGfwTeKoRYIoSYCXws2D/FPEkIMQX1bBdb/382SnM6CGUGuFII4Tnu82kqH3yN2PqSUpZQUt9PUUT4U+Bck4yDPjwT9V6+jJK6QZlXtgAfCcj2eUE5Q8+UJ7IS4kuklLOklPtJKd8ppRwQQjwdWIrqmKAe9FDTSOzAe4MHP4yK9JWGt6AkgPuEEDcLIU4P9i8A1umTpJT9GNKaEOIAIcSVwRdvJ6oRz8n0tEo6EsA/A2P0m4N7/AUl6X0L2CKEuNA0ABv3niKE+G5gVN8JXA/MEkIUjNM2G+v9qM6ShrNRkiFSyg2oj9MbgmN2feyh+vp43FgfcGxnKeOTluTxWFA2UOqPWV8zgN2WxFQL7Ovqa++KOX8A+LWU8mYp5SDw78CJQoiZQojZKHL5DNCNko5PFUJoafrTKHV7TVKBpMJtwb3+vcZyVoOLUKR+LUp6+2uwf7113suAbai2Y5b3einlsJRyO4rslwIHm+cIId6NaoMvDD66pNWXUA6nC1CaTSeK0L7v4omgTf+RgFOklCPAS1B2x83Ah1DCg/1MuaGesJs3oEjjdiHEZuAmY38ipJT/Aj4HfCtNOpBSPiilfDVKtfwS8AshxFSUHaxMqEKIHpSaqPFt1NdqeaBuf5yKSqE7rPml2ce452Yp5duklAtQjoD/EUIsC459XUp5NLACRdQhj1uAD6GMz8cF936WLmbSsyZBCHEiSv05LyC1zSj7zGuCL/kmVGPU508he33kib2C96OxL0riANVRDzeOHU5Y9aoVDwCeEMJUD5OufSeGNmGtPwUoSSkvCaTp9agOelpwfCXwXuMdLAYuE0J8DDc8lKmHoDyHWW3+sIRyZoaUclRK+Skp5RIp5aLgmhuCxcQbgEsyfIQkRvsIhIJ/A1YGdaKRVl9HANdLKW8JyngziiviwsHM+kJKeaeU8tlSyl4p5anB/f6ZUvaaURMhCiG6UWrQOagH1st7qHTQNPwI5Ul8ccq9XieE2FtKOYpyCIBSBX8BvEgIcaIQohP15TYb2nSUGrVbCHEQ8A59QEr5BKqhvC6wa7wZ4yUIIV4hhNBk+ySqcYwKIZ4mhDhOCFFEkeoghlpq3XsA2B58QT+VWhvpeANwFYqIjwiWQ4AelAfvF8DpgT2mE/XFNt9vbH00AP8ehEQ8EzgdZWQHpTp9UAixUAixAPXhuDjuIkE76wo2u4LtCAKJ9FfAZ4QQUwPt5QyUuurCD4GXChUuUkSZF24MTAMPqFuL1wghOoQQ+wCvomLzWomq9yOCZSPqo/mt4PxzhRB7CYVjgXcB1wT/vRZly32vEKIrkLhAOcoI/t8NFIMydAfvUtdHMTjegfoAdGutQ6iwo/2D+64Avgp8Jug3+v+LgOeg+p5Zz08N6qIgVEjdf6L6x73B8deiNIrnSikfseoyrb5uBp6pJUKhTDzP1MeFEK8Vgf1WKBv55436QghxWPCcU4QQH0b5Cy42jpvtojM4t/YPfZrXBYcXDTgLJZEUrf09KDXtdFK8zMG+j6GcMkn3/zHKjrAb9dV7iXHsjSgvnfYybwCeGRx7Fkoi2g3cgCKIG43/vgBlO9qOagDXUfEyXxBcazfwMIFnFNUZ7gz2b0Wpr9OCYxdT8TIvCJ53N6rBnIvh1bbrIniOGxPqoBtFzC9yHPsf4BfB+huM+rC9zGn1EfIyouzFbzS2Pwd8P+VdnYRSZ84P6mct8HrjuAjqdluwXEDYg7hbvz+jTKEl4d6zgd+gPlRrgdcYx56JUs3N898RvOMngSuAxcaxk1EdeQdKVfseDm+t3T9QRPXH4Nn0u/+49YxHAreiPpirgSOt+rOf+Vrj+MWO428Mjh2Asrn1o8wUH3SU9TzgBsf+k4P/7kH1td+gNAl9/FGU/Xy3sXwna30B70Y5UHcBjwAfMo59Pmgze4LfC4Fe4/iXg3e0G2XHXuaof7tOlqTxWtwigouOewRftu2oF/lok4szKSFUYOyPpVLZ2mhj3GFcD90TQrwoEKWnAl9BxXytaW6p2mijjfGKliBEIcQfRHhYkV7OF0LcJoS4MuavZ6BsOBtRDoez5DgWeQN7iqse8nA85AIhxMdjyjjuhhK20YaNllaZhRAfBI4BZkgpT087v4022mijHrSEhOhC4BF7IfD9tHPbaKONNvJAlvCYZuFrqADp2Ch+IcQ5qNAfinB01qjrVkYHKnq1GPx2FkCUN4KlC+V37oQhz2OYTkYoBr+d4e2RThgUKkBoCDXidwQVLDSKCgLRv2VI4wTTeadhjxDzVMk7gtVCsBSDbbP8XRKvc9guJcXyorYLw1KVcwQ1Sl4vZnntZ7CLKoKlQKVsnrWuyxuUb8QrhEpTKVURX2/7ReRQQdXlcFCvI1S2fYJC6R1mYXX9SmPdrnsXuoyKLKpC6/J3GrvNU4qjeEUfD7UU8CkG60VGyvvL66USwjfqXBfdeAe37mKrlHLvmEKOe7QkIQo1GmWLlPJWkZDSR0p5IcpNzwIh5DljU7yGQfPHQlTE+X7AomlQnI8Kb94XNX5gKbAc5FNg7ey9WMNSNjKfdSxmEwtYx2I2Br/rNyyGh7pV0MN9qMCGzSh/vA6gGAx+fVAd0kdFhQygojgG9EHCTaaIirSaDUxRYe5zULmMZgXr+wQPswhYAh3L9rDPvI0sZh2LWccCNjKfjSxgEwtQ+xcMbWTqWqksw1uCpQ8VzNKHCtCwl4DwpQ++Qe7FLtSo86lBMXuDbb2ufxer+t08dxobWcAW5rKRBeVFbz/OPNb1LWZkzQzlvlsT1Kn+1fXLzmDjsWB9Z1CHI8YvRr1qjODGElTL0L9TKvVr1vGiynbHPnuYM+9x5tBHL33MYSu9wfo8HmcujzOPLczjcbV/xw6Km1Djk7YYy+NB3W8BcR2PxRRwQqAlCRF4OvBiIcRpKFlohhDix1LK1zW5XA1HkUoXCXUVu98YKFgHC2FxLx2Ra48QFs9cJ3vBcc992AUPCl6JAuHFo0QBP9j2KfijSjIpoaSvQevXXAYJkeHAEPg+eEGxvAJK6tHSpb1u/eoymNDbmerVs1fiulgRN/nF7dfHYu7pWduourZRMOo6chl7lx+zPoHRkjZEKeV5UspFUsolqCDwv0wGMjQR1yXKcPRNz9oZafR2xzHhmys+0d5gkiSO4ykI7l3wVGfsZCiWFLtMovNxE6BBhJTCZDjiq19fS4tD1mPZS8wj6LLp9fJ+zzh5zEQKk2CLYZXfPMXiYa9cx1ES9KwPExD+SMCkIUKNVpUQ28gJHV4pOrZQv3Wb94AwFdsk6DrHgNmJTEHSWDyvZEhcfqhDdjGMR6ki0Wky9HFLh8EiBytkODBkPaoPRVsSNNfNcpdUmTyTIKh8aELSYxoZZSYS18lxUmJMd40pS4i4A1SeL0yQBXxV7y7Y9TSB0fKEKKW8FjXUbcLD9dEHwtJgLY0y7S2Xr6kN+rY3w75xjc0mUJm9QCIETYphYiyTlakm26RoSJAmGWoaGSkF5gdfSY/Clnwizx6FSYBVmyGAsIprGkPqQWBpTpMQA+hnsEm+ctwPS4cu88IkQssT4mTHiG90K0OSSUNqB45t6K5eYDoBauzYRksz1eTKPkWMnQyFO6OpMuttgxhHLDK05SovUJuLSeqycT+vVKJQcJNH+Zq2sS2xF8XY/SIXSKrTlGs4iVFGy0lFTY4jSCcmETG2pA1xMiJLt7FhqjimWpcZEXOh2eqT1OUEC6dLkDR+tVNFl9Wzl1IpLKVosjLV5OC49GFgUH00dKkGKoeVjFtSx51OFLu8FmzJSm+HYJNhebuYcFIWuFqE3ieqkhBNW6ENrT57pRgpcZKhLSG2EHQYmRMxjTPSQeOQ+U271GXT2+wZ+2IkmwRSLBTCxG17nAt+KWo7NJ0rhnQ4MKTIbqCkiNDlEx9BSZCxviL9q73MvltCTHVQVdWTXPWW1fDoaCFVe5ljpENXU6rXXDPO0CbEFkaqpzlAZlJ0IUQSdoxc5ASixKj/h5sIzXVP2Shtu1yYEEfDUqGWEi2P88iQkg5tMtR0btJGj7YjmlJhBvODLVmVpXAH0YSfkwaQh0WEqV5mP5DGTftsuFChYzp40zZXwKSSFtuE2CJIiohxdq4UZ0DsTWKhA7I1bJXZJEltR0zp9QkSlMuZUu6wLluf4UyRg4oM/VJYnnVJiKDO8wMnS6hz252/BAV/lEKXX10sYmy91tu9XN5mxzWrkFbjQnAK/uikkADT0CbEFkM9L8Sl5hW8EqNJsYeRTmCPlbOpRgdk26pbKON8BZbUEi2zHyZHl5fTVJ9LSlUeGKqMoRkwHsclVZf32bbDFAKIUy8LXikhJECjljdZhWiZKiEq1d+MQ7Q1CdPBAkRtt3qfxiSQEtuE2AJItB3aaFijNBVO23HiUplBDdtzXMYFp2ATjvVTThXc4R9BGI6WDk0yNInQvr2WsUb8wNMMUWI07lnwCZOEgUSHVdrAlEiJ7PUsSJAOLWLsiFHr45wshSQtxBWuNEHRJsQWQpEoOfo1NsRMarPTfuhSmyGqMuPYb90nQYrShONZHVTYJGgJqwNDFbuhSYYu6dCkG98kQ3M9BqYab0pX5Xr1HBKxszdV28VsKVGPcHec5vDim/tNk0TYW26HO6XYDycR2oTYAkjUvEzYjdRorK7GXjtcKrOpJtsk6Yd+nDA6qQnbyxyRDC21eWAwngxtd09EanQNSbPuo7TheOkKKp7y8s0yvby8uprDueL4tR0/tp3WtiOGTo9zOE0CcmwTYgvB7FsRFTqmMWYmwCQ7YnklTmW2CbCWqEmzKAkhLQ6p0JQY/VIyGbqkRAhiEe2OntE+ZsciGgdyQFa12YtuxpCh/vWI8y6HTRWRccx63cQkIENoB2Y3HS5qqbaf1SwVlhu5OWTPhJ3MIY52QheLV5cTs68Enl3bC2xJdXYeHhcZ2vt0gHaomDEqtPDD0lRsSJPtIMrk5c36ZqtoAbESotuBZZoqzHbjHMdsq8+TAG1CbAEUcdsPE5Fi/6r+/zbxmepynFMlY7ky9O/IKBV7GVSxhGYp7djDOLlWF0m61OWMsNXN6lCfRB2LFAkxbhyzWX5nuNMk8CbHoa0ytyCyqss2omE3WTuufV5i8ErCObWhrNr5JffQukBS9EtRmna5fEbK13XItHbYjUM1t4nElfmmfIPMNsQkZFWbi+H7Ev+bNI4ZHHkfs6jIk0BSbEuITUaKM7ZuZM/dZ1INRBXTpP9YSHiYcEKHhOBgh3NlYDCZqu3S2GTo2yRbD+JGq+SGTGJ1zK8eEeQOsbHfgXMcs962P1ATHG0JscWQ+kIa+pW2qSUuFhHCMYjGeXEPELPfKcU6yBA/cIwQpWyzdKnBQTHSp0kEtnRoS1N1DZWsCQmJIjJ6mSGa8Db0HPbpk4QAbbQlxCYizmaYmuAhwQbm9IbGoXxKEvHZlro4pKt8Olg4EhRcVplxEqGtMsc9RpxLqHzcln5cyOhtzh9x7rWET2SshFixdboIPRKH6BrH7EJbZW5jLFCTU8VC/R3VlrvMQO2kEJza4SxzHCn67jtG1GJHycpP4nKq1OBgyQfVKmfZ4g/jJEQ7BCcyntlFdq5A7QmOtsrcRMR9/6t5KVUTobNhp4lN5rGU8bZ24VOFnARSNIhRJtzSldRBb9uh5BEp25YafUJJYl3hNwVKamqGVuk9EVIMh9YkxSGWMwxB1JwwCdGyEqIQolsI8U8hxB1CiLuFEP/e7DI1CqZTpWjsS4TBI/mQost+6HKspPSUOG41QkE0bI9nxOxlSCjlbDUJsFVku0gjZtnsUSsZCCA2j2BuSHtCL/QTWreH7RmjaUx7qB2HGErsoBH3iJPArtiyhIgarHWylPJw4Ajg+UKI45tbpPyQpB7XrjZX8VkPnRqnBtsKaT6hN7ZxPwRXWExCR4zzg8fSuEMiDN3bD5LEWtJhY4mwRsQklIiz1WpEvfvBSlLdTBK0itAfgZRSoqZPhwp/yPh/jE/kEYidD0zyM+Usl8hne5izEWN0fG3MQyWEx5gNNo6+484zJ7B3EeHYI8X8EDpPI2VuZseY8XCiirBUHkrsoOFb602rn7FHK0uICCEKQojbgS3AVVLKm5pcpNyRWcbSfaKhDdO8uCuoxTyWoSAxEowJM7RF2CSVEUlqsnkOBLbIOMINtgv+aHkoYX4hNlk/fTXoDZadtuCFSc+Ea1/5OzWJiC8OLSshAkgpS8ARQohZwK+FEIdIKe/Sx4UQ5wDnAMxsThEbD4kyHmwGdgD3AncA+wLLofMAn6m9/Uybs4fiwhGYXuuN4vy0trqcmiUifJpjOJn+jXWo2NuOThpnqrQzDXpU5Fk9prno+mMCxo26bBCj7QRyOYY880OURSqcBGTZ0oSoIaXcLoT4K/B84C5j/4XAhQALhJgw6rTuyEXUfCA8GRzoBPYBpqDI8Z/QfcMIC/u2sPCJJ3jWhpvo36uHdQcv4sGD9ueOgw9l1aJjuaNwHJvlPjgzWgPuSD4/Zn/9CId/pCVRCG/a/Vb/umyIplAd6st22I1FAM5kqXWhSCWvt952DTA0j1eBiIRYCa2JZhZKqOuk0JsW/CY0Ai1LiEKIvYGRgAx7gOcCX2pyscYc3iiwANgv+F0KPAVYrtZ3Lu1hHYvVMrqI/rVTEffC1Pv6Ofj2+znlf6/lwPsepDg8wn17HcTd3U9ldcdR3Dp4NHeMHM5gWX4yyc9FO3HbdoHj99kZWGqVvFwEGHeeidB/4m6tSXHMVeUqrhVnigjZEN3SoLkvktgBJg3xxaFlCRGYD/xICFFA2Tovk1Je2eQyjTlGitC5BSUZpr2tDsH2JbNYt2QxG18wn3UsZk3fUkYemkHvbVs5+B/3csh9d3HUutW8ZdsPOGjoPh5iGbeygpvZn1Xsx7+YE/QN21niZyhAAiIjziojaiKhHzEhO0nJyey/2SbXsuLvW1MJmEvktq3CDmbWbEfojbXtefGT0Ts95pNAFc6KliVEKeWdwJHNLkez4RehsxtlN5wFLM72v5L1avt65nDjvs/kxtFnKtV7KnSuG+LQoTs5mms5lr/zLn7CfjzObezLKpZwEwtZxT5sYKouDW6JJ0ZmcwRll/MeUpsU5vJpxzlV7KBsp9rs2Fd13gaXDS5XaAOKR3mSer3bXLqB7hHn9KN2ZnK9zyuV3El545YJjpYlxDYCCJTHaB5wI8q58rocruvDsOjiVo7hVvblQk4FtjGDtTyNeziOe3gDN/FtHmOYDlaxmFUs5ib251YOpt+8kAuO4GEzJVU0VVkCCxWiu7KEiY8QnQbLOUdNXMhJFowJSWgytNJ/ORc106JHiU6G6WKILobL65oM1b4hNY7ZtKGa667fCY42ITYRpuE/FbOBFcB64LPAmcBHkv9iS4lZSrSTqVzDQVzDfihHwDBLeYLjWcfxrONMruRQvse9LOIvHM01PJ0beA79zK5cJuW2TnU5DgYZmqOrcfy6Mt1oUgzJsK6wm1qIrWYyrNZJpckwnRA7vBKdhSGHZFhZ72K4rFJ3Do6W57tmiMQJvtoSYhtjjkSS7ASeDbwIuBo4Bbq+OAKvl9kjSmMbdVwirRKPMptHmc3POByYThcex/AEJ/MQ53EJv+B8VnME1/A8rhleyU2jx+FTDKvMjnAQjYh06KGI0LSLFazZCK1SJnmZQ0FEcXZK9aj1IVUnrxVaQjTsh91ECbEbOruHAulwqGwvrEiJQ2UpUf96NgFqYnRIi0njyScK2oQ4HmDb4uYAnwB2Qs8FI6z44iOID3ew4XULoKueG8XRjBndB0MU+RsH8TeO5bNMZwrdPIO1rGQV/73zfSy//UFu3PAMrulbyTXeSu58ymFIOqyA4Yzs47BDxpXO9TRO37kr7Kaezu4SW6v6cxoMCTFOOgwI0vOUqqxV5s6ymhxOFlsISFMMUiFCV9o149dvq8xtNApZ1GUPKCaddDRs/1sPT17by4ILHucNn7iM6997Ar95+4uUA6bmkmU9T53bzxT+zCn8mZfBHJi9Tx8ndV/Lyh3XcM6PLmT2hdv46wnP4e+nH8N9zz8Q9pfxIZFVljKJh5zunxpU5JKht+v1Ud9h2GwYDAkxhgjV+gid3cMU8OlkODQOW6vMWkrsZKjiUNESoUmKppQYLCNtCbGNsYTpGc0cvSYEu58zlXufcxD9d07lgK88zFf2P5/fv+lUvv7+d7KuJ6Nb2lkaSB4x7MY2r5dfLXo5v1ryclgCi6au4+TNf+EF/7iS8/7jKwhP8tDKpWxbOQt5slQBVnWUMA6m/VBvg1L9RIYL+C5vDgZBZpYs6wlsN+2HMRJimRRV2jJtI9TOlE7DsWJKiZ2DI2Fp0GU/NPY5HVITDG1CHE+IEUp0B+07bDaXXXI4g+u6OfFr/+R3h72MK573Qi540XncxaEZb5JF/0vpGVarWr/3Yi454Q1c/oFTWdy5lhPuX8XJ11zL0b+8nf3es57S/A5GV0oVZLXAuo6lMrsabFanSuicKtW/UlzlmwWI2878RxdM+6FDQrQcKhXnSeXXJSV6BBnKh6g4VWwHiiUlDgxlfa7xi5ZO7jBZkdhNYj5hZofdtng23/7Pt3HSI3/kvgMO4E/vP5W//sdJvPLun1MsDeda1qzlK0MI1h+0iOvf9Qyu+NXzufKJlTzyo0WMLuyAi4GXAp9COY02UCYu4UWlZjMmsZrEZLXawhJJ0X2n2m4UgikhikRCLHhaOvRDnmSTGE37ohcjCbqkRHMK2ImMNiGOB9j9MGO/3DVrBl9+3wfZ7zeP8c1T3s25q7/LY1fsx2ce+wSL/HUp/3Y1/+ypvpIQGjJWEPQf08Oej3XBz4DLgbNQOu2tKGJ8CNhT4VoXMbpgS4W5any2DbFhYSkZYhCD/QXPD4XVKFthWGU2JUfhkgrtfYZDpU2IbYwZ4pwCnu53ddjwfa/IL592Jitf/xdWPucaZvo7uGPj4fx2+EW8mMsdeQlrGSlsoRZjTAkVWnQwsBIlLZ4M7A2Mwgzg8GBZQrxDPY7K447ZyOwvcZFgeTtOZq2FVoyhezFkqD3MUZU5KiWGRqjYITbglBJH/EkRhtgmxImIWNXOh3tnruB9+3+dxYvW8auOl/FRvsxajubzfIunsMn6Q1znjRtBnL47M3Rn70aN1JkO24B7gD5Uros3Aa8Bno7iTNft6y2G7VjRVOO8mRNp8msaNPOlB2XrESqmlGgTY6c9QkXbCku4Q28Mh8pkkBDbTpXxhFrflqPv9XdM5WLvTVzsv5EVXM9buJBVfIA7Wcj3OJbfsIzsNnQz+UBtSIxLNC49gsoWfAfQD+yFIsfXolJHPhAsG+oqTStB2w+9WMnQrB+XQ6USalMJxSmPUDFVZZdNcZJJiG1CbAGY4TYjhK1Gwg5Mtt5YJkN/ylu+h4P4EB/kPF7FS/gDb+UavsH/8WMO5fscxT3MtUqrUZ3MUPDqn5vEtB+WgEeAu4HfA3OBA1CDeeYCa4GNKImyWpj1Ghd+E4tcmcOQEBOkQzxZrl877MaeZCo0QsUcrmcTYxCKIyeRDbFNiOMF9bypjGEhwxS5jOO4jOUsZQNv5mb+zI95jJl8n6P4JQezM9f8fimwPgRegcRwmS3BciPK3rgcOAg4ATUqezdB/KGkap3at1Vl36vq/zWcHCBGQowsfjB1qnasuMcza5VZONRil2TIoAq3yThpxLhH24bYZFQlb5kE4RBaqg8LicejTOcTnMx+vJ//4Bm8mPtZy9f4JT/lFdxCT1mhtpi6Bm+rU2pM+QCkBbDvRNkbrwYuRUmRHShbo7cL2IrKRF6PwOpb67kzhtYTPNJCbvAUCXoWCZqZbnRIjlcqVeIPzXjDhIDsAdoSYhtNhqcTHIzJW3JTS4kOruRAruRAZjHAS7iPt3ATF/ILfsfRXMrp/InT3Z0lgSASp/ZM4XXP+jUT8rsgUWqznp5m0dRgSORa4EGUy/ppqBRrKQiyCNo73esJV8mGhJAbK7lDsXs4oha7pURrhIpjiJ69b8RvS4htNAGGchSlJ90HrSwwmVA1obrpZTs9XMyRPJ9zOYDP8jcO4iP8jE08g+/xdlZyNQU7JYqjFyXOH+0qa06Cb7lOC6ix3k8FTkFN2HUbarqyt0PHhVBYW4oS35jDQ42zyeJhDpOeSYZ2bsTQCJUkldmwH/q0JcQ2GgzTmdLwG8WitiwLTzCDb/N8vs1rWUQHr+QG/oPz2PeJtfzf6Cv42bRX84+FJyC9bN/cWJLMoYXalxjxjXrvRnliXoAaU70JxC2w6Etbmb3vHmadsYv+l0xj3WGL3FXVULHJ/ETiJkMjy01EEnQGZJfcI1RibIramdKWENtoOooeYYmwxhEr2ZGVfcJjP9Yzn6/yHo7lZp4++29sLu7Dhbeew5qLlnDBrz/CUY/eClIGRQ53q8xqcyFcwiwlrboD9wArofQdeGzzXB7+r33p3DHCy156BZ9e+iXe8r4fcfS1t9FRKoGf9iHJgz6yepjNGMTkgOyu0lB4hIrtVda/wTmmujwZJMSWJEQhxGIhxF+FEPcIIe4WQryv2WUaa0Q6fL3kl9g/PeOkrM3e7Q562FvG5xf8Pw459W5OO+P3DHldXPaFV3L/Cw7k4/9+AUvveTT90jGsZ6ZCa7i04gl2PHsGt371cL7z8Jv49hVvYfves/jgB77B7YtP4L++8H6Ou3tVmeiTUSuVZPUwV0KabFI0pcRyQLYeoZKWITtwqOhWMRkIsVVVZh/4kJRytRBiOnCrEOIqKeU9zS5Yo2EkincjjzeWeo186ObuOYfwieM/xyfe/1mOGbiZN9/4Pb71vA/R39vD3WcdxBOvmg1PUVJj2V6XULZa8vGFU9vWCCHYeOh8bjr0WL72/96Nd88op37nWn74jTfRvWeQny99FT+b9mruHDmMdBNENQ9huI3iyBBUDGLBjjX0Q1mzy1KiP+pWkWNsiaaEOBlU5pYkRCnlJlDjyKSUu4QQ9wILUZEUExb2y/B0o7elw6a/Nd01YmQGu3xCcMthT+Oxl+7DD778el5w45859dKrOfn46xla2kn/WZ10vHK0ASaAcGnzQAmPNQcu5p/vehafPfkTHHbDnZx17aX85uaXMDjazaXFs7iUs3ig1iSPISQkhsXcVk9oq8uugOwu25nicqgEv7ZDpS0htgCEEEtQmfJuchw7B+UbZObYFit32C8i1suc543qumaGhFuhjqs6rezo4N5nHcTuZ01l9dcP5Zi/rOagSx9i1mf7VRT1M1FODhvW1M1Nhw8IwZ1zD+fOEw/n4/O/wLEP/JNXr/sZf+U5PM4cfsbz+TnHsrYmx5WONUhP+4VXyR4UN+VoJCDbHqHiWGyHSsvUfQPRkjZEDSHENOCXwPullDvt41LKC6WUx0gpj5ky9sXLDZl5qV5SjP2/KQNU2+xT5AbjnvYse9Lr4InnzWHdRfPo2zQV3oYKf3kv8D2UPpBT+sZqpRtTvsoEIfjn1OP4wLSvsZh1fID/4Cms4xbeyt/4OO/hD8xjRxUlcEwsFbN0eBWJMBqYbQVkW6NQ4qRDSlGHSltCbCKEEEUUGf5ESvmrZpenEYiLuS5bjmwVyTxYbyxi4v/Tmn6GrhH3cAFM73IJD7oEnAocigqW/h2wCngMCL52o02a5MgmxZLOD6a/HdY3ZJQC1/FMruMA3sO5nMJVnMW1/Du/ZjULuZQj+RWHso2pCXc1JUQyOVRsG2JEUjQdKnGxh6aE6Nf3qRyPaElCFEII4AfAvVLKrza7PGOJTHw3lvMbxSJj3GACMZqJYksUKud1o0aOzEFlb7gH2Ai9wDRgKmrUyU7j8i6KboXG7ePxR47ijyyni52cxmrO4na+whXcwFIu5Ugu56nsptv6Z/bJ6e2g7CgpGhludMiNa4SKNeWo7VAZaGhNtQZaoc248HTg9cC/hBC3B/s+LqX8ffOKVD/MjDZJx0No5PC93D751oVcDoBa4aFyfO2CJwZgV7DrBOBE4GHgXlTKr9E6b6VRcpTZpJtYpNTnEEV+zWH8msOYxiAv4h7O4ja+xa+4igO4lCP4HSsYDJFhesiNKyg7PIRvuDJlgGuqUT0Xs6FC+yUYsBwqbZW5SZBS3kjdE1W2BkySS6rsomPda5R32cPovJJcLUVxNgDPre+mDo+zSKYfletwJ8oAvhB4FnAGStO+HzXjgJ3LMXe1r6YLVep2N938jKP4GUexF/28lH9xLv/ge/wfV7KCS1nJVSxjpOqgbJsYjSkDtN3QdqgkBGSbEmJbZW6jZpRJzdrW664Z4lz/L0NfqGBt14KGsIJR4iolRC11Sc/4CmYo43ZUbNYNwW0OAI4AXgysQ0mPDU0U6zvWa6jbJ5nCRRzHRRzHXHbxCu7gPH7Lj7iQX/MSfsrZXFd4thoGqeu0nNxB5UFUY5WHI8Sot7uGhpMTwsYEZA+gPkLtbDdt1AWbCOMqOilguJwctl5JsVZHTCJiwmxi7qE9oXFwqqIZzZSgVOlbg6UTWIbK3fAsVIfeQsVhXe/8wqUGTlK/hel8i2fwLd7IYqbwSm7ja8PvZ9qDu/mB/xYu3ueNbPIWGB+ccB7EcAyiMemUDsjWThV7QvqUgOxWI0MhxGLgElSOIglcKKX8b+scAfw3cBqqGbxRSrk66botHXYzXmFMK15ut8WY/a7/Gab0MOohNNd/c5MUraexCTgUelND1uwqyzmM8sNcCXwfFcDqoUhyKdA9gtK5U67rshnmmXMyHqoFrGM//pOPcMRet3PW/pey39Bj3PWzQ/jN/57B6bdfoRxSnjlWOawum4HZ5YBs10x74zMg20eNZlsBHA+8SwixwjrnBag8wctR8crfTrtomxAbgB6iRGhrkfYQvZq4LuVPBXwKXkyv982VHMNvE8jQhVTJsM5QG4mSDu9FhTg+CgxLYE+w40ZgNfBEuLxmuepxpNQG65NZFNwy82m8/Zjvsu+5a7n80DP4+K+/wGMv3Y/PfeWTLH10TUK4jdoWNvklhNvoDNmtHGojpdykpT0p5S7UK15onXYGcIlUWAXMEkIkDiFqE2ID4CJCW0I0STPxQrbanIE56523pGYkSYaO+VQ00YyN1KUwAgx1ALOBo4D9UZmzvw28G/gydFwHjISTNkSkRZf9MHKnWmF9Mo12sGfKNH74jDdz4pf+wanf/hPTB3ZzxbFn8rXnfYSnX/YPuoaGHBlvfLdEGBOQrWfYsy3EYzh5RFVIGM22EGVO1lhPlDRDaNsQGwBzDro0G2LasRBMUkzhEGd+wToM/27YVG8d0r+WyhwHkaNkmBnzUQliFwXbj4L3SVj+2vXMem4/3acOs3PlDNYtWZx+rdzqNVsM4t3LDuFjp3yOb3/trbzq17/gzO/8hiXvXst9Zy9j09vnUFhmzcGsnSlJDhU/OsNenkR4oBByT4bzNqhZHwaNXRdKKS+0z0sbzVYt2oSYM2wbYKw9MAWRkJsYVC1dNVIHShF5NRk2TYJNggCeAjwfhi+AdbsWMPyHHhb+eTNHfvx/6Z82hVtWHsnVK0/m/455GbsjM0EHqLt+TZ3CC++KiUEsdRe49tXP4oFXL+OIh27n5Auv47kn3sDAkV0MvLOTwil+xYFixiI6fu2E53kTxCBwfobz3gmDUspjks7JMJptA2B+yRaREnjQJsScYTXlTDMWm8pVhECrCF+xiSbJqwsEnddUjqoxoScoUU6VOSNT2EVulDad5lDZp8DGN81j3ZsWs04uYvQuj/nXPM7pP/4DnzrnC6yZt5Rrlq/kmt6VXC+exW6m51g447OaGIMoI4lhty+bxaoLnsb6z+zDIf93L4u+tJmud0s4EziJVIeKn9Bk8iCLAmpGxHqRcTTbb4F3CyEuBY4DdgSZtGLRJsScYc4G54pBTELkZaSE20jH28ssfUUIoVbRxhRXrN3m4YTyJZa5kcKk6UNK6glCsOHQBaw69Dgeff8SNvbN59DfPszKP17Dh2/8Cj9f9yrumHY41xRWck1pJas4PgjxqdWOaHxWEwnRdySGVYvohidfP5Pi6wdZcMMTFP8beB1qiM9xKO+BS2W2SuFarwd5ESIxo9lQBhCklN9BTdd9GipWvx94U9pF24SYM8wKTZsqUwdoV6VSa/thomoazioDZHAC5ISEXlRJURVDin5oR+U3o4RcM2Lqw3ak6HmZS57HqhUnsGrKCXz+sP9Hz0P9PP2Ov7HysWv4ytCHOYj7+DtP4xqO5BpWcDvTGa2KHLNKiKVIYlg7BtGjhPdU4NPAy4GfAF8F9gMORrnhtZRIxVTj0nDysCV2QC6ydJbRbFJKCbyrmuu2CbEBGCE0V5qTJNOyCZbT5Reobaa9hsN+qoS2Gdi6TNQUj9hI+JRJuGRNjOWn6O0D3hSu7n0uVw88F4Zh1u4nOYkrWMkf+TG/YC7buJblXMP+XMP+PMDeJPflbFMHdHimVBif3KEccjMTeAlKQvwr8CfUiJd54eIUPaAU/qDn5VjJUUJsCFqum0009FjbWSrci6yko2CoTIm2w0RVudo4RMvi6VKTY6TEynaDjVZZ4Hhk21llkmLJTy7YdvbiN7yQ33AMsIUFPMDJrGIld3MeVwHwF5ZzTbBsDKU3NoO1khPDFjx3DkQ7MDtiMyygQo7mo3y5D6IyY0xHSYxEP+Z5CekF8pEQG4U2ITYQWlKMU6NHgmM+GdVn621lHkHmCrcJkUCDxiN49qab/MqkmMEckDtKxm8GgTUxGUXMt2QjvfyYE/gxRwAjLGcrK3mQF3M3/8XlPMHUMjn+lSPYnjkPYjTe0BWYXX42M9RmEKW/7ouSHNcDG0CUoNNwbkN+0iHBLdsS4iSGJr0s8Yix0ESh1zN+rn0KqdJMPoiREo19BafKnCCNjoeWmfqhcUHwIHvzIHvzHU5EMMoRbGQlD3ION3ExP+cB9uNvPJObeA6rOJ5HC0vBs6TFboK0X+bk9GZyByMG0eVZDhdJBap3AFuhc4dKRTlA/mpzoQNm2GqTC1mCFRuA8dDsJgQs5TL9/BiVM/ketXqYa4E5FidSEPd6AJMIE0mxytLULeOmjm12PEydxZd0cBuLuI1FfIXnUmQGxzHE8WziZfyKC574KJ19w9z0yHGsWn48q4aP5+a9n8YuZgQe5iT7oa+yZNteZNOzbn9sp0BpFPydKixzKrC7vkcMoaMDepIShWu0CXFywOYKU11OPBkSY/JSiSX2cI6qsl1eM9zGCA8Z72hktpsReriRQ7mR5eDNgCWwcNF6juu8ieP7V/Gpy/+do76+mjWLl3D7Mw7hgROWsfH4fSitEI6pSINpR21zgEvDML36AoYKsKUEs4BjUSnW6h4GAkoKnZbhvC153Kx6tAkxZyTRUpqUGGtHNBtwXhPW+/ZOM8lDDggxf3Qyd5fzR3pByrMWQTir4FgVTLeSiod5w7RF/GrJIn616OWwBLx9Rjh0922csvYqnn7j33nbV25n7w19PH743uw4ehr+0QU6jx7CO8in4GqQceYXbTsMjgtUMoyZwHOAP+TxeAVInEqmyWih5jd50OhK1x157OCFV2MeUDtPNBGObRkdqEHdjR0qmbsZIj4G0e8uctuyw9j0wjn8+R0rWcBG9t/+EEfedgcrbr2PxX9cT+/nt9O5cRixDDVgbS6K2aaipLQ4aT74NUNv1qOkw+fk8YgC6MrjQo1ByxKiEOIi4HRgi5TykGaXp1poD7MLVQdiu9YdcBFMyS80IKlDhsLoU4LTCgX3OGYtecUSTRNaaMkb6w+KiewxiKYDZXhWF+ueswieI9nOdOazkX12Ps7cm3bCP4G7gH+hcr/0oWJfZqBmNDRHrQyBEIBUpdBxgx7ZNN1UdNCWEGvExcA3UVlxxw3skcEm+fmEvXZx8DACs82d9i6DW1wqXayal0iM1bgnEp4iROR+oo1zLNN/tT4cMYjmbiMGMS4HokbHjFE4GpXe4BnANhQZbkHNr7AO2Iyai2EAFYozDPgwfVRx5vxg904IIijrRJsQa4OU8vogz9m4hCvc2aYPr4Az9q18Xh7zp6RC2w5zgKucHmXmzsuj3CzE5kPMDdp+WGsMYkUKjzivzHfTBewT3GYesAPFeHuC9T0wsB129Sv+3IXi0m15PGLbhtg4CCHOQaUGD8X6NxOaXnqsdRdCUqDRfsckEWejuMlhpHdh7JwUY4i66tRwpmQgRHvqUVtSBEXgfgGKcePCzVdQCP8WgyLUm8YugraE2DgECSMvBFggRNSV2UTogOwea91GT7fKUFz0VGJOSkpy9KwGWi0ihJOqJutfj3SJsWitJ6vO9gRT5SQJac3P9Kw3Uqt2SOm2Z3ks51KpuHsdC6DTfkH80MeQDdR5Dcc+q81pDSZXksgadtMktKcQyBn2zCQj1jFXYIvnqaXoqUYYsR/mibHSWq3O1nSPchxSHE5a9qpsZ58dsHrEeJi7g8PlfX4kD6KpJptTM5RMsvOsXxOO0JuiV9mdJa9nJmiVOW1JgRDiIiHEFiHEXTHHTxJC7BBC3B4sn8xSvHEtIbYqTJUZElPOARXpUI9O8TxjCtIMMNNU2VLNqMvL7OzAOTQFl/QRKuf4am5xUuGoDszO9eNiqstetA6tbR2EDdHMQZoefQoqc483mqh+u+6l40FN+b/FVOaLSXe63iClPL2ai7ashCiE+BnwD+BAIcR6IcRbml2maqClQJfLohwC7VOeI7ioJcTgNwt32AMmKiEs1RCPfW6dzT7l1vbEUnmoomklrifcPPPse7kgJQYRvR51npiyYqSsLunQlBLNd2aslz/QZGqO2aAJsU4JUUp5PTn5eUy07CdbSvnqZpehVpjqsm1H9Ix9fqkiHWo1WavNWbO+5KOKuqzsNV5Cr8dIiRAlbN2Zy4gLSYzxyjcarvmZy8gtxlNLh+lzqeg8iHHwMVXmQlhCjCNFs70Z7822FteN7BLiHCHELca2c5KpFJwghLgD2Ah8WEp5d9ofWpYQxztsO6LdmPS+kKqsh0zl+kl2oJHXt689viNtkskwdxiVl+T8cMD0NettAL9QQHojKv9r3LVsUtSeZofaXDeyh91sTZtkKgWrgf2klLuFEKcBv0FNWJ+IllWZxzNMFc1OMmIu5XP8wMNswtX4qyCXSCdO/W+1CcrMLuL4j1MyrIFYmvDJ9gtRtXNsSFHXqXATYLCvYHnto2aIiumkLI2b6rEZfhOnHJieZnL8huakMqdBSrlTSrk7WP89UBRCzEn7X1tCbCDMGERbZdZLUedG9QIVuouqyTDOdlSikCGLrEspqma0ihdedZR91C+UO6Y/5iRTP3xH3eYPy3URZ0OMKZtdRrNN+DpywbYZ2nGJ5j6IzK+SC8YoDlEIsQ/wuJRSCiGODe7cl/a/NiE2CLYdUVuIXFKiRyUWUfqOjC8ZJcNMjopGq+P6HgloOSJMmELAJJzGJc0IfSLduwmvuz6CtmNF7yt5hqPOpYpb1y5vU0nyAGNuQ0xE4HQ9CWVrXA98iqCIwYx7ZwLvEEL4qIGJZwWTTiWiTYgNRNYZSrRzJfRHDZcTIWi8Ja8Kcolp8ArVms7HoNlUcYtcJz+IcQLVHjJUjQnCcqi4LmM1qErmQ8/YF8qIqEJvCqNup4p9L6OduEJv6kaBXAKz05yuUspvosJyqkKbEBsEU122pcGI2kzgXAkauwdqpjSsCzQUdq/IQjPZbIcutIyUWIWXOJYY6343Dvttgrpc8sOEVylGhRgjnnvzmiYZFqxj5rkEDr+hBtgQWxRtQmwgXASYpDZDxblSrJEEU9XmxDdeZ3NIUPVLpQKlQiFCKHkRYwOmyIqHGewe2l/PRXW9FONfQ9AmRi27sK0mm0MjfTP0xnamxIXeWJJjrmPr24Q4eWETny0l2mObfaNDFc0Z0lKgG35l20EysfaiON2sDtjqf8aU+35hjBJb1ICy+llKsc/WjIQYRIiqyqUCfsG2F0bDblSCByP0xlaZTXIk+pt36I3sAL9NiJMTSSqyfbxo2ApHfOjpAmE3nJSgZDcRlsBLa8p2Myii7NBjg5Yd5xwgGsJUa7fR+oEN3Soc9kNn3FZBJf4t6PLZ460N+2Gg5pc9zfbisiVapJhnQLzsgKGuLNF+o/ncsEq0CTEHxEXw2clhk8yCIxBqdH7JUJvjEg94HbHGfnt8azrqlAHinELjHHZGb5WBXGSU3mupTyPsxoTjfiYRmili7e2Qpxncc6rYpKi5Oec4xFHRwUDXlAxn7s7hbtWjTYh1oEhF7U0iQ/O4qTbrBBCuzDjT/ZTuFCNUpdrkIoZ6+y45kmJZsvEo+YVywLNGallzEhwboobbZJjb0L1gNS4jiE+5Pktd0fHLfhwpmp7mOOkwhhTN0Jt6MUoH/bQJccKhJ1hcaZFcYzhM0nMFa9vw7Y6QoUG6VM8Or8RojLcyHmknVkExEalmLIfC5YtyuW0VNhPi1GV9TMNzk22CthAfdmPsM8c0Q9iWaL4OiwzzDr1RhJhlpvrmoE2INaAHNfFOD9kkQ5/K2A/XVKNmO/esY1mI0O4QmaEL50zwlGW0SkwHr0daalCLTOzMsY+RQNy5mwX0pzVoLTYBhqREge8XEogwLCX6pmPFdqYk2RK12uwBQ/k85SgdDGSSEJuDNiHWgNm4pUOX8mmSoI04ASPSefPM8BIrKdq9oUZUOapmvCE0SX2ihFhLnhg/9BO5h0GSJcOxY4fdgOEVN+2M5hA+07NskqI9vM+QEHOxIWZWmZuDNiFWiRlUpEOIT3GgSbAshBEmRlNd1qTpGcciyR4gVl3KDC9mPYIclCNbVfYLlLqiN211D3MsbNLKDUFrcNkOI0Vw2w9D6b8M9Tl2CJ8pHULYy6xJNCe0VeYJhNmoScpmE+80sfcl9RU9yDJWlihZJ9eD2Dcdx5KaqquArd6VJRpR3XVyREMdKi4JseYeZeRIyiQhmqNSKr+hbNnWEipjnOpse6L1Xwrh0LBa0VaZJwCKwFwqZDgj2J9EF3Y/saVD139NqTL2QuZczF5G6cp+y6Z0ULYh1poo3k/czIpIh3XdqTTGI1KyItehlYbanBSnFSAuy1HUtmh5miFMjEk2xByF+LbKPM4xg4pkqEnRDKweIRrCbDpR4mCTopMMzQvmhUxvPIdmMYFiEU34mW2I1cCemsy8IU5iHPWDUTOFaLYbc93M1lP2NEMqCeZlUrbRJsQaIYR4PvDfqNf2fSnlF8e6DFoiLJNhAWbPrIQijAwpW9/AoJJe+nGToa0625qlSybzM6gnBb9UXaxeqg3RVt6rHK3i8ohaGK/OlEiQu0t6cz53tRK3oTZnkBCj+RDN4XteSH2O9TSnjGUWge3Ry01lrt+GKIS4CDgd2CKlPMRxXKD44zRU13yjlHJ12nVbkhCFEAXgW8BzgfXAzUKI30op7xmL+/egSFAT4lxg3lSYMhuYGZzkK5tKcRCmDCly9HbDQCmeDD3CFBNK6oDjZSQZ1attnJkdKjl6mQ2iMCUrO9dgJjRyLpVqr51kQ6xZYtQ6g6UyxxFjMHyvVAiPXzadK6ZTRR33op5miNoPYTxIiBeTPOveC1BTBiwHjgO+HfwmoiUJETgWeEhK+QiAEOJS4Ayg4YQ4A1iIQyrsRbGjHl9cAgaD391Q3Kn+6+9Qh82oMgiToVnpPgmqsj7B5azIiIJXShkVqh0eSQFESYgz+FVxibTLZiCsanJ8R65fL+pWm3UrSFCZI0s0/tCWEm1SdCaLtW2IrtAbVCxiMYdYxLwIUUp5vRBiScIpZwCXBElhVwkhZgkh5kspNyVdt1UJcSGwzthej4PdhRDnAOdARXCrBzOA/YAlBNKhKRVqQtSThvsoQtwTbJdUp5xuqNAamhjNJWncQmL/qqXzxX35PdfONJhkOVLZFxdEbMCVoqplkEC6iWPCE99HLd0rQWW2oIfvAU4p0eVpdjpWXIutRucEiciqMtc7656LQxYC45IQMyGooAsBFgiRmh48CZoMlwMLC9DbiyJCmxB1jVlkqDHFp5LGqxSWXGytBMLjmu390q/IbxH4UPBHKXRVoe9V9barIUdj7E0N8Xn1Jl1N+rjUA4/odJ+eNqRlsO9VBy0lAkjKbz4m7EaPVlG73fZDfcyMTwwN4XNJh654RJRUmd/QvUwSYr2z7tWEViXEDcBiY3tRsK8hMMlwSRfMmEuYBM11UA1yZ7BeokKIQWOd7sOuPZTnXMboQ1ptjnNZaOkxtEPfx95XKxIdKjU2+ziJJtHBEi5IKadO1ygkOlZc24luMxMj0fUawm7s5A7lgGxDnY44VuJ+Y4ixXoyhl7kmDmlVQrwZWC6EWIp6iLOA1zTiRj1U1OQlXTBjAcp4qElwBmFiDGyG5b5hE5Wv0v/3lCpzLvs+MFQJwja1FXu6UmfXcd2rWrjCKVLfftWWOYUqHA21z1XSIsgsHZr6qOsPZuiND75jPDPhU0q+9iCbgdjhoGwIj3EG4nMjmkV1reeAvLzMGfBb4N2B/+E4YEea/RBalBCllL4Q4t3An1Dfp4uklHfnfZ8iARECSwoGGc5FkaCtMk9FqcoQkQqZSqixFkvgmbbEoZzamdE5CrnpifXKZb5ztSF6bJUwswtlQkyZM+eWrPmZzT9a6T/iPM4BouOYw04VfU5kThhbRbZDbox1r6CSFteb5CEvCTHDrHu/R4XcPIQKu3lTluu2JCFCeXLp3zfyHloyXAT0ajK0JUSTFLVKodP7Q7ihalIM+o4gSPIK5azDHhV1Oc325ZcSqCr4o1aSqkLsW49LdZsRSZKgr0M/GjWVZ6uhljE11vC9OEdVWUIMh9QkJogNOVYMCRHcarOG5WnOgxCH6KzvImSadU8C76r2ui1LiI3GQhQRLgTmLUC5lRcEi0tVnkpYdXVJhhBuuCWU4yVoRKbdUJNhXLcp7zclUXufAdv4n4q63rxLz7KQ6nHOp+npejTND3p/XojUrV30up0quiU4Ki3WqVI5xR1/6IUcKuZ5zmSxtnRoSY+iO5C091AX2mOZWxCzUZLhfsAirSLPp6Iuz0MR4kwqZDgVRWxxJOgS0gxPdNFTQdsQpZLUvmR3iBh4WSTFhtiHDCqK+zg0AfWQopa6E+s0F++y66LRVdd9zOF7YMyyZ4h4JjGWCdIcwmcL664IrGA9jxQd7aF7LYYZKDJcCCyZTUUqNBetKs9ATardFSxmQGEXFakxjhC7gQKhwfGmlKjVZy0fVGvr8kpV2LZSYTeFKnP4ZTglbdhexCaa8mhZwm1cTxFJZxVTLFfdFvIYv+aEqS8kfGAcEmJ8PkQP23aoh/DBSDTu0JYM9XpXTo8IjCLa6b9aBUUCIgSWTAWhJUKtMuvhKZoQA9KT3VSmyOxGSYpDqA6rpUazn2hVOUDSuGQ74MKUHkf8GFqyWMBU6TIRZOJbz0qEjvPycqiMlWTZkq3fIkPXenlfePienTE7oiobDyw9VOgNJIfeFKgMRsjB9CvbKnPrQJPhwgJM0aqx6UyZH+zTanIXjHTBwLQiBb+EVxpFDFGRGLX6bGPQWA/4yZnwNUBccHYINahnEXIcq7c9BoTmepQag4Si1y65k2bEfmxyNz1UZ0NUhyokaI9SgWiy2PIHPi7sRi/dwb0K5DKNQFtlbhHMo6Iq95rq8VzrNwivGemC4e4Ohro6GWAKnYVhCv5uuruoSImmt1lDW/jrrFk/shLA6JPCr9i4TFtX9Wq0y3CUhAQp0lUfWKn3yTfrje1QSUKk5KZ6mICqnVaJJUiMLXCuuiVELzR8LzxixZ0kNuJptj3KZn3oj36Ow/dGZQf9w22VuamIOFFcdsOADOVM6J+qiHCYLoZQvwDD3UW6hkaUlKglwwJhz5spHVbZ52O7ietAaKx0KbhdChFW7exIoxpNL9WRhe1hLrhIoE7+yeGblH6D3KEfOqMN0THZlJ3QIZIT0bQ3JtkQu6m0MW0WygGjpQ4GdrclxKahByUV7gcsmklFGtSeZb09W5HhrplF+gtTykQ4HMRMFfDpLHTR2TVCt0tNhljp0FSX4yabcl4uTkJw/MEmQ71d8Pwa1MgkO2KCNFlFpymXL/hTNVJt7sP7UnpBrKc5195jO1QkoakXEkwm5igVe7oAe+heeb8OvYFo6I2WDE3kJNDL0Q6G24TYHJgjURZ1QVGH1miJ0HCiyF5FhrsK08tk2M8Uhumkk2EKlBhmmP6p3cBg2c5chuVIyQIXxyVyit0pfMpdoaabhtCAUcR5SHvWf11z2ehxHa7bNERSbFivSfEym9vBYga8uyeYsmyHZugNRpIHDS0d2tAf+3oxKmB369JO65YsByykEoA9w3SiaCI0grA1Ge6iQoj616cQNKPgyzYVIqSY0GCKXnzuaadpCKNLxIqP7t35heEkIaXZxEgzpvoGUbtcOaIl4yPEKfTOUJtslywjTnKtPewmawZy4yviO4bvWQiPVolRjYlOS+oXCuCNhAV+7Sx0FSkvO2IJ2JXDdRqECUuI86gQ4jxNhnONZQHl4OvBmZTJsJ8p9NNTVpd1Bx6we6lJiklfzxhVQ5OeXnpIiEW0+6BzpEpUUiwTTsPf8ghlFS+GqEf9/JwoUNMEB8mk6Kij2Pp0nZ9LHTtE6hQJUQdnq9hCopKgM0ZRFVZ6QbC1aTucatxHnayQlx1xlLpHuzQSE5IQe1Cct5DAbmiPRNFjlWfCyAzYNXUa/UwpE+IwnQyhvMtmpzC+r2q9u0TBH0mdtzYp5Cbz/3SaMbPPlOKIMEGKqeqNZyl4ioUyxxAcPR489rixnioRxozXjb+248YN6T2G2pwmIRoe/HAsYkVqjJuOtOzYcjlR4uo4j2+azhbVopiQhKjV5EVdUDSH4+mgax14PVPFGO5iWpkMB+ihn4pTpUCJEgU8StH4qQIUundR9GKS9MfUrikdmn7FRCnRvkAtSHzbtSYjyACnE0jt1Bn7ks614RpT4/qbixQ9UEkK6kEm6bCW6VyrtCECGJK3a9iefcxcQvOrdOEmXbuI9WKUNiGOJTQZLiGwG5oTKpsTK8+EPTM7yqqyWqYxwBRDSuyiK4hGdQ7josQUbZy2YdhmtA3Rpg/fsQ+s2TUymqzMcNz6US056hIXE73hSQkdMo3DpvoGa844XSvyk7izwLQhWrudxBiORQR36I2JspSoYxG1dGhrIOYvhEPKakWOEmLazJxCiDcCX6aSGPabUsrvJ11zQhHiDIwMNi67oc5co1XlrjAZ7mJ6mRD76SkryCbMvCKdDFfGhWaEJkEXb8RKh67GGTReTYQtATteLnQs6kyxSVC0yGOYiJgjCo10WlVhQzROM2MR45PDhmMSQ1lvskiHeSEnCbGKmTl/LqV8d9brTihC1E6UhVMJE6Emx5mVZdfM7jIZ7g5+tzOrTIhDdEY6rOlpDh2zbSvW9ogPfilMhrbK7EJoDLRp5M694boYN49BcFFUlfYrT/sjjsZu7qzWjlhTz8mSjsKhNoObGAEdq2hnzIbwaKA4YpTdIMxpMBqN/FTmhszMOWEIsQtjnLI5FG8elfyGARnumdlRJsPtzOJJZrGdWeV92sPcyXC54djOi06Gw9KjXnWNCQ3gIkM7hq4FhaSMsLI8a2R4oJAUZuaatFD0KHdcrQrrU80QnEyN2jzJ+oAVfDWWOTVpRuYgx2q7WfY4RIxYREKnuRwpYYfLUFcnBX+QLjKm9solDpGsqnfarHuZZuYEXi6EeBbwAPABKeU6xzllTBhC7EQRYmScsmEzZCaMzIbtXYoAw8tegcqsnColCnQyzJQyIVZaZSZ7ly0lWodt1bnHcay84VJFg/W4RATQhPlK4tQuxz57uGHe8ZNxT17Ekf7L+f+EseGeMSterjAyZtuV5lCVK4RoZ7lJD70ZpkuFkgXhYxFStCXnvJpSdgkxj1n3rgB+JqUcEkKcC/wIODnpDy1HiEKIVwCfBg4GjpVS3pL8D4VOgszX+1IhxNko6VBLibNh+8xpZQI0yVBLiDoGEQgZo80OMj0tstSs1UHw/bBEaP/q0xOFqRRJKxNJZxI/HTPAVQvbEwrK+G8kM439X52wBfS8EWuvzfVmKTZE1/4ArqF7Lg/zsJnG3yZF1zvK6/nyc6qkzqonpewzNr8PXJB20ZYjROAu4GXAd6v5U1eBaJLXXipS4mzYNrc7pCJvZxZbmVNeNwmxk6FQw9JSQpehRkdg7/aBocCGSDwpgjvsJpaOIlm98jYAJTFTXEKH6tksb6nQlf4rUast4H5noVNi1GZTcmoIGcbYEO2PjR8OzradKZW/V6RFM3ys/HwuSdErn5SvhJjPSJWbSZmZUwgx35hp78XAvWkXbTlClFLeCyBElSpJJ5Xg69nGEkwWtae3o2wzjFt0LKJSlzspGWMhMoe06BoNksfKwKFi2w/THCoaiZPVR25dD8HUKp4ZvbPGS1Q9HrsGmOE3nt3BM/SCMBn66moNHdOcMPonojoXysHZ4B7HrPZ7DAWxtU5MhZI3zBRGK/M2Q76EWMOYfxfiZuYUQnwGuEVK+VvgvUKIF6NqahvwxrTrthwhVgMhxDnAOQD7dlNRjc1ltgqxCdsN9ypLiiZJ6nAbICIdDlBiikGQqfn8Ap4YGIpKhDEyQBkRsnSpTDFwqnQ1EVWN7Fal4JgnEWZR8GNjEjOOwshfGs+AGMnQVptdyWHtMcymquxsK10Aw3QWRqNzN+eBHAOzXTNzSik/aayfB5xXzTWbQohCiKuBfRyHzpdSXp71OoHX6UKAY3qFLEuFMyjbD2Wvshsq4turHHPokg53MT1ChGqESo+z40ZIUdemHvdp2A8hPtzGVptjkUdfHAs3dkZSzJS/MaG8pguiZrjGMPsqNi9TnkmbLDxyrGOf0PA9Y3eYGFVwNl2Vscwaps1wOGb6T/v5Sl0eJW+IkjdSSWCS11D09tC9KKSUp+R+UY/KkDz9OwOenN0d61E2nSlaQixRcHaA6cFbdA6UNzuFbqxBRm3Tfghu+2EmVEGGmQK1a+q0SWM+UuLmssAMuWmCIJaGVALPpTfZekSAOCnRgNk27cSwvmU7hLAZyBzj7FOgVChAdz8wQpdnzL9SL/KzITYE41plDsEj1m6oCa8yKkVvTwsd6y9NoeQX6OnqD126Mxi+l1VNLi+DlIOr40ixCorJH3VLMlVcwBdKinFUYR4jbUyOqLpRZ8h0k+U/DUGSl7l8XGXOtudlrniYvcjETubz2URa3i6AP3WoYlfM43na2W6qgxDipcA3gL2B3wkhbpdSnpr6xwIV6XB2ZWieqRLbKnJIQtwzi6HBSjK4UpdXblqZ4/n0LHzmrHyE1WS769tqcqbkDub9YqBJxhW0G4+GU3AElbyD1ZFiLSV1jlTJiLBTpUTZqZJ3nF4IxifTJRkav3ai2Iqq3MUww6FniCN7e6hq+TpdHtBPZyEmiUk1aKvM1UFK+Wvg11X/0VCZtd3QlApdpGiS4e7t08EvgFdCy4f9XT3l5A5xWUTK+7Xd0FjkYDS0xpQOTcT5TfxStqlIUzEWtsNWuq8DqUVx9Ibo8M2xSgMWIIuEGMAVczhEJwXHLHdmeI7+b3TqASN0pwtK3hB1fzTb2W7GCAXKKrPLbmgPzyuT4dAsdm+dBdsD2ukuMuIXyqQ4vctt8NAevQKlsN3QWLSH2STFrPbDULOz+2ALkUwEQdmKpWFm79pG74Y+Zu/aRvdjg0x/cCvzC5vZZ3QT80c3M7+0kXmjW5g5uovO0RI8AWwBthrLk6gONAwdo+o15zpGpOV6gOl6s3a7PMzBUgpiEUuFyggVRYaV9GquGEV7Aip9XG+HjhcKwI76H1HWf4lGoeWaQ80IVGbbbmg7TkIEWZrFDk2G24PrdAPTiowApe6hUEp1s4GEoJ0Bg+HF9ytt1h6U5frO1sNz+cbxhUtXYJSF7GQfBumlxGxG6WWYXkr0Msxs+ullgF520rupj9mbttEzOsC2KbPpm9HLtr1mMzCrh44pI3R5Q3R3DNLTMUhPRz89hQGmdPRT7Cipehswln5gJ4oQB6DHh/2BA6hYJ/RfBoPT9bIn2J/nt2NsArMTkOBUGQ1iEXXzLOExTFeZDH3CqrCdAMKdVNazjudBiKNUn+t87DBxCNELj1Puozc0IqUvGJFS2bcX27fOgq3dShLZTThkwkvPoudRoqs0pHqjthkadkQtIUKYGG1UZTe0oBMRQBWkaLGEYJSFbGJ/7mAJ97OEx1jCZpbwBPuxgwXsYgtT2cx0+phGHzODZS73sYg+FtDHUrXs3cu2JbPZuf8MWCqCGb6AJTBzyWaWdK1hCY+ylDXsz8Pl9flspGvdIKwFHgU2otbXApvU9p4+WD+kImx3oqp5FCUxelSCr2cBTwH2AqZRIU2Jit/Hx5nCMlK3ZZtbuMI6vJL6e829p8rYnBT7oVoK5byIigw7jb+HCc6WBtWvTX6VsdF+ORI3r5TZO3O4TmMwoQixb+bMEPm5lj7msJVetvfNYnTzVNiMIkTtnpxGqL26p3ZUqnInw3QOjlREFb0EjTRudEqSFSZfiSZ8tbm7HufAjfezfP2DLN/8IMv3PMgBPMD+PMwOpvEQ83mUWaxhOtfzFC5hBWuYxXpmMEIBRTc9wPTgtzdY19l3eysSUy16bUJIietULQVqSVDXrxeUrghMQQW8zg1+pxH8cReKaDcAK4A5NZQXGuBUMVuHI/RG/0YWr5wX0R5c4Jb2KqNXShToob9MfsN0hrZt8qwPo7QJcQzgeyIyTnk7e9FHb3m88lZ61f4dsxjZPKNip9qOalQ6CtU1DWOAEhXvcwGfQoz9kCrth3VJibZkKCVzNmzlgFsfYtENP+Xwm+7m6LtupWtoiPtmHMSDheU80HEAPy++igdHlvMQc9lNP/B4sGxDMYbdcOP0w/qnMPVK9av8tu+hiJIKtwMdKEG6A+ieSUV8FMDdwDXAMuAUEK8BniZDpO6chS/3jDAJXwObDM31IPSm5IeH5sXZBM3fKQHxqfM8wtLhQGS7frT2tHsThhBLFCKqsa02b2cvtg/NYnDrXmHDvfZ6WUFsvh/+MurpSLsYCpZhipoALRuiHIx6l/Utkobs9SQcd0JKimt85q7eytxbt/HM1auYd+sTlCjwr6MP4e+HHMdFL3wz73rVt1i7fV/YIGANKnvcMMrgxk69khFmhy3GH3Jt54BcLtmBIkRzju49wENQfCXsP7qBnpf5bD9nFmtXLI7+P3MexFq7WAwppjhWBvb0qDHJmWyCbsdKnIMlH7RV5jFBCS8yGiWiMmsnylYRJsRBUmuihJpoqiuYtL4zGAxVVpctG6I5htnlWDFRVQffCawGrgU2wpT7JfsVN/Pk0f2sP3oht779CG46+hjuXHgo68S+rNuxmMH7ZisSzMFBWA+qntM4Q8WYHx3z1yWzliV2VyhTJ3AQcCYML4ONa+fCTwu8cOXVHHPAHfzp3FNYf+YiIqPfGhqHGMD1kXGQ4+hgZ/mzVuo2pyYN2wTVvig5TgmdFyXTfNDaQ1UmECEWyska7BAbvW/X9ukVJ4q5gKqJbqKNLYAZ3dXJEOW5TDQZmnZEv+JhtkenVA3dfh4Abg229wcOA14HAy8QrF2yiHUsLi9Pshchfc98y85CxMmkdkItXTEpU8FntANmQk4ttKqiCBhe0cm6zy3muk+dSM/lwzzrf/7Giz/xB774hQ/xnZVvCztVxqIXueyG5v5BYHeRUd+jn0C7mQYUEsJoHLZBl73cNbdQ7RilOm0kHhkmmeoCLgGOBvqAV0kp1yRdc0IRYlJqr+07ZjGydUbFZridyrpJho7rAmWpUKvKXQzTNTQcDrmxYhAhbDfMQoodKPv+YpQjwNuFusc+wOHAocB+wFLgKSDnZ/RepKqyOY9ScdVlzNC9WJjpp2pErbZZbYeTxQ7uP3M5fzzzufT+ZRtv++gPef2XfsZ7zv0WN3UfH/1jI3pUXKMxybGcll8wyhQGtbmnW400saU8F8HZqrFNovkgHxtixkmm3gI8KaVcJoQ4C/gS8Kqk604oQrTHLIfshtunVwjQth/GmHr0cCgtFXYFszV36qnsB0crRGjZENWw+HT7oUBx3UEonpuHai67Ua6NOTOgOA9l57IEvzYUbLUZY7uWBl6IIaDbTz6cM/95Gsd9bzWXn3cGFzzvo3x12QcBkWM8onlzx9wq9qLbnPlXXwBdDKLacMkvUOo2YxSjo67iyDCNRKtHbjbELJNMnYHKvg/wC+CbQgghpYwNDZ9QhFjJYmOpzNunw9ZilAy3oxpTglc5bI4OR2YVTVXZsiHaWo3ZWbW56kCUoNePmi3nHuCfVPJULARFgHl+oJuBoCLibIi1SiBJ0rZJhnnIvjrAWXZ0cOmZr+C6eadw2YdeydPv+BuvOfKnDCU1oqrhGKVib5vLbsc5CPC7GTGmf6UbK6u2m+wiw/aIZuCuHbl5mbNMMlU+J0gouwMVK7aVGEwoQtR5DkOz6fXNYnTr1DAJmovpWbbthp5f9iprqdBUmcsN0rIhSj8aciNR5HcAygT4GCqf+Z9QcXQ6um8GY/BSYpmkAS7hMUIuJc/IywXPZ+38/XjmJ2/gxxe8jh/94w28er+fIelI+Fe1bzXDFAKmdGieZ55PMOrKL6g4xe5CefSVe7pSd9xtfjbETX+CT2eJ+uxOmXWvIZhQhBjxLNt2Q9uZsp2KdBhpSOB5pWBC0ortsKwuM1QhQsuGODBUudwsFAkuQsWA3wFcSXjwUv1RfBlRFWtU0zTq1xeVRzRBlquiP9aqKkdvGZZo7UQPI14nZ7/0Eq763nP53GP/j/O9L+RwVxN+5SduMaVDP/pXta6cLYOGtKgy2ATrMfGK6jL52hGllM/P5UIZJpkyzlkvhPBQ+bD6SMAEI0Qj5GaPI97QXrT90AUPOruHLduhISEOWSE3xsDakQFl8jsIFdr2MHAjihBbdxRnTrDrM4WEI2pYxsgcO8xmzBF8A4Z6unnJ837DfT8/iIsWv5mHvWXGCbXAoVvoeVVs6XA3oQ+466Ne+Y9SoU27YjXxivlJiLkhdZIp4LfAG4B/AGcCf0myH8IEIsSRIA7xSWbRN9SrMtjY8Yb2EqMqK6+zpLMQlgr1eoFS2KEyiBrgcTewFqaOqsuvRiVvqTbsuaGIkyJqQtPoKISxGgoJ7jlVtnX38q1F7+Lf+r7I28T3Hf+qVQcIiNEvuglRIxMh6n3KrlgyvNBZ4hVbjRAzTjL1A+B/hRAPoXyUZ6Vdd8IQYgmPPnrZzl5B8HUxTH6bjd/NVJwpTjIEuqNkqL3NnQwph8oA8DfgN6hxsYE3eMd62DAaNe+0EUZsJ2shU6YrYYannUOGpeAbS9/D2jX78vYZ30lxPnikf0i0G04vxXgyhHizj4+KRbRJcZpSoXf7Bbqn9Zftiq6hfa4xza2CDJNMDQKvqOaaE4gQC2UnSiT4ejNu6TCWDKHYHVaRK6E3w3TvGoLvAN9DffyPQTW8DahMLY0MjWnYXCNmZRSpTrnXRtfks6qOQ7QwYr2vkZh1s1Spspm+plX22PmYNbzw+rbOXjZ7+7C/eJgHODDtrhkQEuvCNsO401wkGHtOVIUuTbXHLkfHNE90TJgnLFEIO1Fci5nZRsNcNwixKyIhDjF9624WfW0Te397JxwPfBzV+NahiFBfsgUnSAohIoGNjUgWF3aTqaPVMOqvaKzXi9jUasZolT6vl9lym3GwHlVZ/wbrWjK0Rx25FmLWXefEqNBxY5onOlqOEIUQXwZehEo98DDwJinl9rT/lWSB4Tgnynbj1/zCmj2lO7yYDpW9N27l5K9cz4qL72fXK6aw8+ZOZhWH4RHCkVA1op5MNzWjERyYD/OkwjWYsCkwh+5Jyf7DD/Noz9JgZ9H6zQr9ROZI+MCx4rLBuIgOkgnRPAfKXmg95M8c3VI5tfVU5kag5QgRuAo4LzCafgk10fTH0v5U8r10J8pW4w8xkqFalENl3hNbOOdTP+QZl/6dh85eyvV3Hsvei55gMevcRNjU3tkIZLF35YPydK4ZUW011zyW3EBSAt6Fgxvwhcfjna7pxmuBaUfErS6De1BBHDkmSpCVIX/26JYW9TI3BC1HiFLKPxubq1Du8nT4IiwJuqREIDL5Z3d0EZ0DnH3hT/i3//ef/OO1x/KV+97N3nMDInTeO1MJM8GWFr2mt0M7wUNzUW9VV/MkmTKQBz3odZt+zJ+mp08OmQ12OuEBYEqUFONU5mqkw9B6ZXSLVqHt0S0THS1HiBbeDPw87qAQ4hzgHADm7RtVj/X6VgJ1Qzc0j9D4U4MMD3/0dr79oXPp7BzmA1d9keHDO1nMOvbmifDNJ5QkWAt0BaRQTEzMtmuURCMQa46o8f25SLJH9vP+h7/GymXXqLBfl8OOItklblM69Kk4uBykaP4lSXVOUqcj+4vlnDT26JaJjqYQohDialROAxvnSykvD845H/V6fhJ3nWAoz4UAYv9jZKz9cDdUvrQGDCKcVtrFZy79JK+96Sd88hP/jz+97xQWdWyIlwo1Jj0x1o6xIMN6G7iLAAteKXThf7vli/xt9tO5p+epxn6PiiZSy7A93bAGCCdDMEjRJLssqrNrn0uFBtJGt0xUNOUJpZSnJB0XQrwROB1YmRZZXoZPvMpcJkNLXQ6kw5Oe+CuX/PZsrj7yFJ7607vZ8bQuFnRsjL2VM919zsSYOVwkNzSgKTT4Y1H35eu9gAen3H8Vb7nrBxxz0i0qOD8XmKqylixN9FScLFnUZmL2J6nQgV0xMrplgqPlKD9I+vhR4NlSyuwDPEq4PctIKlMRGepKN4iuUf7tsS/y3k1f5+xXXcJVL3ge7A1FdpYlAz2YScM1UiErqo3uyx0t97ZjUEU547KQ55ntJs6WuHjLWi75ydm85vk/ZbOYH1PuWivdVJXt/VAmxSRpL86OaBNi6rmWXXECoxW7yDdRFourhBAAq6SUb0/9l09UXfZBpRoyybAICGbJJ/nfTa9nr+KTHPPSW9iwfFHkkmmD+xsXJN0qyEHE05fw6pidPKjnMYnvzNgjlq57hGvOW8kXnvdxrl30HJV8yrxGXVWnad7OG2izXzC3oJYU4yZJS3OspEmSUFahJzpa7gmllMvSz3LAJsTdoBrUTsITVMLebOHq4VO4dsZJfOjo/8SfWgwZ/wtetDXXIxm2BFxvehzbP+uR+kb8dJOE/TE0t/d/8GF+/eZX8x8vP4/vrHiHmq9Go1zP9Q5XypIzULuMAmnRtismLdOISoXTUs71J3524pYjxJrhUxmJsh0qZKilQx+Yznw2cQ0v4bKOV/LpfT4NHTW+ZJNMGsSVY5YWrJXhIO1m8XiBEsdd808+97rPcN57vsjFh7+pQoZOb7reUST8NovEP4UZkB13HCoSom+sF7PZFatWmVOKNIEwcQhxBMOJshMV/7ALUzrci+1cy5n8kLfyxeJ57o94g2qkJSs610Kl9JZAuiiMca+q525mWcXoKK/8wi859X+u4a0//DZXLD0zLBmaCNVrvcP34vYZiR/K61VIixjr02LW4yTKCYyW7Kc1QavM5QnXtXSo3qKgi0v4OFfyPL7IeZHB+XXZvh2w5YEJnwexDlQTfmNTQlIf9clHyp66aQ/veMsPELvgtbdcxOp5R8FDwUG73eTSo9KyPZpMpWfyHrH2ZZAWXWpzHClCmxDHFUYA2Y/6bG/DJEPw+Aj/x2x28DE+VflPAhHGSTIFShRM635KDY4fMqy2tdfHBOPCJluSHPQ/D7DiMw9x3blP57ufejNPFPeOmkhiH92m47TW4KL4JGupKRm64MVLixpxKnGbEMc5pESR4RZUo9MRO0UWsJ2P8WOO4Lf4STKDIwVUef7lJMQIOC1buU0sWM0ppPzEzVTEhefEofOfIxzx9vvYM2MKP77uTG5fcUS47J4kdu7r3GGX3JQEzW29WCo0jphFcBOg63I4zp+gaNk+Wz2GUWTYR+UrrMjvQ1zGxZzGOha4/5qxFjKNbR0LTIKG6XontXqWq+rTO6Djv2Dv3+/k4QsWs/p1h7NdzMxWxth2ZKoiSQ4ViH9Kc3+PtX+K41yHCu3jzvaUJB3a505wTDBCfBx7RMo0RngjV3MoPyXUqDzr14DnyNuXFxm2TKqECfTmc8EwcAUq+/mLYdM9e/HEXr1oKbAqZ5AXu5GALPl4XE6VHmvbVqEd0qJWoTWykGGbEMcbfJTtMDw873Ae4wEWsJG9on+xHSttBGg8ZdeTbNQMPEkraWquyRHgFuAiYAXwLRg9FUZnJk0pamDM241Nerb6TMx+h8NF77b/4rpMmxDHG3yUZzlsIzycR7iTJcHWGET2eUHKrhbRrpsCu1VV05nGqkUOoUaX3AYsAd4JnAg8RR2uKvFE7Ic16WFcw1mq+RBpp6FLGjT3uyTJHmBKxa4I2ZwqkwATjBBtD57HHHbxOLOip8c9ea01MoFqsulI4KJq+qYr200n0DmEmhRsDvBslGQ4P0uxAidboQReLUE9cY2k1vS19lQDtpSYJkk6YhbNUyehHXECdWM9VrYSagM+G5nJiaw1zrMacUzojQ4LyTM8ZHxVtk8Og3KdqMrTHHNqtUr9Pijemw74Erz5qCnMZ1V5obiypcayNko7Mf3nNpPZI1lMO6PePz1Miq7LQJQcJyjGVx/NBLNxwC0s4Tz+QAclRu1TU7yE4yJWbiKjztY5DzgCeCrKgqFjEOZ60N3IxC2RkSpZyLBeu60Zd2te01aVIRyeA067YpsQJxqUwnQ7+7KR2byCv/JzXls53Ignn8C1malT1/H8eSWL3Qs4FDgEpR7fB/wRpT/MAxa6/lTPe6tnlBOQz2wv5rVsb3MWaXEGZS/0oGgT4sRDWIX4PC/lv/ghv+MMdjPP/RejJuzpMm1JseQVICpvJqIlRqy0YIMOEWEKJ9oylN6ejpICDwVmAvegImjWo17rdIh767XDK+H8SOQyjrleaGnRJj973USMCq2XabRk+8kbE5QQwXzpf+RIXsJdXMb7eBG/qVoRdoaIpNRcsyJ6YsNZxqox5/WwKeTYBRwOHIiyD94H/AU1PXZ1n6pkJJpNqn7WOIJsRJhTkqRo2hDNxZQYrbHQMCkIMWOw1XiFbmiCd/EBAC7hXHowEnFXOSR33M8+lrkT55VvOh5VxSJKRSfzgKcBZ6GiZW4C/hO4HDWJdxoZ1vI0Nik6g7RTVWfXJzJPdTkOA1SSJJsp8fqN7Z3WsZ3BcVkZ3bId9+RWEwwTnBArKOFxJt9AIvgHJ7DMf7BysIovfbUjVqoRIpr/Aa6V/LI/pf6gZPI07wbWosS+PkWGU4Jdl6BI8D7Gpt68YHbiiMSY+hiuOVHGGjr7tp5KYxcVkrQXTYz6nKBNDDIpCLHZb2pM0c8UXsf3eDs/5e99J/LF7n/jm4vfzXDGORZLFKqyeVVrN2zYy7AZIzJs0byzSz+q0RaW8YHKdboL+AdK910N7ABmo+xXEtZvgw2oAZrVUnfecm6HZ0Qt1JX4ZywHcqYFc8ep0IbDZYKjJSVEIcRnhRB3CiFuF0L8WQgRk5WhpqvzHd7Bs3qv59l7ruO+Px3Eq+/9KWLUrWzVM8SsjWR0DJeYc0Mfiz61mdlP3w1HA5ehXMWvBt4BHAfsjZNkskiGY6GUOpGJFJtRMi0tmiryAG712ZQWB6jE+k5ctKqE+GUp5ScAhBDvBT4JpE80FYuohHNf98Gcsei3PGvKdXxp9cf47E2f4MLTz+GHb3sTw0uiV4i1HVYZduExUZM7JEuRXUODHHHLHTz79hs49tabOXj1/Sy4ZzO7Vkxl4Lld7PzMFOYs2aWS/K4FNgLrgnKO8TeplOowK+kV18G4DcJ11GwDidZd4sZCu9DsMjceLUmIUkpzurGpNPDTdP3cZ3PCkf/gmI5bePvt3+H+FxzIX059Nje88kQee+4iFccxRmhEcytlninNdV5tJeouDXBY350cve1Wjr7uVo56fDUHbryfR5Yv5cFj9ueRo5bwrzetoHD4MIunbGAx61jMOhU5vbWmW2ZCQz9EVTmrWuKTSDiY245hjFOhJzZakhABhBCfB85GWZKe0+Cbcct+T+OtK5/Gh2b/J29e9V1edtFvOfrNq1l31CIeO20Ru06bCk+V9U+mNsEwk12s4FaO5nKOLt3OURtWs2ztQ9w38yBWLz6KWw4+hu+eeS7/WnkoXQftYHGXIr8lrGEJj7ovmip55YQ8rpk5vsolQbcCwbhCc2wborlMbDSNEIUQV6NCyGycL6W8XEp5PnC+EOI84N1g5v4vX+Mc4By1lY8ot2PGLC465w388T3PZVn/Q5x07XWc+PtVnPiif+KVfAae30nxqFHl8syiZRgYy+Bs7RMNIZNHNIqpDLGcvmDZyXKe5AC2spwtdDPCfSxjNcfxd3Ei35jzHu5aegjD+3apuJhFqN9u6GJH5NoRG62tHuc0ljlvFJKcKiF4ROu1FaUtLS26SHCSDFOhiYQopTwl46k/AX6PgxCllBcCFwIIsaB2tVrXgvXO90yZwp2nHcKTp83kX984mIPuf4Blf36EaTftVjn0HkQ5AOagxokNq0U0yfZcVUhLAMEoM9nNQh5kCfewhAdYwgaW8DhL2Mp+bGcqIzzEbB6klweZxw3sz0Ws5EGWsZn9Uay3CDqKKmI6L1edJkfrcartmmNOPS2rd6VhcqvL0KKvTgixXEqpAwXPQIWb1YEMLzPNOSIEew6awraDZjK1tJPOtSPwCHAzKqfeQyj715Ow16iK9df+u35UCNcelP6vfXaDxm8eHFLo95nRt5N9t61jTl8fh/Tdg3xMMPPhAXof66N3cx+9W/vo3dXH7P5t9I70sRdPsoepbGQOa5jDGmawhmnczGGsYSprmMYWZlGxFejMy9NRT5liQ8grnVqheQPhCkHAlb1ee+9pmu87I1xSYpsQm4kvCiEORA08eIy6PMwZ4cr/lgYPJRyNoiTFR4G1sOsR2DQcvuRocPocVOq9HqCbCr10ozKyDBuX9gCeRLHowyju0fwT/E6RoxxYeowDRx+jVOygv3cKO2bPoK+3l82981g/dSFPiPlsmjWfu7ufSt+cXvp29tLX10vftl6e3LYXPgOokb9bUFF+etZCPR1DEullpClHvdY84VSAarpoHqPPEoPyI7GdSSeNF5hjoieH2tySb0hK+fJml6Em6NoUitx2E4722mWs25gRLLNRBLmPXmaiWHQfYC6KTRej7HIHQP9TBI/1LGFdx2LWFhbzqHhK4LMNlr7FjDw0Q0mwa4JlPUpcLYulDbBsVjEEUtsRpRdQr6kqFyrXyppIq6by1SPJxjmBfGHtgPGngk4uKbElA7PzQ21ftJKfHPjmF6oLjMuaDW8EpV73oyTFEqg35KFslF0oUXIaSmPtAboFslMgPQGi0gEjo2pqRr6KalyZys6VlvxER1HwYsYzOxFjpB5X0J/0iY0JTog2HF+4mEZc7wiVovU7FvDLTpVK2SPkHttpaylpA58uQYKrljMbRkOZCmLX0XiWtMZrubNjEhFi7d3CaevK0BnGQthRuRmbjaDTVzt2eZyiZebnbiN3TAJCTCBCL/lwrfDyZsIM13NJh+MWzUom2VCMJ7vh5MUkIMQxRhM6cNXqfZNIxmWbTfQ060PavFhocNEzXFxLh55XT+jNeLYlTmxMEkLMrwGGyKcBvbOqkraUMFibWSH2Pw14toaF3IxZKdpoNCYJIeYDpyRmhIU0AiLl2qbHu964PoXm6aiJzn2jWM0M0I5FarWZzpS26tyqmOCEOBKzHiDDx1qTTKpaGpOmyhwMNVao3o5YbVCea3xujsjAyQ27eya12XqbruoL7WtLheMFE5wQoZ4wh3wkrpzhGN9rJnKoz6nSoOd1XNaMk0y1I+rA7BZ8HdWjLR22MiYBIZrwnav23LN+SmB2Nch1dEXGTDCpyCywjB0DOYncsSt3ybCeR8z038mXMWY8YxIQoj0OM9sX2pRaQoHOWhqzh5elIFdqibnYeAi5yWR6MH81jL81y4YIGZwrqS+6TYqtjElAiLXDVuuqzy/YeFSl1rdAeaHKMKGEmRtaDi1ZqDaqwSQiREsytD/U9X64c+gMtuyaKAlZ98tfOsyxdwd16zJFhKRvUzo0HRPBuueFS9VUSdFLkRR9c8VMvNpGK2MSEGJCuIPLrOOreUhqJhivwcZ/xyiOzGVNJP1qCp1x1IXLTlveNKXvQkXSTbCT6nqtx8ftEYwkMgk3j2+Js/pcmYLbpNjKmASEaCOBFWIOtYRtLi7NVD3Xi8DlAsqX3Z2mB33MC5pjgXB8p5UCrNaSVXO+Lmd7GtrJhUlAiBnyzyUQYdihkjJeeKxsSHn10dTMNzmn/nLMABipy6QsN4VKqXKvauODUybmAFWnUgu1pxFjZ9uh0uqYBIRoop6MN4XmZTlx9f6gf1YtvdbVJ5NoyI+ELyXdz5S+NOGUM/docjI9+Y5bZ6Vr2+4YMmnU+HFJy5npJsU2Wh2TgBBjGmMcMfgZGju0hkcxN22uXknQqOMUwk2VthJGfQivdstB5D9aLbcPGh+akG0zK/zYjYR9bbQKJgEhQuIQvhhJxu601UqHufFlnOfVQKqUWFMfzGKtyyj5xGiLyk7nKLsrFtE6Lfd0tglV6JfJsV7nVVtSbHW0NCEKIT4khJBCiDn1X83RSqsMvammUzQ0iqch0mkaxdRwU/sj40ftsXq9REElqrClNtO54oXdPllLlEjtLq99QtLd6kmxndBhPKFlCVEIsRh4HrA2v6tWqT6noQrNrxZ4LpXOukGqo6e+EpBOlNV1dPujEptBSN8ex3p0MxUhUtQk65JEjXqtWm2ObUdtNXm8oGUJEfgv4KO4g7mqhN0gc5xJvgHSWh6+3aaECrmcKjEwidGnEKilXlQyNJ0rhYqHuZox4nGjAatBaviN89ldH4u2pNjKEFLmSA45QQhxBnCylPJ9Qog1wDFSyq2O884Bzgk2DwHuGrtSjhnmAJFnnyCYqM82UZ8L4EAp5fRmF6JRaBohCiGuRs02bON84OPA86SUO5II0breLVLKY/IvaXMxUZ8LJu6zTdTngon9bNDE4BEp5Smu/UKIQ4GlwB1CzTO8CFgthDhWSrl5DIvYRhttTDK0QjRdCFLKfwFz9XZWCbGNNtpoo160slOlWlzY7AI0CBP1uWDiPttEfS6Y2M/Wmk6VNtpoo41mYCJJiG200UYbdaFNiG200UYbASYkIeY75K/5EEJ8WQhxnxDiTiHEr4UQs5pdpnoghHi+EOJ+IcRDQoh/a3Z58oIQYrEQ4q9CiHuEEHcLId7X7DLlCSFEQQhxmxDiymaXpVGYcITYmCF/TcdVwCFSysOAB4DzmlyemiGEKADfAl4ArABeLYRY0dxS5QYf+JCUcgVwPPCuCfRsAO8D7m12IRqJCUeI5DrkrzUgpfyzlFIPCFuFis0crzgWeEhK+YiUchi4FDijyWXKBVLKTVLK1cH6LhR5LGxuqfKBEGIR8ELg+80uSyMxoQgxGPK3QUp5R7PL0kC8GfhDswtRBxYC64zt9UwQ0jAhhFgCHAnc1OSi5IWvoQSN0SaXo6FoucDsNGQZ8je2JcoHSc8lpbw8OOd8lFr2k7EsWxvVQQgxDfgl8H4p5c5ml6deCCFOB7ZIKW8VQpzU5OI0FOOOECfqkL+459IQQrwROB1YKcd38OgGYLGxvSjYNyEghCiiyPAnUspfNbs8OeHpwIuFEKcB3cAMIcSPpZSva3K5cseEDcyeSEP+hBDPB74KPFtK+USzy1MPhBAeyjG0EkWENwOvkVLe3dSC5QChvsQ/ArZJKd/f5OI0BIGE+GEp5elNLkpDMKFsiBMY3wSmA1cJIW4XQnyn2QWqFYFz6N3An1BOh8smAhkGeDrweuDk4D3dHkhVbYwTTFgJsY022mijWrQlxDbaaKONAG1CbKONNtoI0CbENtpoo40AbUJso4022gjQJsQ22mijjQBtQmyjjTbaCNAmxDbaaKONAP8fIEIxNYz/ncwAAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "algorithm = \"PAIR\"\n", + "r2, r = 1e4, 1e-12\n", + "preference = np.array([r]*1+[(1-1*r-r2*r),r2*r])\n", + "model = MLP().to(device)\n", + "model.update_preference(preference)\n", + "train_OOD()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "u7z9miU74Mio" + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "provenance": [], + "toc_visible": true + }, + "gpuClass": "standard", + "kernelspec": { + "display_name": "Python 3.8.11 ('gnn')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.11" + }, + "vscode": { + "interpreter": { + "hash": "ddf6e3d325859c53183a922f8f4a4c99a98233dd957a6c2162fb2135209fdc50" + } + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..31617e1 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Yongqiang Chen + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/PAIR/pair.py b/PAIR/pair.py new file mode 100644 index 0000000..947f69d --- /dev/null +++ b/PAIR/pair.py @@ -0,0 +1,452 @@ +import copy +import imp +from pickletools import optimize +import torch +from torch.optim.optimizer import Optimizer, required +from torch.autograd import Variable +import traceback +import torch.nn.functional as F +from torch.optim import SGD + +class PAIR(Optimizer): + r""" + Implements Pareto Invariant Risk Minimization (PAIR) algorithm. + It is proposed in the ICLR 2023 paper + `Pareto Invariant Risk Minimization: Towards Mitigating the Optimization Dilemma in Out-of-Distribution Generalization` + https://arxiv.org/abs/2206.07766 . + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining parameter groups + optimizer (pytorch optim): inner optimizer + balancer (str, optional): indicates which MOO solver to use + preference (list[float], optional): preference of the objectives + eps (float, optional): precision up to the preference (default: 1e-04) + coe (float, optional): L2 regularization weight onto the yielded objective weights (default: 0) + """ + + def __init__(self, params, optimizer=required, balancer="EPO",preference=[1e-8,1-1e-8], eps=1e-4, coe=0, verbose=False): + # TODO: parameter validty checking + if eps < 0.0: + raise ValueError("Invalid epsilon value: {}".format(eps)) + for _pp in preference: + if _pp < 0.0: + raise ValueError("Invalid preference: {}".format(preference)) + + self.optimizer = optimizer + if type(preference) == list: + preference = np.array(preference) + self.preference = preference + + self.descent = 0 + self.losses = [] + self.params = params + if balancer.lower() == "epo": + self.balancer = EPO(len(self.preference),self.preference,eps=eps,coe=coe,verbose=verbose) + elif balancer.lower() == "sepo": + self.balancer = SEPO(len(self.preference),self.preference,eps=eps,coe=coe,verbose=verbose) + else: + raise NotImplementedError("Nrot supported balancer") + defaults = dict(balancer=balancer, preference=self.preference, eps=eps) + super(PAIR, self).__init__(params, defaults) + + + def __setstate__(self, state): + super(PAIR, self).__setstate__(state) + + def set_losses(self,losses): + self.losses = losses + + def step(self, closure=None): + r"""Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + if len(self.losses) == 0: + self.optimizer.step() + alphas = np.zeros(len(self.preference)) + alphas[0] = 1 + return -1, 233, alphas + else: + losses = self.losses + if closure is not None: + losses = closure() + + pair_loss = 0 + mu_rl = 0 + alphas = 0 + + grads = [] + for cur_loss in losses: + self.optimizer.zero_grad() + cur_loss.backward(retain_graph=True) + cur_grad = [] + for group in self.param_groups: + for param in group['params']: + if param.grad is not None: + cur_grad.append(Variable(param.grad.data.clone().flatten(), requires_grad=False)) + grads.append(torch.cat(cur_grad)) + + G = torch.stack(grads) + if self.get_grad_sim: + grad_sim = get_grad_sim(G,losses,preference=self.preference,is_G=True) + GG = G @ G.T + moo_losses = np.stack([l.item() for l in losses]) + reset_optimizer = False + try: + # Calculate the alphas from the LP solver + alpha, mu_rl, reset_optimizer = self.balancer.get_alpha(moo_losses, G=GG.cpu().numpy(), C=True,get_mu=True) + if self.balancer.last_move == "dom": + self.descent += 1 + print("dom") + except Exception as e: + print(traceback.format_exc()) + alpha = None + if alpha is None: # A patch for the issue in cvxpy + alpha = self.preference / np.sum(self.preference) + + scales = torch.from_numpy(alpha).float().to(losses[-1].device) + pair_loss = scales.dot(losses) + if reset_optimizer: + self.optimizer.param_groups[0]["lr"]/=5 + # self.optimizer = torch.optim.Adam(self.params,lr=self.optimizer.param_groups[0]["lr"]/5) + self.optimizer.zero_grad() + pair_loss.backward() + self.optimizer.step() + + return pair_loss, moo_losses, mu_rl, alpha + + + +import numpy as np +import cvxpy as cp +import cvxopt + +class EPO(object): + r""" + The original EPO solver proposed in ICML2020 + https://proceedings.mlr.press/v119/mahapatra20a.html + """ + def __init__(self, m, r, eps=1e-4, coe=0, verbose=False): + # self.solver = cp.GLPK + self.solver = cp.GUROBI + # cvxopt.glpk.options["msg_lev"] = "GLP_MSG_OFF" + self.m = m + self.r = r/np.sum(r) + self.eps = eps + self.last_move = None + self.a = cp.Parameter(m) # Adjustments + self.C = cp.Parameter((m, m)) # C: Gradient inner products, G^T G + self.Ca = cp.Parameter(m) # d_bal^TG + self.rhs = cp.Parameter(m) # RHS of constraints for balancing + + self.alpha = cp.Variable(m) # Variable to optimize + self.last_alpha = np.zeros_like(r)-1 + self.coe = coe + + obj_bal = cp.Maximize(self.alpha @ self.Ca) # objective for balance + constraints_bal = [self.alpha >= 0, cp.sum(self.alpha) == 1, # Simplex + self.C @ self.alpha >= self.rhs] + self.prob_bal = cp.Problem(obj_bal, constraints_bal) # LP balance + + obj_dom = cp.Maximize(cp.sum(self.alpha @ self.C)-self.coe*cp.sum_squares(self.alpha-self.last_alpha)) # obj for descent + constraints_dom = [self.alpha >= 0, cp.sum(self.alpha) == 1, # Restrict + self.alpha @ self.Ca >= -cp.neg(cp.max(self.Ca)), + self.C @ self.alpha >= 0] + self.prob_dom = cp.Problem(obj_dom, constraints_dom) # LP dominance + + + self.gamma = 0 # Stores the latest Optimum value of the LP problem + self.mu_rl = 0 # Stores the latest non-uniformity + + self.verbose = verbose + + + def get_alpha(self, l, G, r=None, C=False, get_mu=False): + """calculate weights for all objectives given the gradient information + + Args: + l (ndarray): the values of objective losses + G (ndarray): inner products of the gradients of each objective loss w.r.t. params + r (ndarray, optional): adopt this preference if specified + C (bool, optional): True if the input gradients are inner products + get_mu (bool, optional): return detailed information if True. + + Returns: + alpha: the objective weights + mu_rl (optional): the optimal value to the LP + reset_optimizer (optional): whether to reset the inner optimizer + """ + r = self.r if r is None else r + assert len(l) == len(G) == len(r) == self.m, "length != m" + l_hat, rl, self.mu_rl, self.a.value = self.adjustments(l, r) + reset_optimizer = False + self.C.value = G if C else G @ G.T + self.Ca.value = self.C.value @ self.a.value + + if self.last_alpha.sum() is None: + self.last_alpha = np.array(r) + if self.mu_rl > self.eps: + J = self.Ca.value > 0 + J_star_idx = np.where(rl == np.max(rl))[0] + self.rhs.value = self.Ca.value.copy() + # it's equivalent to setting no constraints to objectives in J + # as maximize alpha^TCa would trivially satisfy the non-negativity + self.rhs.value[J] = -np.inf + self.rhs.value[J_star_idx] = 0 + + self.gamma = self.prob_bal.solve(solver=self.solver, verbose=False) + self.last_move = "bal" + + if self.verbose: + test_alpha = np.ones_like(self.a.value)/self.m + print(self.last_alpha,self.C.value,self.Ca.value,self.rhs.value) + print(self.gamma,test_alpha@self.Ca.value, self.alpha.value @ self.C.value) + print(self.gamma,self.coe*np.linalg.norm(self.alpha.value-self.last_alpha)**2) + + else: + self.gamma = self.prob_dom.solve(solver=self.solver, verbose=False) + self.last_move = "dom" + self.last_alpha = np.array(self.alpha.value) + + if get_mu: + return self.alpha.value, self.mu_rl, reset_optimizer + + return self.alpha.value + + + def mu(self, rl, normed=False): + if len(np.where(rl < 0)[0]): + raise ValueError(f"rl<0 \n rl={rl}") + return None + m = len(rl) + l_hat = rl if normed else rl / rl.sum() + eps = np.finfo(rl.dtype).eps + l_hat = l_hat[l_hat > eps] + return np.sum(l_hat * np.log(l_hat * m)) + + + def adjustments(self, l, r=1): + m = len(l) + rl = r * l + + l_hat = rl / rl.sum() + mu_rl = self.mu(l_hat, normed=True) + uniformity_div = np.log(l_hat * m) - mu_rl + div_r = np.array(r) + a = div_r * uniformity_div + + if self.verbose: + print(a, rl, div_r, uniformity_div, l_hat, a.dot(l)) + return l_hat, rl, mu_rl, a + + +class SEPO(object): + r""" + A smoothed variant of EPO, with two adjustments for unrobust OOD objectives: + a) normalization: unrobust OOD objective can yield large loss values that dominate the solutions of the LP, + hence we adopt the normalized OOD losses in the LP to resolve the issue + b) regularization: solutions yielded by the LP can change sharply among steps, especially when switching descending phases + hence we incorporate a L2 regularization in the LP to resolve the issue + """ + def __init__(self, m, r, eps=1e-4, coe=0, verbose=False): + # self.solver = cp.GLPK + self.solver = cp.GUROBI + # cvxopt.glpk.options["msg_lev"] = "GLP_MSG_OFF" + self.m = m + self.r = r/np.sum(r) + self.eps = eps + self.last_move = None + self.a = cp.Parameter(m) # Adjustments + self.C = cp.Parameter((m, m)) # C: Gradient inner products, G^T G + self.Ca = cp.Parameter(m) # d_bal^TG + self.rhs = cp.Parameter(m) # RHS of constraints for balancing + + self.alpha = cp.Variable(m) # Variable to optimize + self.last_alpha = np.zeros_like(r)-1 + self.coe = coe + + obj_bal = cp.Maximize(self.alpha @ self.Ca-self.coe*cp.sum_squares(self.alpha-self.last_alpha)) # objective for balance + obj_bal_orig = cp.Maximize(self.alpha @ self.Ca) # objective for balance + constraints_bal = [self.alpha >= 0, cp.sum(self.alpha) == 1, # Simplex + self.C @ self.alpha >= self.rhs] + self.prob_bal = cp.Problem(obj_bal, constraints_bal) # LP balance + self.prob_bal_orig = cp.Problem(obj_bal_orig, constraints_bal) # LP balance + + obj_dom = cp.Maximize(cp.sum(self.alpha @ self.C)-self.coe*cp.sum_squares(self.alpha-self.last_alpha)) # obj for descent + constraints_res = [self.alpha >= 0, cp.sum(self.alpha) == 1, # Restrict + self.alpha @ self.Ca >= -cp.neg(cp.max(self.Ca)), + self.C @ self.alpha >= 0] + constraints_rel = [self.alpha >= 0, cp.sum(self.alpha) == 1, # Relaxed + self.C @ self.alpha >= 0] + self.prob_dom = cp.Problem(obj_dom, constraints_res) # LP dominance + self.prob_rel = cp.Problem(obj_dom, constraints_rel) # LP dominance + + self.gamma = 0 # Stores the latest Optimum value of the LP problem + self.mu_rl = 0 # Stores the latest non-uniformity + + self.verbose = verbose + + + def get_alpha(self, l, G, r=None, C=False, get_mu=False): + """calculate weights for all objectives given the gradient information + + Args: + l (ndarray): the values of objective losses + G (ndarray): inner products of the gradients of each objective loss w.r.t. params + r (ndarray, optional): adopt this preference if specified + C (bool, optional): True if the input gradients are inner products + get_mu (bool, optional): return detailed information if True. + + Returns: + alpha: the objective weights + mu_rl (optional): the optimal value to the LP + reset_optimizer (optional): whether to reset the inner optimizer + """ + r = self.r if r is None else r + assert len(l) == len(G) == len(r) == self.m, "length != m" + l_hat, rl, self.mu_rl, self.a.value = self.adjustments(l, r) + reset_optimizer = False + if self.mu_rl <= 0.1: + self.r[0]=max(1e-15,self.r[0]/10000) + self.r = self.r/self.r.sum() + print(f"pua preference {self.r}") + l_hat, rl, self.mu_rl, self.a.value = self.adjustments(l, r) + + + a_norm = np.linalg.norm(self.a.value) + G_norm = np.linalg.norm(G,axis=1) + Ga = G.T @ self.a.value + self.C.value = G if C else G/np.expand_dims(G_norm,axis=1) @ G.T/a_norm + self.Ca.value = G/np.expand_dims(G_norm,axis=1) @ Ga.T/a_norm + + if self.last_alpha.sum() is None: + self.last_alpha = np.array(r) + if self.mu_rl > self.eps: + J = self.Ca.value > 0 + + J_star_idx = np.where(rl == np.max(rl))[0] + self.rhs.value = self.Ca.value.copy() + # it's equivalent to setting no constraints to objectives in J + # as maximize alpha^TCa would trivially satisfy the non-negativity + self.rhs.value[J] = -np.inf # Not efficient; but works. + self.rhs.value[J_star_idx] = max(0,self.Ca.value[J_star_idx]/2) + + if self.last_alpha.sum()<0: + self.gamma = self.prob_bal_orig.solve(solver=self.solver, verbose=False) + else: + self.gamma = self.prob_bal.solve(solver=self.solver, verbose=False) + + self.last_move = "bal" + + if self.verbose: + test_alpha = np.ones_like(self.a.value)/self.m + print(self.last_alpha,self.C.value,self.Ca.value,self.rhs.value) + print(self.gamma,test_alpha@self.Ca.value, self.alpha.value @ self.C.value) + print(self.gamma,self.coe*np.linalg.norm(self.alpha.value-self.last_alpha)**2) + else: + self.gamma = self.prob_dom.solve(solver=self.solver, verbose=False) + self.last_move = "dom" + self.last_alpha = np.array(self.alpha.value) + + if get_mu: + return self.alpha.value, self.mu_rl, reset_optimizer + + return self.alpha.value + + + def mu(self, rl, normed=False): + if len(np.where(rl < 0)[0]): + raise ValueError(f"rl<0 \n rl={rl}") + return None + m = len(rl) + l_hat = rl if normed else rl / rl.sum() + eps = np.finfo(rl.dtype).eps + l_hat = l_hat[l_hat > eps] + return np.sum(l_hat * np.log(l_hat * m)) + + + def adjustments(self, l, r=1): + m = len(l) + rl = r * l + + l_hat = rl / rl.sum() + mu_rl = self.mu(l_hat, normed=True) + uniformity_div = np.log(l_hat * m) - mu_rl + div_r = np.array(r) + a = div_r * uniformity_div + + if self.verbose: + print(a, rl, div_r, uniformity_div, l_hat, a.dot(l)) + return l_hat, rl, mu_rl, a + + +def getNumParams(params): + numParams, numTrainable = 0, 0 + for param in params: + npParamCount = np.prod(param.data.shape) + numParams += npParamCount + if param.requires_grad: + numTrainable += npParamCount + return numParams, numTrainable + +def get_kl_div(losses, preference): + pair_score = losses.dot(preference) + return pair_score + +def pair_selection(losses,val_accs,test_accs,anneal_iter=0,val_acc_bar=-1,pood=None): + + losses = losses[anneal_iter:] + val_accs = val_accs[anneal_iter:] + test_accs = test_accs[anneal_iter:] + if val_acc_bar < 0: + val_acc_bar = (np.max(val_accs)-np.min(val_accs))*0.05+np.min(val_accs) + + try: + preference_base = 10**max(-12,int(np.log10(np.mean(losses[:,-1]))-2)) + except Exception as e: + print(e) + preference_base = 1e-12 + if len(losses[0])==2: + preference = np.array([preference_base,1]) + elif len(losses[0])==4: + preference = np.array([1e-12,1e-4,1e-2,1]) + elif len(losses[0])==5: + preference = np.array([1e-12,1e-6,1e-4,1e-2,1]) + else: + preference = np.array([1e-12,1e-2,1]) + + if pood is not None: + preference = pood + print(f"Use preference: {preference}, validation acc bar: {val_acc_bar}") + + pair_score = np.array([get_kl_div(l,preference) if a>=val_acc_bar else 1e9 for (a,l) in zip(val_accs,losses)]) + sel_idx = np.argmin(pair_score) + return sel_idx+anneal_iter, val_accs[sel_idx], test_accs[sel_idx] + +def get_grad_sim(params,losses,preference=None,is_G=False,cosine=True): + num_ood_losses = len(losses)-1 + if is_G: + G = params + else: + pesudo_opt = SGD(params,lr=1e-6) + grads = [] + for cur_loss in losses: + pesudo_opt.zero_grad() + cur_loss.backward(retain_graph=True) + cur_grad = [] + for param in params: + if param.grad is not None: + cur_grad.append(Variable(param.grad.data.clone().flatten(), requires_grad=False)) + # print(torch.cat(cur_grad).sum()) + grads.append(torch.cat(cur_grad)) + G = torch.stack(grads) + if cosine: + G = F.normalize(G,dim=1) + GG = (G @ G.T).cpu() + if preference is not None: + G_weights = preference[1:]/np.sum(preference[1:]) + else: + G_weights = np.ones(num_ood_losses)/num_ood_losses + grad_sim =G_weights.dot(GG[0,1:]) + return grad_sim.item() diff --git a/README.md b/README.md new file mode 100644 index 0000000..0f0f5f2 --- /dev/null +++ b/README.md @@ -0,0 +1,140 @@ +

PAIR: Pareto Invariant Risk Minimization

+

+ Paper + Github + + License + License + + + +

+ +This repo contains the sample code for reproducing the results of our ICLR 2023: *[Pareto Invariant Risk Minimization](https://arxiv.org/abs/2206.07766)*, which has also been presented at [ICML PODS](https://sites.google.com/view/scis-workshop/home) Workshop. 😆😆😆 + +TODO items: +- [] Camera ready version of the paper. + +- [] Full instructions to reproduce results. + +## Introduction +Recently, there has been a growing surge of interest in enabling machine learning systems to generalize well to Out-of-Distribution (OOD) data. Most efforts are devoted to advancing *optimization objectives* that regularize Empirical Risk Minimization (ERM) to capture the underlying invariance; however, little attention is paid to the *optimization process* of the objectives. +In fact, the optimization process of the OOD objectives turns out to be substantially more challenging than ERM. +When optimizing the ERM and OOD objectives, +$$\min_f (\mathcal{L}_\text{ERM},\mathcal{L}_\text{OOD})^T$$ +there often exists an **optimization dilemma** in the training of the OOD objectives: + + + +1. The original OOD objectives are often hard to be optimized directly (e.g., IRM), hence they are **relaxed as regularization terms** of ERM (e.g., IRMv1), i.e., $\min_f \mathcal{L}_\text{ERM}+\lambda \widehat{\mathcal{L}}_\text{OOD}$, which can behave very differently and introduce huge gaps with the original one. +As shown in figure *(a)*, the ellipsoids denote solutions that satisfy the invariance constraints of practical IRM variant IRMv1. When optimized with ERM, IRMv1 prefers $f_1$ instead of $f_\text{IRM}$(The predictor produced by IRM). + +2. The **intrinsic conflicts** between ERM and OOD objectives brings conflicts in gradients that further increases the optimization difficulty, as shown in figure *(b)*. Consequently, it often require careful tuning of the penalty weights (the $\lambda$). Figure (d) shows an example that IRMv1 usually requires exhaustive tuning of hyperparameters ($y$-axis: penalty weights; $x$-axis: ERM pre-training epochs before applying IRMv1 penalty), +Especially, the Multi-Objective Optimization (MOO) theory the typically used linear weighting scheme, i.e., $\min_f \mathcal{L}_\text{ERM}+\lambda \widehat{\mathcal{L}}_\text{OOD}$ cannot reach any solutions in the non-convex part of the Pareto front, as shown in figure *(c)*, and lead to suboptimal OOD generalization performance. + +3. Along with the optimization dilemma is another challenge, i.e., **model selection** during the training with the OOD objectives. As we lack the access to a validation set that have a similar distribution with the test data, DomainBed provides 3 options to choose and construct a validation set from: training domain data; leave-one-out validation data; test domain data. However, all three validation set construction approaches have their own limitations, as they essentially posit different ** assumptions on the test distribution**. + +This work provides understandings and solutions to the aforementioned challenges from the MOO perspective, which leads to a new optimization scheme for OOD generalization, called PAreto Invariant Risk Minimization (`PAIR`), including an optimizer `PAIR-o` and a new model selection criteria `PAIR-s`. + +1. Owing to the MOO formulation, `PAIR-o` allows for **cooperative optimization** with other OOD objectives to improve the robustness of practical OOD objectives. Despite the huge gaps between IRMv1 and IRM, we show that incorporating VREx into IRMv1 (i.e., `IRMX` objective) provably recovers the causal invariance for some group of problem instances. + +2. When given robust OOD objectives, `PAIR-o` finds a descent path with **adaptive penalty weights**, which leads to a Pareto optimal solution that trades off ERM and OOD performance properly, as shown in figure *(c)*. Therefore, `PAIR-o` robustly yields top performances and relieves the needs of exhaustive hyperparameter tunning, as shown in figure *(d)*. + +3. `PAIR-s` addresses the challenge of finding a proper validation set for model selection in OOD generalization, by leveraging **the prior assumed by the OOD objective**. Essentially, different lines of OOD algorithms adopt different priors and assumptions on the causes of the distribution shifts. The main purpose of the OOD evaluation is to validate the correctness of the posed assumptions. To this end, the selected models should properly reflect the preferences implied by the assumptions, i.e., the OOD loss values. When considering the loss values during the model selection, it is natural to leverage the MOO perspective and explicitly consider the trade-offs between ERM and OOD performance. + +We conducted extensive experiments on challenging OOD benchmarks. Empirical results show that `PAIR-o` successfully alleviates the objective conflicts and empowers IRMv1 to achieve high performance in $6$ datasets from WILDS. `PAIR-s` effectively improves the performance of selected OOD models up to $10\%$ across $3$ datasets from DomainBed. + +## Structure of Codebase + +The whole code base contain four parts, corresponding to experiments presented in the paper: +- `Extrapolation`: Recovery of Causal Invariance +- `Extrapolation`: Proof of Concept on ColoredMNIST +- `WILDS`: Verification of PAIR-o in WILDS +- `DomainBed`: Verification of PAIR-s in DomainBed + +## Recovery of Causal Invariance +We provide a minimal demo code for the experiments on the recovery of causal invariance, in [pair_extrapolation.ipynb](./pair_extrapolation.ipynb). + + +## ColoredMNIST +The corresponding code is in the folder [ColoredMNIST](./ColoredMNIST). +The code is modified from [RFC](https://github.com/TjuJianyu/RFC/). +To reproduce results of PAIR, simply run the following commands under the directory: + +For the original ColoredMNIST data (CMNIST-25): +``` +python run_exp.py --methods pair --verbose True --penalty_anneal_iters 150 --dataset coloredmnist025 --n_restarts 10 --lr 0.1 --opt 'test' +``` + +For the modified ColoredMNIST data (CMNIST-01): +``` +python run_exp.py --methods pair --verbose True --penalty_anneal_iters 150 --dataset coloredmnist01 --n_restarts 10 --lr 0.01 --opt 'test' +``` + +## WILDS +The corresponding code is in the folder [WILDS](./WILDS). +The code is modified from [Fish](https://github.com/YugeTen/fish). +The dependencies and running commands are the same as for [Fish](https://github.com/YugeTen/fish). +For example, +``` +python main.py --need_pretrain --data-dir ./data --dataset civil --algorithm pair -pc 3 --seed 0 -ac 1e-4 +``` +We add additional commands to control `PAIR-p`: +- `-pc`: specify preferences; +- `--use_old`: to avoid repeated pretraining of ERM and directly use the pretrained weights; + +To avoid negative loss inputs, we use the following commands to adjust IRMv1 loss values: +- `-al` and `-ac`: adjust negative irm penalties in pair by multiplying a negative number; +- `-ai`: adjust negative irm penalties in pair by adding up a sufficient large number; + +We also provide a accelerated mode by freezing the featurizer by specifying `--frozen`. + +Note that we use `wilds 2.0` following the latest official recommendations. + + + +## DomainBed +The corresponding code is in the folder [DomainBed](./DomainBed). +The code is based on [DomainBed](https://github.com/facebookresearch/DomainBed). + +We provide new [PAIR model selection criteria](./DomainBed/model_selection.py). +Based on three options of validation set choice, we implement corresponding `PAIR-s` variants. + +- `PAIRIIDAccuracySelectionMethod`: `PAIR-s` based on a random subset from the data of the training domains. +- `PAIRLeaveOneOutSelectionMethod`: `PAIR-s` based on a random subset from the data of a held-out (not training, not testing) domain. +- `PAIROracleSelectionMethod`: `PAIR-s` based on a random subset from the data of the test domain. + +To use `PAIR-s`, simply add the corresponding functions or replace the original `model_selection.py` with ours, +and then run the corresponding commands in DomainBed. + + +## Misc + +If you find our paper and repo useful, please cite our paper: + +```bibtex +@inproceedings{pair, +title={Pareto Invariant Risk Minimization}, +author={Yongqiang Chen and Kaiwen Zhou and Yatao Bian and Binghui Xie and Bingzhe Wu and Yonggang Zhang and Kaili Ma and Han Yang and Peilin Zhao and Bo Han and James Cheng}, +booktitle={International Conference on Learning Representations}, +year={2023}, +url={https://openreview.net/forum?id=esFxSb_0pSL} +} +``` diff --git a/WILDS/README.md b/WILDS/README.md new file mode 100644 index 0000000..121c098 --- /dev/null +++ b/WILDS/README.md @@ -0,0 +1,3 @@ +# PAIR for WILDS +The dependencies and running commands are the same as for [Fish](https://github.com/YugeTen/fish). +The only difference is that we use `wilds 2.0` following the latest official recommendations. diff --git a/WILDS/src/config.py b/WILDS/src/config.py new file mode 100644 index 0000000..ac58366 --- /dev/null +++ b/WILDS/src/config.py @@ -0,0 +1,106 @@ + +dataset_defaults = { + 'fmow': { + 'epochs': 12, + 'batch_size': 32, + 'optimiser': 'Adam', + 'optimiser_args': { + 'lr': 1e-4, + + 'amsgrad': True, + + }, + 'pretrain_iters': 24000, + 'meta_lr': 0.01, + 'meta_steps': 5, + 'selection_metric': 'acc_worst_region', + 'reload_inner_optim': True, + 'eval_iters': 500 + }, + 'camelyon': { + 'epochs': 20, + 'batch_size': 32, + 'optimiser': 'SGD', + 'optimiser_args': { + 'momentum': 0.9, + 'lr': 1e-4, + + }, + 'pretrain_iters': 10000, + 'meta_lr': 0.01, + 'meta_steps': 3, + 'selection_metric': 'acc_avg', + 'reload_inner_optim': True, + 'eval_iters': -1 + }, + 'poverty': { + 'epochs': 200, + 'batch_size': 64, + 'optimiser': 'Adam', + 'optimiser_args': { + 'lr': 1e-3, + + 'amsgrad': True, + + }, + 'pretrain_iters': 5000, + 'meta_lr': 0.1, + 'meta_steps': 5, + 'selection_metric': 'r_wg', + 'reload_inner_optim': True, + 'eval_iters': -1, + 'scheduler': 'StepLR', + 'scheduler_kwargs': {'gamma': 0.96,'step_size': 1,}, + }, + 'iwildcam': { + 'epochs': 9, + 'batch_size': 16, + 'optimiser': 'Adam', + 'optimiser_args': { + 'lr': 1e-4, + 'weight_decay': 0.0, + 'amsgrad': True, + + }, + 'pretrain_iters': 24000, + 'meta_lr': 0.01, + 'meta_steps': 10, + 'selection_metric': 'F1-macro_all', + 'reload_inner_optim': True, + 'eval_iters': 1000 + }, + 'civil': { + 'epochs': 5, + 'batch_size': 16, + 'optimiser': 'Adam', + 'optimiser_args': { + 'lr': 1e-5, + 'amsgrad': True, + }, + 'pretrain_iters': 20000, + 'meta_lr': 0.05, + 'meta_steps': 5, + 'selection_metric': 'acc_wg', + 'reload_inner_optim': True, + 'eval_iters': 500 + }, + 'rxrx': { + 'epochs': 90, + 'batch_size': 72, + 'optimiser': 'Adam', + 'optimiser_args': { + 'lr': 1e-3, + 'weight_decay': 1e-5, + 'amsgrad': True, + 'betas': (0.9, 0.999), + }, + 'pretrain_iters': 15000, + 'meta_lr': 0.01, + 'meta_steps': 10, + 'selection_metric': 'acc_avg', + 'reload_inner_optim': True, + 'eval_iters': 2000, + 'scheduler': 'cosine_schedule_with_warmup', + 'scheduler_kwargs': {'num_warmup_steps': 5415}, + }, +} diff --git a/WILDS/src/main.py b/WILDS/src/main.py new file mode 100644 index 0000000..6d1c8c2 --- /dev/null +++ b/WILDS/src/main.py @@ -0,0 +1,896 @@ +import copy +import argparse +import datetime +import json +import os +from statistics import mode +import sys +import csv +from tokenize import group +import tqdm +from collections import defaultdict +from tempfile import mkdtemp + +import numpy as np +import torch +import torch.optim as optim +from scheduler import initialize_scheduler + +import models +from config import dataset_defaults +from utils import get_preference, set_seed, unpack_data, sample_domains, save_best_model, \ + Logger, return_predict_fn, return_criterion, fish_step + +# This is secret and shouldn't be checked into version control +os.environ["WANDB_API_KEY"]=None +# Name and notes optional +# WANDB_NAME="My first run" +# WANDB_NOTES="Smaller learning rate, more regularization." +import wandb +import traceback +from pair import PAIR +from torch.autograd import Variable + + +runId = datetime.datetime.now().isoformat().replace(':', '_') + +parser = argparse.ArgumentParser(description='Pareto Invariant Risk Minimization') +# General +parser.add_argument('--dataset', type=str, + help="Name of dataset, choose from amazon, camelyon, " + "rxrx, civil, fmow, iwildcam, poverty") +parser.add_argument('--algorithm', type=str, + help='training scheme, choose between fish or erm.') +parser.add_argument('--experiment', type=str, default='.', + help='experiment name, set as . for automatic naming.') +parser.add_argument('--data_dir', type=str, + help='path to data dir') +parser.add_argument('--exp_dir', type=str, default="", + help='path to save results of experiments') +parser.add_argument('--stratified', action='store_true', default=False, + help='whether to use stratified sampling for classes') + +parser.add_argument('--sample_domains', type=int, default=-1) +parser.add_argument('--epochs', type=int, default=-1) +parser.add_argument('--batch_size', type=int, default=-1) +parser.add_argument('--print_iters', type=int, default=1) +parser.add_argument('--eval_iters', type=int, default=-1) +parser.add_argument('--lr', type=float, default=-1) +parser.add_argument('--momentum', type=float, default=-1) +parser.add_argument('--penalty_weight','-p', type=float, default=1) +parser.add_argument('--penalty_weight2','-p2', type=float, default=-1) # if there is another penalty weight to be tuned +parser.add_argument('--eps', type=float, default=1e-4) # if there is another penalty weight to be tuned +parser.add_argument('--preference_choice','-pc',type=int,default=0) +parser.add_argument('--num_workers','-nw',type=int,default=4) +parser.add_argument('--frozen', action='store_true', default=False) # whether to frozen the featurizer +parser.add_argument('--adjust_irm', '-ai',action='store_true', default=False) # whether to adjust some negative irm penalties in pair by adding up a positive number +parser.add_argument('--adjust_loss', '-al',action='store_true', default=False) # whether to adjust some negative irm penalties in pair by multiplying a negative number + +parser.add_argument('--need_pretrain', action='store_true') +parser.add_argument('--adjust_lr', '-alr',action='store_true', default=False) # whether to adjust lr as scheduled after pretraining +parser.add_argument('--pretrain_iters', type=int,default=-1) +parser.add_argument('--use_old', action='store_true') +parser.add_argument('--no_plot', action='store_true') +parser.add_argument('--no_test', action='store_true') +parser.add_argument('--no_sch', action='store_true') # not to use any scheduler +parser.add_argument('--opt', type=str, default='') +parser.add_argument('--exp_name', type=str, default='') +parser.add_argument('--scheduler', type=str, default='') +# Computation +parser.add_argument('--no-cuda', action='store_true', default=False, + help='disables CUDA use') +parser.add_argument('--seed', type=int, default=-1, + help='random seed, set as -1 for random.') +parser.add_argument('--no_wandb', action='store_true', default=False) # whether not to use wandb +parser.add_argument('--no_drop_last', action='store_true', default=True) # whether not to drop last batch + + +args = parser.parse_args() +args.cuda = not args.no_cuda and torch.cuda.is_available() +device = torch.device("cuda" if args.cuda else "cpu") +# overwrite some arguments +batch_size = args.batch_size +print_iters = args.print_iters +pretrain_iters = args.pretrain_iters +epochs = args.epochs +optimiser = args.opt +args_dict = args.__dict__ +args_dict.update(dataset_defaults[args.dataset]) +args = argparse.Namespace(**args_dict) + +if len(args.exp_dir) == 0: + args.exp_dir = args.data_dir +os.environ["WANDB_DIR"] = args.exp_dir + +# experiment directory setup +args.experiment = f"{args.dataset}_{args.algorithm}" \ + if args.experiment == '.' else args.experiment +directory_name = os.path.join(args.exp_dir,'experiments/{}'.format(args.experiment)) +if not os.path.exists(directory_name): + os.makedirs(directory_name) +runPath = mkdtemp(prefix=runId, dir=directory_name) + + +if batch_size>0: + args.batch_size=batch_size +if print_iters>0: + args.print_iters = print_iters +if pretrain_iters>0: + args.pretrain_iters = pretrain_iters +if len(optimiser)>0: + args.optimiser = optimiser + + + +exp_name = f"{args.experiment}" +if len(args.exp_name)>0: + args.exp_name ="_"+args.exp_name + +if args.algorithm.lower() in ['pair']: + exp_name += f"_pc{args.preference_choice}" +else: + exp_name += f"_p{args.penalty_weight}" + +if args.sample_domains>0: + exp_name += f"_meta{args.sample_domains}" + args.meta_steps = args.sample_domains +if args.frozen: + exp_name += "_frozen" +if epochs>0: + exp_name += f"_ep{epochs}" + args.epochs = epochs +exp_name += f"_{args.optimiser}_lr{args.lr}{args.exp_name}_seed{args.seed}" + +os.environ["WANDB_NAME"]=exp_name.replace("_","/") +group_name = "/".join(exp_name.split("_")[:-1]) # seed don't participate in the grouping + +if args.dataset.lower() in ["poverty"]: + dataset_name_wandb = args.dataset+"_avg" +else: + dataset_name_wandb = args.dataset +if not args.no_wandb: + # raise Exception("Please specify your own parameters if you wish to use wandb") + wandb_run = wandb.init(project=dataset_name_wandb, entity="entity_name",group=group_name,id=wandb.util.generate_id()) + wandb.config = args + + +# Choosing and saving a random seed for reproducibility +if args.seed == -1: + args.seed = int(torch.randint(0, 2 ** 32 - 1, (1,)).item()) +set_seed(args.seed) + +# logging setup +sys.stdout = Logger('{}/run.log'.format(runPath)) +print('RunID:' + runPath) +with open('{}/args.json'.format(runPath), 'w') as fp: + json.dump(args.__dict__, fp) +torch.save(args, '{}/args.rar'.format(runPath)) + +# load model +modelC = getattr(models, args.dataset) +train_loader, tv_loaders,dataset = modelC.getDataLoaders(args, device=device) +val_loader, test_loader = tv_loaders['val'], tv_loaders['test'] +model = modelC(args, weights=None).to(device) + +# assert args.optimiser in ['SGD', 'Adam'], "Invalid choice of optimiser, choose between 'Adam' and 'SGD'" +if args.optimiser.lower() in ['sgd','adam']: + opt = getattr(optim, args.optimiser) +else: + raise Exception("Invalid choice of optimiser") +if args.lr>0: + args.optimiser_args['lr'] = args.lr +# pop up unnecessary configs +if args.optimiser.lower() not in ['adam'] and 'amsgrad' in args.optimiser_args.keys(): + args.optimiser_args.pop('amsgrad') +if args.momentum > 0: + args.optimiser_args['momentum'] = args.momentum + +if args.dataset.lower() in ["poverty"]: + classifier = model.enc.fc +elif args.dataset.lower() in ["iwildcam","rxrx"]: + classifier = model.fc +else: + classifier = model.classifier +trainable_params = classifier.parameters() if args.frozen else model.parameters() +optimiserC = opt(trainable_params, **args.optimiser_args) +predict_fn, criterion = return_predict_fn(args.dataset), return_criterion(args.dataset) + + + +if args.algorithm not in ['erm'] and not args.adjust_lr: + n_train_steps = train_loader.dataset.training_steps*args.epochs +else: + n_train_steps = len(train_loader) * args.epochs +n_train_steps += (args.need_pretrain and args.pretrain_iters>0 and not args.use_old)*args.pretrain_iters + +if args.no_sch: + args.scheduler = None + +if args.scheduler is not None and len(args.scheduler)>0: + scheduler = initialize_scheduler(args, optimiserC, n_train_steps) +else: + scheduler = None + +if args.adjust_lr: + print("Adjusting learning rate as scheduled after pretraining...") + n_iters = 0 + pretrain_iters = args.pretrain_iters + pretrain_epochs = int(np.ceil(pretrain_iters/len(train_loader))) + pbar = tqdm.tqdm(total = pretrain_iters) + for epoch in range(pretrain_epochs): + for i in range(len(train_loader)): + if scheduler is not None and scheduler.step_every_batch: + scheduler.step() + # display progress + pbar.set_description(f"Pretrain {n_iters}/{pretrain_iters} iters") + pbar.update(1) + if scheduler is not None and not scheduler.step_every_batch: + scheduler.step() +elif args.need_pretrain and args.pretrain_iters>0 and args.use_old: + if args.scheduler is not None and len(args.scheduler)>0: + try: + if 'num_warmup_steps' in args.scheduler_kwargs.keys(): + args.scheduler_kwargs['num_warmup_steps'] = 0 + except Exception as e: + print(e) + scheduler = initialize_scheduler(args, optimiserC, n_train_steps) + else: + scheduler = None +print(optimiserC,scheduler) + +def pretrain(train_loader, pretrain_iters, save_path=None): + aggP = defaultdict(list) + aggP['val_stat'] = [0.] + + n_iters = 0 + pretrain_epochs = int(np.ceil(pretrain_iters/len(train_loader))) + pbar = tqdm.tqdm(total = pretrain_iters) + for epoch in range(pretrain_epochs): + for i, data in enumerate(train_loader): + model.train() + # get the inputs + x, y = unpack_data(data, device) + optimiserC.zero_grad() + y_hat = model(x,frozen_mode=args.frozen) + loss = criterion(y_hat, y) + loss.backward() + optimiserC.step() + if scheduler is not None and scheduler.step_every_batch: + scheduler.step() + n_iters += 1 + # display progress + pbar.set_description(f"Pretrain {n_iters}/{pretrain_iters} iters") + pbar.update(1) + if (i + 1) % args.eval_iters == 0 and args.eval_iters != -1: + test(val_loader, aggP, loader_type='val', verbose=False) + test(test_loader, aggP, loader_type='test', verbose=False) + if save_path is None: + save_path = runPath + save_best_model(model, save_path, aggP, args) + + if n_iters == pretrain_iters: + print("Pretrain is done!") + test(val_loader, aggP, loader_type='val', verbose=False) + test(test_loader, aggP, loader_type='test', verbose=False) + if save_path is None: + save_path = runPath + # save the model at last pretrain epoch no matter whatever + save_best_model(model, save_path, aggP, args, pretrain=True) + break + if scheduler is not None and not scheduler.step_every_batch: + scheduler.step() + pbar.close() + + model.load_state_dict(torch.load(save_path + '/model.rar')) + print('Finished ERM pre-training!') + +def train_erm(train_loader, epoch, agg): + running_loss = 0 + total_iters = len(train_loader) + print('\n====> Epoch: {:03d} '.format(epoch)) + for i, data in enumerate(train_loader): + model.train() + # get the inputs + x, y = unpack_data(data, device) + optimiserC.zero_grad() + y_hat = model(x,frozen_mode=args.frozen) + loss = criterion(y_hat, y) + loss.backward() + optimiserC.step() + if scheduler is not None and scheduler.step_every_batch: + scheduler.step() + running_loss += loss.item() + # print statistics + if (i + 1) % args.print_iters == 0 and args.print_iters != -1 and args.algorithm != 'fish': + if not args.no_wandb: + wandb.log({ "loss": loss.item()}) + agg['train_loss'].append(running_loss / args.print_iters) + agg['losses'].append([running_loss / args.print_iters]) + agg['train_iters'].append(i+1+epoch*total_iters) + print('iteration {:05d}/{:d}: loss: {:6.3f}'.format(i + 1, total_iters, running_loss / args.print_iters)) + if i % args.eval_iters == 0 and args.eval_iters != -1: + test(val_loader, agg, loader_type='val') + test(test_loader, agg, loader_type='test') + if not args.no_wandb: + wandb.log({"val_acc":agg['val_stat'][-1]}) + wandb.log({"test_acc":agg['test_stat'][-1]}) + running_loss=0 + model.train() + save_best_model(model, runPath, agg, args) + + + +from wilds.common.utils import split_into_groups +import torch.autograd as autograd +import torch.nn.functional as F +scale = torch.tensor(1.).to(device).requires_grad_() +def irm_penalty(losses, pos=-1, adjust=False): + grad_1 = autograd.grad(losses[0::2].mean(), [scale], create_graph=True)[0] + grad_2 = autograd.grad(losses[1::2].mean(), [scale], create_graph=True)[0] + result = torch.sum(grad_1 * grad_2) + if pos>0 and not adjust: + # grad = autograd.grad(losses.mean(), [scale], create_graph=True)[0] + # result = torch.sum(grad.pow(2)) + result += pos + if result<0 and adjust: + grad = autograd.grad(losses.mean(), [scale], create_graph=True)[0] + result = torch.sum(grad.pow(2)) + return result + +def train_irmx(train_loader, epoch, agg): + model.train() + train_loader.dataset.reset_batch() + i = 0 + print('\n====> Epoch: {:03d} '.format(epoch)) + running_loss = 0 + total_iters = len(train_loader) + running_losses = [] + while sum([l > 1 for l in train_loader.dataset.batches_left.values()]) >= args.meta_steps: + model.train() + i += 1 + # sample `meta_steps` number of domains to use for the inner loop + domains = sample_domains(train_loader, args.meta_steps, args.stratified).tolist() + # print(domains) + avg_loss = 0. + penalty = 0. + # overall_losses = F.cross_entropy(scale * results['y_pred'],results['y_true'],reduction="none") + losses_bygroup = [] + + # inner loop update + for domain in domains: + data = train_loader.dataset.get_batch(domain) + x, y = unpack_data(data, device) + y_hat = model(x,frozen_mode=args.frozen) + # loss = criterion(y_hat, y) + if 'poverty'in args.dataset.lower(): + loss = F.mse_loss(scale*y_hat,y,reduction="none") + else: + loss = F.cross_entropy(scale * y_hat,y,reduction="none") + losses_bygroup.append(loss.mean()) + penalty += irm_penalty(loss) + avg_loss += loss.mean() + avg_loss /= args.meta_steps + penalty /= args.meta_steps + # losses = losses_bygroup+[ penalty, torch.stack(losses_bygroup).var()] + losses = [avg_loss, penalty, torch.stack(losses_bygroup).var()] + # agg['losses'].append([l.item() for l in losses]) + if len(running_losses)==0: + running_losses = [0]*len(losses) + for (j,loss) in enumerate(running_losses): + running_losses[j]+=losses[j].item() + # print([l.item() for l in losses],sol) + optimiserC.zero_grad() + # loss = scales.dot(torch.stack(losses)) + if args.penalty_weight2 > 0: + loss = avg_loss+args.penalty_weight*penalty+args.penalty_weight2*torch.stack(losses_bygroup).var() + else: + loss = avg_loss+args.penalty_weight*(penalty+torch.stack(losses_bygroup).var()) + # print(loss) + loss.backward() + optimiserC.step() + if scheduler is not None and scheduler.step_every_batch: + scheduler.step() + running_loss += loss.item() + + # log the number of batches left for each domain + for domain in domains: + train_loader.dataset.batches_left[domain] = \ + train_loader.dataset.batches_left[domain] - 1 \ + if train_loader.dataset.batches_left[domain] > 1 else 1 + + if i % args.print_iters == 0 and args.print_iters != -1: + print(avg_loss,penalty) + agg['losses'].append([l / args.print_iters for l in running_losses]) + if not args.no_wandb: + wandb.log({ "loss": loss.item(), + "erm_loss": agg['losses'][-1][0], + "irm_loss": agg['losses'][-1][1], + "vrex_loss": agg['losses'][-1][2], + }) + running_losses = [0]*len(losses) + # agg['losses'].append([l.item() for l in losses]) + agg['train_loss'].append(running_loss / args.print_iters) + agg['train_iters'].append(i+1+epoch*total_iters) + print('iteration {:05d}/{:d}: loss: {:6.3f}'.format(i + 1, total_iters, running_loss / args.print_iters)) + if i % args.eval_iters == 0 and args.eval_iters != -1: + test(val_loader, agg, loader_type='val') + test(test_loader, agg, loader_type='test') + if not args.no_wandb: + wandb.log({"val_acc":agg['val_stat'][-1]}) + wandb.log({"test_acc":agg['test_stat'][-1]}) + model.train() + save_best_model(model, runPath, agg, args) + +def train_irm(train_loader, epoch, agg): + model.train() + train_loader.dataset.reset_batch() + i = 0 + print('\n====> Epoch: {:03d} '.format(epoch)) + running_loss = 0 + total_iters = len(train_loader) + running_losses = [] + while sum([l > 1 for l in train_loader.dataset.batches_left.values()]) >= args.meta_steps: + model.train() + i += 1 + # sample `meta_steps` number of domains to use for the inner loop + domains = sample_domains(train_loader, args.meta_steps, args.stratified).tolist() + # print(domains) + avg_loss = 0. + penalty = 0. + # overall_losses = F.cross_entropy(scale * results['y_pred'],results['y_true'],reduction="none") + losses_bygroup = [] + + # inner loop update + for domain in domains: + data = train_loader.dataset.get_batch(domain) + x, y = unpack_data(data, device) + y_hat = model(x,frozen_mode=args.frozen) + # loss = criterion(y_hat, y) + if 'poverty'in args.dataset.lower(): + loss = F.mse_loss(scale*y_hat,y,reduction="none") + else: + loss = F.cross_entropy(scale * y_hat,y,reduction="none") + losses_bygroup.append(loss.mean()) + penalty += irm_penalty(loss) + avg_loss += loss.mean() + avg_loss /= args.meta_steps + penalty /= args.meta_steps + # losses = losses_bygroup+[ penalty, torch.stack(losses_bygroup).var()] + losses = [avg_loss, penalty, torch.stack(losses_bygroup).var()] + # agg['losses'].append([l.item() for l in losses]) + if len(running_losses)==0: + running_losses = [0]*len(losses) + for (j,loss) in enumerate(running_losses): + running_losses[j]+=losses[j].item() + # print([l.item() for l in losses],sol) + optimiserC.zero_grad() + + # loss = scales.dot(torch.stack(losses)) + loss = avg_loss+args.penalty_weight*penalty + # print(loss) + loss.backward() + optimiserC.step() + running_loss += loss.item() + if scheduler is not None and scheduler.step_every_batch: + scheduler.step() + # log the number of batches left for each domain + for domain in domains: + train_loader.dataset.batches_left[domain] = \ + train_loader.dataset.batches_left[domain] - 1 \ + if train_loader.dataset.batches_left[domain] > 1 else 1 + + if i % args.print_iters == 0 and args.print_iters != -1: + print(avg_loss,penalty) + agg['losses'].append([l / args.print_iters for l in running_losses]) + if not args.no_wandb: + wandb.log({ "loss": loss.item(), + "erm_loss": agg['losses'][-1][0], + "irm_loss": agg['losses'][-1][1], + "vrex_loss": agg['losses'][-1][2], + }) + running_losses = [0]*len(losses) + # agg['losses'].append([l.item() for l in losses]) + agg['train_loss'].append(running_loss / args.print_iters) + agg['train_iters'].append(i+1+epoch*total_iters) + print('iteration {:05d}/{:d}: loss: {:6.3f}'.format(i + 1, total_iters, running_loss / args.print_iters)) + if i % args.eval_iters == 0 and args.eval_iters != -1: + test(val_loader, agg, loader_type='val') + test(test_loader, agg, loader_type='test') + if not args.no_wandb: + wandb.log({"val_acc":agg['val_stat'][-1]}) + wandb.log({"test_acc":agg['test_stat'][-1]}) + model.train() + save_best_model(model, runPath, agg, args) + +def train_vrex(train_loader, epoch, agg): + model.train() + train_loader.dataset.reset_batch() + i = 0 + print('\n====> Epoch: {:03d} '.format(epoch)) + running_loss = 0 + total_iters = len(train_loader) + running_losses = [] + while sum([l > 1 for l in train_loader.dataset.batches_left.values()]) >= args.meta_steps: + model.train() + i += 1 + # sample `meta_steps` number of domains to use for the inner loop + domains = sample_domains(train_loader, args.meta_steps, args.stratified).tolist() + # print(domains) + avg_loss = 0. + penalty = 0. + # overall_losses = F.cross_entropy(scale * results['y_pred'],results['y_true'],reduction="none") + losses_bygroup = [] + + # inner loop update + for domain in domains: + data = train_loader.dataset.get_batch(domain) + x, y = unpack_data(data, device) + y_hat = model(x,frozen_mode=args.frozen) + # loss = criterion(y_hat, y) + if 'poverty'in args.dataset.lower(): + loss = F.mse_loss(scale*y_hat,y,reduction="none") + else: + loss = F.cross_entropy(scale * y_hat,y,reduction="none") + losses_bygroup.append(loss.mean()) + + penalty += irm_penalty(loss) + avg_loss += loss.mean() + avg_loss /= args.meta_steps + penalty /= args.meta_steps + losses = [avg_loss, penalty, torch.stack(losses_bygroup).var()] + if len(running_losses)==0: + running_losses = [0]*len(losses) + for (j,loss) in enumerate(running_losses): + running_losses[j]+=losses[j].item() + + # print([l.item() for l in losses],sol) + optimiserC.zero_grad() + # loss = scales.dot(torch.stack(losses)) + loss = avg_loss+args.penalty_weight*torch.stack(losses_bygroup).var() + # print(loss) + loss.backward() + optimiserC.step() + if scheduler is not None and scheduler.step_every_batch: + scheduler.step() + running_loss += loss.item() + + # log the number of batches left for each domain + for domain in domains: + train_loader.dataset.batches_left[domain] = \ + train_loader.dataset.batches_left[domain] - 1 \ + if train_loader.dataset.batches_left[domain] > 1 else 1 + # print(i) + if i % args.print_iters == 0 and args.print_iters != -1: + print(avg_loss,penalty) + agg['losses'].append([l / args.print_iters for l in running_losses]) + if not args.no_wandb: + wandb.log({ "loss": loss.item(), + "erm_loss": agg['losses'][-1][0], + "irm_loss": agg['losses'][-1][1], + "vrex_loss": agg['losses'][-1][2], + }) + running_losses = [0]*len(losses) + # agg['losses'].append([l.item() for l in losses]) + agg['train_loss'].append(running_loss / args.print_iters) + agg['train_iters'].append(i+1+epoch*total_iters) + print('iteration {:05d}/{:d}: loss: {:6.3f}'.format(i + 1, total_iters, running_loss / args.print_iters)) + if i % args.eval_iters == 0 and args.eval_iters != -1: + test(val_loader, agg, loader_type='val') + test(test_loader, agg, loader_type='test') + if not args.no_wandb: + wandb.log({"val_acc":agg['val_stat'][-1]}) + wandb.log({"test_acc":agg['test_stat'][-1]}) + model.train() + save_best_model(model, runPath, agg, args) + + +amplify = 1e2 if (not args.adjust_irm and not args.adjust_loss) else 1 +preference = get_preference(args.preference_choice) +n_tasks = 1+2 +preference[1]/=amplify+1e-6 +pair_optimizer = PAIR(trainable_params,optimiserC,preference=preference,eps=args.eps) +descent = 0 +def train_pair(train_loader, epoch, agg): + model.train() + train_loader.dataset.reset_batch() + i = 0 + print('\n====> Epoch: {:03d} '.format(epoch)) + running_loss = 0 + total_iters = len(train_loader) + running_losses = [] + while sum([l > 1 for l in train_loader.dataset.batches_left.values()]) >= args.meta_steps: + model.train() + i += 1 + # sample `meta_steps` number of domains to use for the inner loop + domains = sample_domains(train_loader, args.meta_steps, args.stratified).tolist() + + avg_loss = 0. + penalty = 0. + # overall_losses = F.cross_entropy(scale * results['y_pred'],results['y_true'],reduction="none") + losses_bygroup = [] + y_hats = [] + # inner loop update + for domain in domains: + data = train_loader.dataset.get_batch(domain) + x, y = unpack_data(data, device) + y_hat = model(x,frozen_mode=args.frozen) + if 'poverty'in args.dataset.lower(): + loss = F.mse_loss(scale*y_hat,y,reduction="none") + else: + loss = F.cross_entropy(scale * y_hat,y,reduction="none") + losses_bygroup.append(loss.mean()) + + if args.adjust_loss: + penalty += irm_penalty(loss) + else: + penalty += irm_penalty(loss,pos=amplify,adjust=args.adjust_irm) + avg_loss += loss.mean() + avg_loss /= args.meta_steps + penalty /= args.meta_steps + losses = [avg_loss, penalty, torch.stack(losses_bygroup).var()] + if len(running_losses)==0: + running_losses = [0]*len(losses) + for (j,loss) in enumerate(running_losses): + running_losses[j]+=losses[j].item() + pair_optimizer.zero_grad() + pair_optimizer.set_losses(losses) + pair_loss, moo_losses, mu_rl, alphas = pair_optimizer.step() + + if scheduler is not None and scheduler.step_every_batch: + scheduler.step() + running_loss += pair_loss + + # log the number of batches left for each domain + for domain in domains: + train_loader.dataset.batches_left[domain] = \ + train_loader.dataset.batches_left[domain] - 1 \ + if train_loader.dataset.batches_left[domain] > 1 else 1 + + if i % args.print_iters == 0 and args.print_iters != -1: + agg['losses'].append([l / args.print_iters for l in running_losses]) + # compensate the irm penalty + if not args.adjust_irm and not args.adjust_loss: + agg['losses'][-1][1] -= amplify + + if not args.no_wandb: + wandb.log({ "loss": loss.item(), + "erm_loss": agg['losses'][-1][0], + "irm_loss": agg['losses'][-1][1], + "vrex_loss": agg['losses'][-1][2], + "mu_rl":mu_rl, + "erm_alpha":alphas[0], + "irm_alpha":alphas[1], + "vrex_alpha":alphas[2] + }) + running_losses = [0]*len(losses) + # agg['losses'].append([l.item() for l in losses]) + agg['train_loss'].append(running_loss / args.print_iters) + agg['train_iters'].append(i+1+epoch*total_iters) + print('iteration {:05d}/{:d}: loss: {:6.3f}'.format(i + 1, total_iters, running_loss / args.print_iters)) + if i % args.eval_iters == 0 and args.eval_iters != -1: + test(val_loader, agg, loader_type='val') + test(test_loader, agg, loader_type='test') + if not args.no_wandb: + wandb.log({"val_acc":agg['val_stat'][-1]}) + wandb.log({"test_acc":agg['test_stat'][-1]}) + model.train() + save_best_model(model, runPath, agg, args) + + +def train_fish(train_loader, epoch, agg): + model.train() + train_loader.dataset.reset_batch() + i = 0 + print('\n====> Epoch: {:03d} '.format(epoch)) + opt_inner_pre = None + running_losses = [] + + while sum([l > 1 for l in train_loader.dataset.batches_left.values()]) >= args.meta_steps: + i += 1 + # sample `meta_steps` number of domains to use for the inner loop + domains = sample_domains(train_loader, args.meta_steps, args.stratified).tolist() + + + # prepare model for inner loop update + model_inner = copy.deepcopy(model) + + model_inner.train() + if args.dataset.lower() in ["poverty"]: + classifier = model_inner.enc.fc + elif args.dataset.lower() in ["iwildcam","rxrx"]: + classifier = model_inner.fc + else: + classifier = model_inner.classifier + inner_trainable_params = classifier.parameters() if args.frozen else model_inner.parameters() + opt_inner = opt(inner_trainable_params, **args.optimiser_args) + if opt_inner_pre is not None and args.reload_inner_optim: + opt_inner.load_state_dict(opt_inner_pre) + + penalty = 0. + avg_loss = 0 + losses_bygroup = [] + # inner loop update + for domain in domains: + data = train_loader.dataset.get_batch(domain) + x, y = unpack_data(data, device) + opt_inner.zero_grad() + y_hat = model_inner(x) + loss = criterion(y_hat, y) + loss.backward() + opt_inner.step() + losses_bygroup.append(loss.mean()) + if 'poverty'in args.dataset.lower(): + cur_loss = F.mse_loss(scale*y_hat,y,reduction="none") + else: + cur_loss = F.cross_entropy(scale * y_hat,y,reduction="none") + penalty += irm_penalty(cur_loss) + avg_loss += loss.mean() + + avg_loss /= args.meta_steps + penalty /= args.meta_steps + # losses = losses_bygroup+[ penalty, torch.stack(losses_bygroup).var()] + losses = [avg_loss, penalty, torch.stack(losses_bygroup).var()] + # agg['losses'].append([l.item() for l in losses]) + if len(running_losses)==0: + running_losses = [0]*len(losses) + for (j,loss) in enumerate(running_losses): + running_losses[j]+=losses[j].item() + + opt_inner_pre = opt_inner.state_dict() + # fish update + meta_weights = fish_step(meta_weights=model.state_dict(), + inner_weights=model_inner.state_dict(), + meta_lr=args.meta_lr / args.meta_steps) + model.reset_weights(meta_weights) + # log the number of batches left for each domain + for domain in domains: + train_loader.dataset.batches_left[domain] = \ + train_loader.dataset.batches_left[domain] - 1 \ + if train_loader.dataset.batches_left[domain] > 1 else 1 + + if (i + 1) % args.print_iters == 0 and args.print_iters != -1: + agg['losses'].append([l / args.print_iters for l in running_losses]) + if not args.no_wandb: + wandb.log({ "loss": avg_loss, + "erm_loss": agg['losses'][-1][0], + "irm_loss": agg['losses'][-1][1], + "vrex_loss": agg['losses'][-1][2], + }) + print(f"iteration {(i + 1):05d}: {agg['losses'][-1]}") + running_losses = [0]*len(losses) + if i % args.eval_iters == 0 and args.eval_iters != -1: + test(val_loader, agg, loader_type='val') + test(test_loader, agg, loader_type='test') + if not args.no_wandb: + wandb.log({"val_acc":agg['val_stat'][-1]}) + wandb.log({"test_acc":agg['test_stat'][-1]}) + model.train() + save_best_model(model, runPath, agg, args) + +def test(test_loader, agg, loader_type='test', verbose=True, save_ypred=False, return_last=False): + model.eval() + yhats, ys, metas = [], [], [] + import timeit + with torch.no_grad(): + a = timeit.default_timer() + for i, (x, y, meta) in enumerate(test_loader): + # get the inputs + x, y = x.to(device), y.to(device) + y_hat = model(x) + ys.append(y) + yhats.append(y_hat) + metas.append(meta) + # print(timeit.default_timer()-a) + # a = timeit.default_timer() + ypreds, ys, metas = predict_fn(torch.cat(yhats)), torch.cat(ys), torch.cat(metas) + if save_ypred: + if args.dataset == 'poverty': + save_name = f"{args.dataset}_split:{loader_type}_fold:" \ + f"{['A', 'B', 'C', 'D', 'E'][args.seed]}" \ + f"_epoch:best_pred.csv" + else: + save_name = f"{args.dataset}_split:{loader_type}_seed:" \ + f"{args.seed}_epoch:best_pred.csv" + with open(f"{runPath}/{save_name}", 'w') as f: + writer = csv.writer(f) + writer.writerows(ypreds.unsqueeze(1).cpu().tolist()) + test_val = test_loader.dataset.eval(ypreds.cpu(), ys.cpu(), metas) + agg[f'{loader_type}_stat'].append(test_val[0][args.selection_metric]) + if verbose: + print(f"=============== {loader_type} ===============\n{test_val[-1]}") + if return_last: + return test_val[0][args.selection_metric] + + +if __name__ == '__main__': + try: + if args.need_pretrain and args.pretrain_iters != 0: + pretrain_path = os.path.join(args.exp_dir,"experiments",args.dataset,str(args.seed)) + if not os.path.exists(pretrain_path): + os.makedirs(pretrain_path) + if args.use_old: + model.load_state_dict(torch.load(pretrain_path + f'/model.rar')) + print(f"Load pretrained model from {pretrain_path}") + else: + print("="*30 + "ERM pretrain" + "="*30) + pretrain(train_loader, args.pretrain_iters, save_path=pretrain_path) + + torch.cuda.empty_cache() + print("="*30 + f"Training: {args.algorithm}" + "="*30) + train = locals()[f'train_{args.algorithm}'] + agg = defaultdict(list) + agg['val_stat'] = [0.] + + for epoch in range(args.epochs): + train(train_loader, epoch, agg) + + test(val_loader, agg, loader_type='val') + test(test_loader, agg, loader_type='test') + save_best_model(model, runPath, agg, args) + if not args.no_wandb: + wandb.log({"val_acc":agg['val_stat'][-1]}) + wandb.log({"test_acc":agg['test_stat'][-1]}) + if scheduler is not None and not scheduler.step_every_batch: + scheduler.step() + print(optimiserC) + + model.load_state_dict(torch.load(runPath + '/model.rar')) + print('Finished training! Loading best model...') + test_acc = 0 + for split, loader in tv_loaders.items(): + tmp_acc = test(loader, agg, loader_type=split, save_ypred=True,return_last=True) + if split=="test": + test_acc = tmp_acc + + import matplotlib.pyplot as plt + if not args.no_plot: + folder_name = os.path.join(args.exp_dir,"plots",f"{args.dataset}_{args.algorithm}") + if not os.path.exists(folder_name): + os.mkdir(folder_name) + num_epochs = len(agg['losses']) + plt.title(exp_name) + fig, ax1 = plt.subplots() + ax1.set_xlabel("epoch") + ax1.set_ylabel("test acc") + + + if args.algorithm in ['pair','irm','vrex','fish',"irmx"]: + ax2 = ax1.twinx() + ax2.set_ylabel("penalty") + if len(agg['losses'][0])>=3: + irm_pens = np.array([log_i[-2] for log_i in agg['losses']]) + vrex_pens = np.array([log_i[-1] for log_i in agg['losses']]) + ax2.plot(np.arange(num_epochs),irm_pens,label=f'irm_pen',c='r',alpha=0.2) + ax2.plot(np.arange(num_epochs),vrex_pens,label=f'vrex_pen',c='g',alpha=0.2) + erm_pens = np.array([log_i[-3] for log_i in agg['losses']]) + else: + irm_pens = np.array([log_i[-1] for log_i in agg['losses']]) + ax2.plot(np.arange(num_epochs),irm_pens,label=f'{args.algorithm}_pen',c='r',alpha=0.2) + erm_pens = np.array([log_i[-2] for log_i in agg['losses']]) + + max_ratio = irm_pens.max() + ax2.set_title(f"{exp_name}: {test_acc}") + # control the scale of erm loss w.r.t. others + erm_pens = np.clip(erm_pens,erm_pens.min(),erm_pens.min()*max_ratio) + ax1.plot(np.arange(num_epochs),erm_pens,label=f'erm_pen') + # ask matplotlib for the plotted objects and their labels + lines, labels = ax1.get_legend_handles_labels() + lines2, labels2 = ax2.get_legend_handles_labels() + ax2.legend(lines + lines2, labels + labels2, loc=0) + else: + erm_pens = np.array([log_i[0] for log_i in agg['losses']]) + ax1.plot(np.arange(num_epochs),erm_pens,label=f'erm_pen') + + plt.savefig(os.path.join(folder_name,f"{exp_name}.png")) + plt.close() + torch.save(agg,os.path.join(folder_name,f"{exp_name}_agg.pt")) + if not args.no_wandb: + wandb.finish() + except Exception as e: + traceback.print_exc() + print(e) + if not args.no_wandb: + wandb.finish(-1) + print("Exceptions found, delete all wandb files") + import shutil + shutil.rmtree(wandb_run.dir.replace("/files","")) diff --git a/WILDS/src/misc.py b/WILDS/src/misc.py new file mode 100644 index 0000000..e8223b9 --- /dev/null +++ b/WILDS/src/misc.py @@ -0,0 +1,400 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved + +""" +Things that don't belong anywhere else +""" + +import hashlib +import json +import os +import sys +from shutil import copyfile +from collections import OrderedDict +from numbers import Number +import operator + +import numpy as np +import torch +import tqdm +from collections import Counter + +def make_weights_for_balanced_classes(dataset): + counts = Counter() + classes = [] + for _, y in dataset: + y = int(y) + counts[y] += 1 + classes.append(y) + + n_classes = len(counts) + + weight_per_class = {} + for y in counts: + weight_per_class[y] = 1 / (counts[y] * n_classes) + + weights = torch.zeros(len(dataset)) + for i, y in enumerate(classes): + weights[i] = weight_per_class[int(y)] + + return weights + +def pdb(): + sys.stdout = sys.__stdout__ + import pdb + print("Launching PDB, enter 'n' to step to parent function.") + pdb.set_trace() + +def seed_hash(*args): + """ + Derive an integer hash from all args, for use as a random seed. + """ + args_str = str(args) + return int(hashlib.md5(args_str.encode("utf-8")).hexdigest(), 16) % (2**31) + +def print_separator(): + print("="*80) + +def print_row(row, colwidth=10, latex=False): + if latex: + sep = " & " + end_ = "\\\\" + else: + sep = " " + end_ = "" + + def format_val(x): + if np.issubdtype(type(x), np.floating): + x = "{:.10f}".format(x) + return str(x).ljust(colwidth)[:colwidth] + print(sep.join([format_val(x) for x in row]), end_) + +class _SplitDataset(torch.utils.data.Dataset): + """Used by split_dataset""" + def __init__(self, underlying_dataset, keys): + super(_SplitDataset, self).__init__() + self.underlying_dataset = underlying_dataset + self.keys = keys + def __getitem__(self, key): + return self.underlying_dataset[self.keys[key]] + def __len__(self): + return len(self.keys) + +def split_dataset(dataset, n, seed=0): + """ + Return a pair of datasets corresponding to a random split of the given + dataset, with n datapoints in the first dataset and the rest in the last, + using the given random seed + """ + assert(n <= len(dataset)) + keys = list(range(len(dataset))) + np.random.RandomState(seed).shuffle(keys) + keys_1 = keys[:n] + keys_2 = keys[n:] + return _SplitDataset(dataset, keys_1), _SplitDataset(dataset, keys_2) + +def random_pairs_of_minibatches(minibatches): + perm = torch.randperm(len(minibatches)).tolist() + pairs = [] + + for i in range(len(minibatches)): + j = i + 1 if i < (len(minibatches) - 1) else 0 + + xi, yi = minibatches[perm[i]][0], minibatches[perm[i]][1] + xj, yj = minibatches[perm[j]][0], minibatches[perm[j]][1] + + min_n = min(len(xi), len(xj)) + + pairs.append(((xi[:min_n], yi[:min_n]), (xj[:min_n], yj[:min_n]))) + + return pairs + +def accuracy(network, loader, weights, device): + correct = 0 + total = 0 + weights_offset = 0 + + network.eval() + with torch.no_grad(): + for x, y in loader: + x = x.to(device) + y = y.to(device) + p = network.predict(x) + #print() + #print(p) + if weights is None: + batch_weights = torch.ones(len(x)) + else: + batch_weights = weights[weights_offset : weights_offset + len(x)] + weights_offset += len(x) + batch_weights = batch_weights.to(device) + if p.size(1) == 1: + #print(p.flatten().gt(0).eq(y).float()) + #print(p.flatten().gt(0).eq(y).float().sum().item()) + correct += (p.flatten().gt(0).eq(y).float() * batch_weights.flatten()).sum().item() + else: + correct += (p.argmax(1).eq(y).float() * batch_weights).sum().item() + total += batch_weights.sum().item() + #print(correct,total) + #0/0 + network.train() + + return correct / total + +class Tee: + def __init__(self, fname, mode="a"): + self.stdout = sys.stdout + self.file = open(fname, mode) + + def write(self, message): + self.stdout.write(message) + self.file.write(message) + self.flush() + + def flush(self): + self.stdout.flush() + self.file.flush() + +class ParamDict(OrderedDict): + """Code adapted from https://github.com/Alok/rl_implementations/tree/master/reptile. + A dictionary where the values are Tensors, meant to represent weights of + a model. This subclass lets you perform arithmetic on weights directly.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, *kwargs) + + def _prototype(self, other, op): + if isinstance(other, Number): + return ParamDict({k: op(v, other) for k, v in self.items()}) + elif isinstance(other, dict): + return ParamDict({k: op(self[k], other[k]) for k in self}) + else: + raise NotImplementedError + + def __add__(self, other): + return self._prototype(other, operator.add) + + def __rmul__(self, other): + return self._prototype(other, operator.mul) + + __mul__ = __rmul__ + + def __neg__(self): + return ParamDict({k: -v for k, v in self.items()}) + + def __rsub__(self, other): + # a- b := a + (-b) + return self.__add__(other.__neg__()) + + __sub__ = __rsub__ + + def __truediv__(self, other): + return self._prototype(other, operator.truediv) + + + + +def l2_between_dicts(dict_1, dict_2): + assert len(dict_1) == len(dict_2) + dict_1_values = [dict_1[key] for key in sorted(dict_1.keys())] + dict_2_values = [dict_2[key] for key in sorted(dict_1.keys())] + return ( + torch.cat(tuple([t.view(-1) for t in dict_1_values])) - + torch.cat(tuple([t.view(-1) for t in dict_2_values])) + ).pow(2).mean() + +class MovingAverage: + + def __init__(self, ema, oneminusema_correction=True): + self.ema = ema + self.ema_data = {} + self._updates = 0 + self._oneminusema_correction = oneminusema_correction + + def update(self, dict_data): + ema_dict_data = {} + for name, data in dict_data.items(): + data = data.view(1, -1) + if self._updates == 0: + previous_data = torch.zeros_like(data) + else: + previous_data = self.ema_data[name] + + ema_data = self.ema * previous_data + (1 - self.ema) * data + if self._oneminusema_correction: + # correction by 1/(1 - self.ema) + # so that the gradients amplitude backpropagated in data is independent of self.ema + ema_dict_data[name] = ema_data / (1 - self.ema) + else: + ema_dict_data[name] = ema_data + self.ema_data[name] = ema_data.clone().detach() + + self._updates += 1 + return ema_dict_data + + + +def make_weights_for_balanced_classes(dataset): + counts = Counter() + classes = [] + for _, y in dataset: + y = int(y) + counts[y] += 1 + classes.append(y) + + n_classes = len(counts) + + weight_per_class = {} + for y in counts: + weight_per_class[y] = 1 / (counts[y] * n_classes) + + weights = torch.zeros(len(dataset)) + for i, y in enumerate(classes): + weights[i] = weight_per_class[int(y)] + + return weights + +def pdb(): + sys.stdout = sys.__stdout__ + import pdb + print("Launching PDB, enter 'n' to step to parent function.") + pdb.set_trace() + +def seed_hash(*args): + """ + Derive an integer hash from all args, for use as a random seed. + """ + args_str = str(args) + return int(hashlib.md5(args_str.encode("utf-8")).hexdigest(), 16) % (2**31) + +def print_separator(): + print("="*80) + +def print_row(row, colwidth=10, latex=False): + if latex: + sep = " & " + end_ = "\\\\" + else: + sep = " " + end_ = "" + + def format_val(x): + if np.issubdtype(type(x), np.floating): + x = "{:.10f}".format(x) + return str(x).ljust(colwidth)[:colwidth] + print(sep.join([format_val(x) for x in row]), end_) + +class _SplitDataset(torch.utils.data.Dataset): + """Used by split_dataset""" + def __init__(self, underlying_dataset, keys): + super(_SplitDataset, self).__init__() + self.underlying_dataset = underlying_dataset + self.keys = keys + def __getitem__(self, key): + return self.underlying_dataset[self.keys[key]] + def __len__(self): + return len(self.keys) + +def split_dataset(dataset, n, seed=0): + """ + Return a pair of datasets corresponding to a random split of the given + dataset, with n datapoints in the first dataset and the rest in the last, + using the given random seed + """ + assert(n <= len(dataset)) + keys = list(range(len(dataset))) + np.random.RandomState(seed).shuffle(keys) + keys_1 = keys[:n] + keys_2 = keys[n:] + return _SplitDataset(dataset, keys_1), _SplitDataset(dataset, keys_2) + +def random_pairs_of_minibatches(minibatches): + perm = torch.randperm(len(minibatches)).tolist() + pairs = [] + + for i in range(len(minibatches)): + j = i + 1 if i < (len(minibatches) - 1) else 0 + + xi, yi = minibatches[perm[i]][0], minibatches[perm[i]][1] + xj, yj = minibatches[perm[j]][0], minibatches[perm[j]][1] + + min_n = min(len(xi), len(xj)) + + pairs.append(((xi[:min_n], yi[:min_n]), (xj[:min_n], yj[:min_n]))) + + return pairs + +def accuracy(network, loader, weights, device): + correct = 0 + total = 0 + weights_offset = 0 + + network.eval() + with torch.no_grad(): + for x, y in loader: + x = x.to(device) + y = y.to(device) + p = network.predict(x) + if weights is None: + batch_weights = torch.ones(len(x)) + else: + batch_weights = weights[weights_offset : weights_offset + len(x)] + weights_offset += len(x) + batch_weights = batch_weights.to(device) + if p.size(1) == 1: + correct += (p.gt(0).eq(y).float() * batch_weights.view(-1, 1)).sum().item() + else: + correct += (p.argmax(1).eq(y).float() * batch_weights).sum().item() + total += batch_weights.sum().item() + network.train() + + return correct / total + +class Tee: + def __init__(self, fname, mode="a"): + self.stdout = sys.stdout + self.file = open(fname, mode) + + def write(self, message): + self.stdout.write(message) + self.file.write(message) + self.flush() + + def flush(self): + self.stdout.flush() + self.file.flush() + +class ParamDict(OrderedDict): + """Code adapted from https://github.com/Alok/rl_implementations/tree/master/reptile. + A dictionary where the values are Tensors, meant to represent weights of + a model. This subclass lets you perform arithmetic on weights directly.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, *kwargs) + + def _prototype(self, other, op): + if isinstance(other, Number): + return ParamDict({k: op(v, other) for k, v in self.items()}) + elif isinstance(other, dict): + return ParamDict({k: op(self[k], other[k]) for k in self}) + else: + raise NotImplementedError + + def __add__(self, other): + return self._prototype(other, operator.add) + + def __rmul__(self, other): + return self._prototype(other, operator.mul) + + __mul__ = __rmul__ + + def __neg__(self): + return ParamDict({k: -v for k, v in self.items()}) + + def __rsub__(self, other): + # a- b := a + (-b) + return self.__add__(other.__neg__()) + + __sub__ = __rsub__ + + def __truediv__(self, other): + return self._prototype(other, operator.truediv) diff --git a/WILDS/src/models/__init__.py b/WILDS/src/models/__init__.py new file mode 100644 index 0000000..e6e74a9 --- /dev/null +++ b/WILDS/src/models/__init__.py @@ -0,0 +1,9 @@ +from .camelyon import Model as camelyon +from .cdsprites import Model as cdsprites +from .civil import Model as civil +from .fmow import Model as fmow +from .iwildcam import Model as iwildcam +from .poverty import Model as poverty +from .rxrx import Model as rxrx + +__all__ = [cdsprites, iwildcam, camelyon, amazon, civil, fmow, poverty, rxrx] diff --git a/WILDS/src/models/camelyon.py b/WILDS/src/models/camelyon.py new file mode 100644 index 0000000..1d1db2a --- /dev/null +++ b/WILDS/src/models/camelyon.py @@ -0,0 +1,79 @@ +import os +from copy import deepcopy + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms as transforms +from torch.utils.data import DataLoader +from torchvision.models import densenet121 +from wilds.common.data_loaders import get_eval_loader +from wilds.datasets.camelyon17_dataset import Camelyon17Dataset + +from .datasets import GeneralWilds_Batched_Dataset + +IMG_HEIGHT = 224 +NUM_CLASSES = 2 + +class Model(nn.Module): + def __init__(self, args, weights): + super(Model, self).__init__() + self.num_classes = NUM_CLASSES + self.enc = densenet121(pretrained=False).features # remove fc layer + self.classifier = nn.Linear(1024, self.num_classes) + if weights is not None: + self.load_state_dict(deepcopy(weights)) + + def reset_weights(self, weights): + self.load_state_dict(deepcopy(weights)) + + @staticmethod + def getDataLoaders(args, device): + full_dataset = Camelyon17Dataset(root_dir=os.path.join(args.data_dir, 'wilds'), download=True) + + transform = transforms.Compose([ + transforms.Resize((224, 224)), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], + [0.229, 0.224, 0.225]) + ]) + # get all train data + train_data = full_dataset.get_subset('train', transform=transform) + # separate into subsets by distribution + train_sets = GeneralWilds_Batched_Dataset(train_data, args.batch_size, domain_idx=0, drop_last=not args.no_drop_last) + # take subset of test and validation, making sure that only labels appeared in train + # are included + datasets = {} + for split in full_dataset.split_dict: + if split != 'train': + datasets[split] = full_dataset.get_subset(split, transform=transform) + + # get the loaders + kwargs = {'num_workers': args.num_workers, 'pin_memory': True, 'drop_last': False} \ + if device.type == "cuda" else {} + train_loaders = DataLoader(train_sets, batch_size=args.batch_size, shuffle=True, **kwargs) + + kwargs = {'num_workers': args.num_workers, 'pin_memory': True, 'drop_last': False} + tv_loaders = {} + for split, dataset in datasets.items(): + tv_loaders[split] = get_eval_loader('standard', dataset, batch_size=256,**kwargs) + return train_loaders, tv_loaders,full_dataset + + def forward(self, x, get_feat=False,frozen_mode=False): + if frozen_mode: + self.enc.eval() + self.classifier.train() + with torch.no_grad(): + features = self.enc(x) + out = F.relu(features, inplace=True) + out = F.adaptive_avg_pool2d(out, (1, 1)) + out = torch.flatten(out, 1) + else: + features = self.enc(x) + out = F.relu(features, inplace=True) + out = F.adaptive_avg_pool2d(out, (1, 1)) + out = torch.flatten(out, 1) + pred = self.classifier(out) + if get_feat: + return pred, out + return pred diff --git a/WILDS/src/models/civil.py b/WILDS/src/models/civil.py new file mode 100644 index 0000000..5dde9a3 --- /dev/null +++ b/WILDS/src/models/civil.py @@ -0,0 +1,237 @@ +import os +from copy import deepcopy + +import torch +from torch import nn +from torch.utils.data import DataLoader +from transformers import DistilBertForSequenceClassification +from transformers import DistilBertTokenizerFast +from transformers import logging +from wilds.common.data_loaders import get_eval_loader +from wilds.datasets.civilcomments_dataset import CivilCommentsDataset + +from .datasets import CivilComments_Batched_Dataset + +logging.set_verbosity_error() + +MAX_TOKEN_LENGTH = 300 +NUM_CLASSES = 2 + +def initialize_bert_transform(root_dir="../data"): + """Adapted from the Wilds library, available at: https://github.com/p-lambda/wilds""" + try: + tokenizer = DistilBertTokenizerFast.from_pretrained(root_dir) + except Exception as e: + tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased") + tokenizer.save_pretrained(root_dir) + def transform(text): + tokens = tokenizer( + text, + padding='max_length', + truncation=True, + max_length=MAX_TOKEN_LENGTH, + return_tensors='pt') + x = torch.stack( + (tokens['input_ids'], + tokens['attention_mask']), + dim=2) + x = torch.squeeze(x, dim=0) # First shape dim is always 1 + return x + return transform + + +class DistilBertClassifier(DistilBertForSequenceClassification): + def __init__(self, config): + super().__init__(config) + + + def __call__(self, x,output_hidden_states=False): + input_ids = x[:, :, 0] + attention_mask = x[:, :, 1] + outputs = super().__call__( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=output_hidden_states, + ) + + if output_hidden_states: + outputs = outputs + else: + outputs = outputs[0] + return outputs + +class Model(nn.Module): + def __init__(self, args, weights): + super(Model, self).__init__() + self.num_classes = NUM_CLASSES + + try: + self.model = DistilBertClassifier.from_pretrained( + os.path.join(args.data_dir, 'wilds',args.dataset), + num_labels=2, + ) + except Exception as e: + self.model = DistilBertClassifier.from_pretrained( + 'distilbert-base-uncased', + num_labels=2, + ) + self.model.save_pretrained(os.path.join(args.data_dir, 'wilds',args.dataset)) + # self.model = DistilBertClassifier.from_pretrained( + # 'distilbert-base-uncased', + # num_labels=2, + # cache_dir=os.path.join(args.data_dir, 'wilds',args.dataset) + # ) + if weights is not None: + self.load_state_dict(deepcopy(weights)) + self.classifier = self.model.classifier + + def reset_weights(self, weights): + self.model.load_state_dict(deepcopy(weights)) + self.classifier = self.model.classifier + + @staticmethod + def getDataLoaders(args, device): + dataset = CivilCommentsDataset(root_dir=os.path.join(args.data_dir, 'wilds'), download=True) + # get all train data + transform = initialize_bert_transform(root_dir=os.path.join(args.data_dir, 'wilds', args.dataset)) + train_data = dataset.get_subset('train', transform=transform) + # separate into subsets by distribution + train_sets = CivilComments_Batched_Dataset(train_data, batch_size=args.batch_size, drop_last=not args.no_drop_last) + # take subset of test and validation, making sure that only labels appeared in train + # are included + datasets = {} + for split in dataset.split_dict: + if split != 'train': + datasets[split] = dataset.get_subset(split, transform=transform) + # get the loaders + kwargs = {'num_workers': args.num_workers, 'pin_memory': True, 'drop_last': False} \ + if device.type == "cuda" else {} + train_loaders = DataLoader(train_sets, batch_size=args.batch_size, shuffle=True, **kwargs) + tv_loaders = {} + for split, sep_dataset in datasets.items(): + tv_loaders[split] = get_eval_loader('standard', sep_dataset, batch_size=256, num_workers=args.num_workers) + return train_loaders, tv_loaders, dataset + + def forward(self, x, get_feat=False,frozen_mode=False): + if frozen_mode: + self.model.eval() + self.classifier.train() + with torch.no_grad(): + outs = self.model(x,output_hidden_states=True) + pooled_output = outs[1][-1][:, 0] + # pooled_output = hidden_state[:, 0] # (bs, dim) + pooled_output = self.model.pre_classifier(pooled_output) # (bs, dim) + pooled_output = nn.ReLU()(pooled_output) # (bs, dim) + pooled_output = self.model.dropout(pooled_output) # (bs, dim) + outs = self.classifier(pooled_output) + return outs + if get_feat: + # print(self.model) + outs = self.model(x,output_hidden_states=True) + with torch.no_grad(): + pooled_output = outs[1][-1][:, 0] + # pooled_output = hidden_state[:, 0] # (bs, dim) + pooled_output = self.model.pre_classifier(pooled_output) # (bs, dim) + pooled_output = nn.ReLU()(pooled_output) # (bs, dim) + pooled_output = self.model.dropout(pooled_output) # (bs, dim) + # print(pooled_output.size()) + + # print(self.model.classifier) + # print(outs) + # exit() + return outs[0],pooled_output + return self.model(x) +import os +from copy import deepcopy + +import torch +from torch import nn +from torch.utils.data import DataLoader +from transformers import BertForSequenceClassification +from transformers import BertTokenizerFast +from transformers import logging +from wilds.common.data_loaders import get_eval_loader +from wilds.datasets.civilcomments_dataset import CivilCommentsDataset + +from .datasets import CivilComments_Batched_Dataset + +logging.set_verbosity_error() + +MAX_TOKEN_LENGTH = 300 +NUM_CLASSES = 2 + +# def initialize_bert_transform(): +# """Adapted from the Wilds library, available at: https://github.com/p-lambda/wilds""" +# tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') +# def transform(text): +# tokens = tokenizer( +# text, +# padding='max_length', +# truncation=True, +# max_length=MAX_TOKEN_LENGTH, +# return_tensors='pt') +# x = torch.stack( +# (tokens['input_ids'], +# tokens['attention_mask'], +# tokens['token_type_ids']), +# dim=2) +# x = torch.squeeze(x, dim=0) # First shape dim is always 1 +# return x +# return transform + +# class BertClassifier(BertForSequenceClassification): +# """Adapted from the Wilds library, available at: https://github.com/p-lambda/wilds""" +# def __init__(self, args): +# super().__init__(args) +# self.d_out = 2 + +# def __call__(self, x): +# input_ids = x[:, :, 0] +# attention_mask = x[:, :, 1] +# token_type_ids = x[:, :, 2] +# outputs = super().__call__( +# input_ids=input_ids, +# attention_mask=attention_mask, +# token_type_ids=token_type_ids +# )[0] +# return outputs + +# class Model(nn.Module): +# def __init__(self, args, weights): +# super(Model, self).__init__() +# self.num_classes = NUM_CLASSES +# self.model = BertClassifier.from_pretrained( +# 'bert-base-uncased', +# num_labels=2, +# ) +# if weights is not None: +# self.load_state_dict(deepcopy(weights)) + +# def reset_weights(self, weights): +# self.load_state_dict(deepcopy(weights)) + +# @staticmethod +# def getDataLoaders(args, device): +# dataset = CivilCommentsDataset(root_dir=os.path.join(args.data_dir, 'wilds'), download=True) +# # get all train data +# transform = initialize_bert_transform() +# train_data = dataset.get_subset('train', transform=transform) +# # separate into subsets by distribution +# train_sets = CivilComments_Batched_Dataset(train_data, batch_size=args.batch_size) +# # take subset of test and validation, making sure that only labels appeared in train +# # are included +# datasets = {} +# for split in dataset.split_dict: +# if split != 'train': +# datasets[split] = dataset.get_subset(split, transform=transform) +# # get the loaders +# kwargs = {'num_workers': args.num_workers, 'pin_memory': True, 'drop_last': False} \ +# if device.type == "cuda" else {} +# train_loaders = DataLoader(train_sets, batch_size=args.batch_size, shuffle=True, **kwargs) +# tv_loaders = {} +# for split, sep_dataset in datasets.items(): +# tv_loaders[split] = get_eval_loader('standard', sep_dataset, batch_size=256) +# return train_loaders, tv_loaders, dataset + +# def forward(self, x): +# return self.model(x) diff --git a/WILDS/src/models/datasets.py b/WILDS/src/models/datasets.py new file mode 100644 index 0000000..9182ae9 --- /dev/null +++ b/WILDS/src/models/datasets.py @@ -0,0 +1,501 @@ +import copy +import os + +import numpy as np +import torch +from PIL import Image +from torch.utils.data import Dataset + + +class Poverty_Batched_Dataset(Dataset): + """ + Batched dataset for Poverty. Allows for getting a batch of data given + a specific domain index. + """ + def __init__(self, dataset, split, batch_size, transform=None, drop_last=True): + self.split_array = dataset.split_array + self.split_dict = dataset.split_dict + split_mask = self.split_array == self.split_dict[split] + self.split_idx = np.where(split_mask)[0] + + self.root = dataset.root + self.no_nl = dataset.no_nl + + self.metadata_array = torch.stack([dataset.metadata_array[self.split_idx, i] for i in [0, 2]], -1) + self.y_array = dataset.y_array[self.split_idx] + + self.eval = dataset.eval + self.collate = dataset.collate + self.metadata_fields = dataset.metadata_fields + self.data_dir = dataset.data_dir + + self.transform = transform if transform is not None else lambda x: x + + domains = self.metadata_array[:, 1] + self.domain_indices = [torch.nonzero(domains == loc).squeeze(-1) + for loc in domains.unique()] + self.num_envs = len(domains.unique()) + for did in self.domain_indices: + print(len(did)) + self.domains = domains + self.targets = self.y_array + self.batch_size = batch_size + self.drop_last = drop_last + + min_domain_size = np.min([len(didx) for didx in self.domain_indices]) + self.training_steps = int(min_domain_size/self.batch_size)+\ + (not self.drop_last*(min_domain_size//self.batch_size>0)) + + def reset_batch(self): + """Reset batch indices for each domain.""" + self.batch_indices, self.batches_left = {}, {} + for loc, d_idx in enumerate(self.domain_indices): + self.batch_indices[loc] = torch.split(d_idx[torch.randperm(len(d_idx))], self.batch_size) + # mannually drop last + if self.drop_last and len(self.batch_indices[loc][-1])0)) + + def reset_batch(self): + """Reset batch indices for each domain.""" + self.batch_indices, self.batches_left = {}, {} + for loc, d_idx in enumerate(self.domain_indices): + self.batch_indices[loc] = torch.split(d_idx[torch.randperm(len(d_idx))], self.batch_size) + # mannually drop last + if self.drop_last and len(self.batch_indices[loc][-1])0)) + + def reset_batch(self): + """Reset batch indices for each domain.""" + self.batch_indices, self.batches_left = {}, {} + for loc, d_idx in enumerate(self.domain_indices): + print(len(d_idx)) + self.batch_indices[loc] = torch.split(d_idx[torch.randperm(len(d_idx))], self.batch_size) + # mannually drop last + if self.drop_last and len(self.batch_indices[loc][-1])0)) + + def reset_batch(self): + """Reset batch indices for each domain.""" + self.batch_indices, self.batches_left = {}, {} + for loc, d_idx in enumerate(self.domain_indices): + self.batch_indices[loc] = torch.split(d_idx[torch.randperm(len(d_idx))], self.batch_size) + # mannually drop last + if self.drop_last and len(self.batch_indices[loc][-1])0)) + + def reset_batch(self): + """Reset batch indices for each domain.""" + self.batch_indices, self.batches_left = {}, {} + for loc, env_idx in enumerate(self.domain_indices): + self.batch_indices[loc] = torch.split(env_idx[torch.randperm(len(env_idx))], self.batch_size) + # mannually drop last + if self.drop_last and len(self.batch_indices[loc][-1])3: + canvas = torch.zeros_like(image).repeat(1,3,1,1) + for c, (img, l) in enumerate(zip(image, latent)): + canvas[c, ...] = self.get_domain_color_palatte(img, l, env) + else: + canvas = self.get_domain_color_palatte(image, latent, env) + return (canvas, latent, None) + + + def __len__(self): + return len(self.latents) + + def __getitem__(self, idx): + image = torch.Tensor(self.images[idx]) + latent = torch.Tensor(self.latents[idx]) + + if len(image.shape)>3: + canvas = torch.zeros_like(image).repeat(1,3,1,1) + for c, (img, l) in enumerate(zip(image, latent)): + canvas[c, ...] = self.get_color_palatte(img, l) + else: + canvas = self.get_color_palatte(image, latent) + return (canvas, latent[0].long()-1, latent[0:]) + + def get_color_palatte(self, image, latent): + chosen_color = torch.randint(high=len(self.color_palattes) - 1, size=(1,)).item() + cc = int(latent[-1].long()) if self.split == 'train' else \ + torch.randint(high=2, size=(1,)).item() + canvas = self.color_palattes[chosen_color][cc] + return canvas*image + + def get_domain_color_palatte(self, image, latent, chosen_color): + cc = int(latent[-1].long()) + canvas = self.color_palattes[chosen_color][cc] + return canvas*image + + def eval(self, ypreds, ys, metas): + total = ys.size(0) + correct = (ypreds == ys).sum().item() + test_val = [ + {'acc_avg': correct/total}, + f"Accuracy: {correct/total*100:6.2f}%" + ] + return test_val + + +DspritesDataSize = torch.Size([1, 64, 64]) +class DspritesDataset(Dataset): + """2D shapes dataset. + More info here: + https://github.com/deepmind/dsprites-dataset/blob/master/dsprites_reloading_example.ipynb + """ + def __init__(self, data_root, train=True, train_fract=0.8, split=True, clip=False, drop_last=True): + """ + Args: + npz_file (string): Path to the npz file. + """ + filename = 'dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz' + self.npz_file = data_root + '/' + filename + self.npz_train_file = data_root + '/train_' + filename + self.npz_test_file = data_root + '/test_' + filename + if not os.path.isfile(self.npz_file): + self.download_dataset(self.npz_file) + if split: + if not (os.path.isfile(self.npz_train_file) and os.path.isfile(self.npz_test_file)): + self.split_dataset(data_root, self.npz_file, self.npz_train_file, + self.npz_test_file, train_fract, clip) + dataset = np.load(self.npz_train_file if train else self.npz_test_file, + mmap_mode='r') + else: + rdataset = np.load(self.npz_file, encoding='latin1', mmap_mode='r') + dataset = {'latents': rdataset['latents_values'][:, 1:], # drop colour + 'images': rdataset['imgs']} + + self.latents = dataset['latents'] + self.images = dataset['images'] + + def download_dataset(self, npz_file): + from urllib import request + url = 'https://github.com/deepmind/dsprites-dataset/blob/master/' \ + 'dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz?raw=true' + print('Downloading ' + url) + data = request.urlopen(url) + with open(npz_file, 'wb') as f: + f.write(data.read()) + + def split_dataset(self, data_root, npz_file, npz_train_file, npz_test_file, train_fract, clip): + print('Splitting dataset') + dataset = np.load(npz_file, encoding='latin1', mmap_mode='r') + latents = dataset['latents_values'][:, 1:] + images = np.array(dataset['imgs'], dtype='float32') + images = images.reshape(-1, *DspritesDataSize) + if clip: + images = np.clip(images, 1e-6, 1 - 1e-6) + + split_idx = np.int(train_fract * len(latents)) + shuffled_range = np.random.permutation(len(latents)) + train_idx = shuffled_range[range(0, split_idx)] + test_idx = shuffled_range[range(split_idx, len(latents))] + + np.savez(npz_train_file, images=images[train_idx], latents=latents[train_idx]) + np.savez(npz_test_file, images=images[test_idx], latents=latents[test_idx]) + + def __len__(self): + return len(self.latents) + + def __getitem__(self, idx): + image = torch.Tensor(self.images[idx]).unsqueeze(0) + latent = torch.Tensor(self.latents[idx]) + return (image, latent) diff --git a/WILDS/src/models/fmow.py b/WILDS/src/models/fmow.py new file mode 100644 index 0000000..372fd39 --- /dev/null +++ b/WILDS/src/models/fmow.py @@ -0,0 +1,72 @@ +import os +from copy import deepcopy + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms as transforms +from torch.utils.data import DataLoader +from torchvision.models import densenet121 +from wilds.common.data_loaders import get_eval_loader +from wilds.datasets.fmow_dataset import FMoWDataset + +from .datasets import FMoW_Batched_Dataset + +IMG_HEIGHT = 224 +NUM_CLASSES = 62 + +class Model(nn.Module): + def __init__(self, args, weights): + super(Model, self).__init__() + self.num_classes = NUM_CLASSES + self.enc = densenet121(pretrained=True).features + self.classifier = nn.Linear(1024, self.num_classes) + if weights is not None: + self.load_state_dict(deepcopy(weights)) + + def reset_weights(self, weights): + self.load_state_dict(deepcopy(weights)) + + @staticmethod + def getDataLoaders(args, device): + dataset = FMoWDataset(root_dir=os.path.join(args.data_dir, 'wilds'), download=True) + # get all train data + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], + [0.229, 0.224, 0.225]) + ]) + train_sets = FMoW_Batched_Dataset(dataset, 'train', args.batch_size, transform, drop_last=not args.no_drop_last) + datasets = {} + for split in dataset.split_dict: + if split != 'train': + datasets[split] = dataset.get_subset(split, transform=transform) + + # get the loaders + kwargs = {'num_workers': args.num_workers, 'pin_memory': True, 'drop_last': False} \ + if device.type == "cuda" else {} + train_loaders = DataLoader(train_sets, batch_size=args.batch_size, shuffle=True, **kwargs) + tv_loaders = {} + for split, sep_dataset in datasets.items(): + tv_loaders[split] = get_eval_loader('standard', sep_dataset, batch_size=256, num_workers=args.num_workers) + return train_loaders, tv_loaders, dataset + + def forward(self, x, get_feat=False,frozen_mode=False): + + if frozen_mode: + self.enc.eval() + self.classifier.train() + with torch.no_grad(): + features = self.enc(x) + out = F.relu(features, inplace=True) + out = F.adaptive_avg_pool2d(out, (1, 1)) + out_features = torch.flatten(out, 1) + else: + features = self.enc(x) + out = F.relu(features, inplace=True) + out = F.adaptive_avg_pool2d(out, (1, 1)) + out_features = torch.flatten(out, 1) + out = self.classifier(out_features) + if get_feat: + return out, out_features + return out diff --git a/WILDS/src/models/iwildcam.py b/WILDS/src/models/iwildcam.py new file mode 100644 index 0000000..8d7c265 --- /dev/null +++ b/WILDS/src/models/iwildcam.py @@ -0,0 +1,99 @@ +import os +from copy import deepcopy + +import torch.nn as nn +import torchvision.transforms as transforms +from torch.utils.data import DataLoader +from torchvision.models import resnet50 +from wilds.common.data_loaders import get_eval_loader +from wilds.datasets.iwildcam_dataset import IWildCamDataset + +from .datasets import GeneralWilds_Batched_Dataset +import torch + +IMG_HEIGHT = 224 +NUM_CLASSES = 186 + +def get_image_base_transform_steps(dataset, target_resolution=None): + transform_steps = [] + + if dataset.original_resolution is not None and min( + dataset.original_resolution + ) != max(dataset.original_resolution): + crop_size = min(dataset.original_resolution) + transform_steps.append(transforms.CenterCrop(crop_size)) + + if target_resolution is not None: + transform_steps.append(transforms.Resize(target_resolution)) + + return transform_steps + +class Model(nn.Module): + def __init__(self, args, weights): + super(Model, self).__init__() + self.num_classes = NUM_CLASSES + pretrain_path=os.path.join(args.data_dir,'wilds',args.dataset) + if os.path.exists(pretrain_path): + resnet = resnet50(pretrained=False) + resnet.load_state_dict(torch.load(pretrain_path + f'/resnet50.rar')) + print(f"Load pretrained resnet from {pretrain_path}") + else: + resnet = resnet50(pretrained=True) + if not os.path.exists(pretrain_path): + os.makedirs(pretrain_path) + torch.save(resnet.state_dict(),pretrain_path+ f'/resnet50.rar') + print(f"Load pretrained resnet from url") + self.enc = nn.Sequential(*list(resnet.children())[:-1]) # remove fc layer + self.fc = nn.Linear(2048, self.num_classes) + if weights is not None: + self.load_state_dict(deepcopy(weights)) + + def reset_weights(self, weights): + self.load_state_dict(deepcopy(weights)) + + + @staticmethod + def getDataLoaders(args, device): + dataset = IWildCamDataset(root_dir=os.path.join(args.data_dir, 'wilds'), download=True) + # get all train data + transform = transforms.Compose([ + # transforms.Resize((224, 224)), + transforms.Resize((448, 448)), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], + [0.229, 0.224, 0.225]) + ]) + train_data = dataset.get_subset('train', transform=transform) + # separate into subsets by distribution + train_sets = GeneralWilds_Batched_Dataset(train_data, args.batch_size, domain_idx=0, drop_last=not args.no_drop_last) + # take subset of test and validation, making sure that only labels appeared in train + # are included + datasets = {} + for split in dataset.split_dict: + if split != 'train': + datasets[split] = dataset.get_subset(split, transform=transform) + + # get the loaders + kwargs = {'num_workers': args.num_workers, 'pin_memory': True, 'drop_last': False} \ + if device.type == "cuda" else {} + train_loaders = DataLoader(train_sets, batch_size=args.batch_size, shuffle=True, **kwargs) + tv_loaders = {} + for split, sep_dataset in datasets.items(): + tv_loaders[split] = get_eval_loader('standard', sep_dataset, batch_size=256, num_workers=args.num_workers) + return train_loaders, tv_loaders, dataset + + def forward(self, x,get_feat=False,frozen_mode=False): + # x = x.expand(-1, 3, -1, -1) # reshape MNIST from 1x32x32 => 3x32x32 + if len(x.shape) == 3: + x.unsqueeze_(0) + if frozen_mode: + self.enc.eval() + self.fc.train() + with torch.no_grad(): + e = self.enc(x) + else: + e = self.enc(x) + out = self.fc(e.squeeze(-1).squeeze(-1)) + if get_feat: + return out, e.squeeze(-1).squeeze(-1) + return out diff --git a/WILDS/src/models/poverty.py b/WILDS/src/models/poverty.py new file mode 100644 index 0000000..1d82056 --- /dev/null +++ b/WILDS/src/models/poverty.py @@ -0,0 +1,87 @@ +import os +from copy import deepcopy +import torch +import torch.nn as nn +import torchvision.transforms as transforms +from torch.utils.data import DataLoader +from wilds.common.data_loaders import get_eval_loader +from wilds.datasets.poverty_dataset import PovertyMapDataset + +from .resnet_multispectral import ResNet18 +from .datasets import Poverty_Batched_Dataset + +IMG_HEIGHT = 224 +NUM_CLASSES = 1 + +def initialize_poverty_train_transform(): + """Adapted from the Wilds library, available at: https://github.com/p-lambda/wilds""" + transforms_ls = [ + transforms.ToPILImage(), + transforms.RandomHorizontalFlip(), + transforms.RandomVerticalFlip(), + transforms.ColorJitter(brightness=0.8, contrast=0.8, saturation=0.8, hue=0.1), + transforms.ToTensor()] + rgb_transform = transforms.Compose(transforms_ls) + + def transform_rgb(img): + # bgr to rgb and back to bgr + img[:3] = rgb_transform(img[:3][[2,1,0]])[[2,1,0]] + return img + transform = transforms.Lambda(lambda x: transform_rgb(x)) + return transform + + +class Model(nn.Module): + def __init__(self, args, weights): + super(Model, self).__init__() + self.num_classes = NUM_CLASSES + + self.enc = ResNet18(num_classes=1, num_channels=8) + if weights is not None: + self.load_state_dict(deepcopy(weights)) + + def reset_weights(self, weights): + self.load_state_dict(deepcopy(weights)) + + + + @staticmethod + def getDataLoaders(args, device): + kwargs = {'no_nl': False, + 'fold': ['A', 'B', 'C', 'D', 'E'][args.seed], + # 'oracle_training_set': False, + 'use_ood_val': True} + dataset = PovertyMapDataset(root_dir=os.path.join(args.data_dir, 'wilds'), + download=True, **kwargs) + # get all train data + # transform = initialize_poverty_train_transform() + # In latest wilds example code, no transfromation is applied + transform = transforms.Compose([]) + + train_sets = Poverty_Batched_Dataset(dataset, 'train', args.batch_size, transform, drop_last=not args.no_drop_last) + datasets = {} + for split in dataset.split_dict: + if split != 'train': + datasets[split] = dataset.get_subset(split, transform=transform) + print(split, len(datasets[split])) + + kwargs = {'num_workers': args.num_workers, 'pin_memory': True, 'drop_last': False, 'shuffle': True} \ + if device.type == "cuda" else {} + train_loaders = DataLoader(train_sets, batch_size=args.batch_size, **kwargs) + tv_loaders = {} + for split, sep_dataset in datasets.items(): + tv_loaders[split] = get_eval_loader('standard', sep_dataset, batch_size=256, num_workers=args.num_workers) + return train_loaders, tv_loaders, dataset + + def forward(self, x, get_feat=False,frozen_mode=False): + if frozen_mode: + self.enc.eval() + self.enc.fc.train() + with torch.no_grad(): + pred, feat = self.enc(x, with_feats=True) + pred = self.enc.fc(feat) + if get_feat: + return pred, feat + else: + return pred + return self.enc(x,with_feats=get_feat) diff --git a/WILDS/src/models/resnet_multispectral.py b/WILDS/src/models/resnet_multispectral.py new file mode 100644 index 0000000..d2f6b19 --- /dev/null +++ b/WILDS/src/models/resnet_multispectral.py @@ -0,0 +1,248 @@ +# Adapted from the WILDS library +import torch +import torch.nn as nn + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + __constants__ = ['downsample'] + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + __constants__ = ['downsample'] + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, + groups=1, width_per_group=64, replace_stride_with_dilation=None, + norm_layer=None, num_channels=3): + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(num_channels, self.inplanes, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, + dilate=replace_stride_with_dilation[2]) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + if num_classes is not None: + self.fc = nn.Linear(512 * block.expansion, num_classes) + self.d_out = num_classes + else: + self.fc = None + self.d_out = 512 * block.expansion + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def get_feats(self, x, layer=4): + # See note [TorchScript super()] + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + if layer == 1: + return x + x = self.layer2(x) + if layer == 2: + return x + x = self.layer3(x) + if layer == 3: + return x + x = self.layer4(x) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + return x + + + def _forward_impl(self, x, with_feats=False): + x = feats = self.get_feats(x) + if self.fc is not None: + x = self.fc(feats) + + if with_feats: + return x, feats + else: + return x + + def forward(self, x, with_feats=False): + return self._forward_impl(x, with_feats) + + +class ResNet18(ResNet): + def __init__(self, num_classes=10, num_channels=3): + super().__init__( + BasicBlock, [2, 2, 2, 2], num_classes=num_classes, num_channels=num_channels) + +class ResNet34(ResNet): + def __init__(self, num_classes=10, num_channels=3): + super().__init__( + BasicBlock, [3, 4, 6, 3], num_classes=num_classes, num_channels=num_channels) + +class ResNet50(ResNet): + def __init__(self, num_classes=10, num_channels=3): + super().__init__( + Bottleneck, [3, 4, 23, 3], num_classes=num_classes, num_channels=num_channels) + +class ResNet101(ResNet): + def __init__(self, num_classes=10, num_channels=3): + super().__init__( + Bottleneck, [3, 4, 23, 3], num_classes=num_classes, num_channels=num_channels) + +class ResNet152(ResNet): + def __init__(self, num_classes=10, num_channels=3): + super().__init__( + Bottleneck, [3, 8, 36, 3], num_classes=num_classes, num_channels=num_channels) + + +DEPTH_TO_MODEL = {18: ResNet18, 34: ResNet34, 50: ResNet50, 101: ResNet101, 152: ResNet152} + diff --git a/WILDS/src/models/rxrx.py b/WILDS/src/models/rxrx.py new file mode 100644 index 0000000..6f9abbd --- /dev/null +++ b/WILDS/src/models/rxrx.py @@ -0,0 +1,110 @@ +import numpy as np +import os +from copy import deepcopy + +import torch +import torch.nn as nn +import torchvision.transforms as transforms +import torchvision.transforms.functional as TF +from torch.utils.data import DataLoader +from torchvision.models import resnet50 +from wilds.common.data_loaders import get_eval_loader +try: + from wilds.datasets.rxrx1_dataset import RxRx1Dataset +except Exception as e: + print("RxRx1 Dataset not supported") + +from .datasets import GeneralWilds_Batched_Dataset + +IMG_HEIGHT = 224 +NUM_CLASSES = 1139 + +def initialize_rxrx1_transform(is_training): + def standardize(x: torch.Tensor) -> torch.Tensor: + mean = x.mean(dim=(1, 2)) + std = x.std(dim=(1, 2)) + std[std == 0.] = 1. + return TF.normalize(x, mean, std) + t_standardize = transforms.Lambda(lambda x: standardize(x)) + + angles = [0, 90, 180, 270] + def random_rotation(x: torch.Tensor) -> torch.Tensor: + angle = angles[torch.randint(low=0, high=len(angles), size=(1,))] + if angle > 0: + x = TF.rotate(x, angle) + return x + t_random_rotation = transforms.Lambda(lambda x: random_rotation(x)) + + if is_training: + transforms_ls = [ + t_random_rotation, + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + t_standardize, + ] + else: + transforms_ls = [ + transforms.ToTensor(), + t_standardize, + ] + transform = transforms.Compose(transforms_ls) + return transform + +class Model(nn.Module): + def __init__(self, args, weights): + super(Model, self).__init__() + self.num_classes = NUM_CLASSES + resnet = resnet50(pretrained=True) + self.enc = nn.Sequential(*list(resnet.children())[:-1]) # remove fc layer + self.fc = nn.Linear(2048, self.num_classes) + if weights is not None: + self.load_state_dict(deepcopy(weights)) + + def reset_weights(self, weights): + self.load_state_dict(deepcopy(weights)) + + + @staticmethod + def getDataLoaders(args, device): + dataset = RxRx1Dataset(root_dir=os.path.join(args.data_dir, 'wilds'), download=True) + + # initialize transform + train_transform = initialize_rxrx1_transform(is_training=True) + eval_transform = initialize_rxrx1_transform(is_training=False) + + # get all train data + train_data = dataset.get_subset('train', transform=train_transform) + + # separate into subsets by distribution + train_sets = GeneralWilds_Batched_Dataset(train_data, args.batch_size, domain_idx=1, drop_last=not args.no_drop_last) + # take subset of test and validation, making sure that only labels appeared in train + # are included + datasets = {} + for split in dataset.split_dict: + if split != 'train': + datasets[split] = dataset.get_subset(split, transform=eval_transform) + + # get the loaders + kwargs = {'num_workers': args.num_workers, 'pin_memory': True, 'drop_last': False} \ + if device.type == "cuda" else {} + train_loaders = DataLoader(train_sets, batch_size=args.batch_size, shuffle=True, **kwargs) + tv_loaders = {} + for split, sep_dataset in datasets.items(): + tv_loaders[split] = get_eval_loader('standard', sep_dataset, batch_size=256, num_workers=args.num_workers) + return train_loaders, tv_loaders, dataset + + def forward(self, x,get_feat=False,frozen_mode=False): + # x = x.expand(-1, 3, -1, -1) # reshape MNIST from 1x32x32 => 3x32x32 + if len(x.shape) == 3: + x.unsqueeze_(0) + if frozen_mode: + self.enc.eval() + self.fc.train() + with torch.no_grad(): + e = self.enc(x) + else: + e = self.enc(x) + out = self.fc(e.squeeze(-1).squeeze(-1)) + if get_feat: + return out, e.squeeze(-1).squeeze(-1) + return out diff --git a/WILDS/src/pair.py b/WILDS/src/pair.py new file mode 100644 index 0000000..947f69d --- /dev/null +++ b/WILDS/src/pair.py @@ -0,0 +1,452 @@ +import copy +import imp +from pickletools import optimize +import torch +from torch.optim.optimizer import Optimizer, required +from torch.autograd import Variable +import traceback +import torch.nn.functional as F +from torch.optim import SGD + +class PAIR(Optimizer): + r""" + Implements Pareto Invariant Risk Minimization (PAIR) algorithm. + It is proposed in the ICLR 2023 paper + `Pareto Invariant Risk Minimization: Towards Mitigating the Optimization Dilemma in Out-of-Distribution Generalization` + https://arxiv.org/abs/2206.07766 . + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining parameter groups + optimizer (pytorch optim): inner optimizer + balancer (str, optional): indicates which MOO solver to use + preference (list[float], optional): preference of the objectives + eps (float, optional): precision up to the preference (default: 1e-04) + coe (float, optional): L2 regularization weight onto the yielded objective weights (default: 0) + """ + + def __init__(self, params, optimizer=required, balancer="EPO",preference=[1e-8,1-1e-8], eps=1e-4, coe=0, verbose=False): + # TODO: parameter validty checking + if eps < 0.0: + raise ValueError("Invalid epsilon value: {}".format(eps)) + for _pp in preference: + if _pp < 0.0: + raise ValueError("Invalid preference: {}".format(preference)) + + self.optimizer = optimizer + if type(preference) == list: + preference = np.array(preference) + self.preference = preference + + self.descent = 0 + self.losses = [] + self.params = params + if balancer.lower() == "epo": + self.balancer = EPO(len(self.preference),self.preference,eps=eps,coe=coe,verbose=verbose) + elif balancer.lower() == "sepo": + self.balancer = SEPO(len(self.preference),self.preference,eps=eps,coe=coe,verbose=verbose) + else: + raise NotImplementedError("Nrot supported balancer") + defaults = dict(balancer=balancer, preference=self.preference, eps=eps) + super(PAIR, self).__init__(params, defaults) + + + def __setstate__(self, state): + super(PAIR, self).__setstate__(state) + + def set_losses(self,losses): + self.losses = losses + + def step(self, closure=None): + r"""Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + if len(self.losses) == 0: + self.optimizer.step() + alphas = np.zeros(len(self.preference)) + alphas[0] = 1 + return -1, 233, alphas + else: + losses = self.losses + if closure is not None: + losses = closure() + + pair_loss = 0 + mu_rl = 0 + alphas = 0 + + grads = [] + for cur_loss in losses: + self.optimizer.zero_grad() + cur_loss.backward(retain_graph=True) + cur_grad = [] + for group in self.param_groups: + for param in group['params']: + if param.grad is not None: + cur_grad.append(Variable(param.grad.data.clone().flatten(), requires_grad=False)) + grads.append(torch.cat(cur_grad)) + + G = torch.stack(grads) + if self.get_grad_sim: + grad_sim = get_grad_sim(G,losses,preference=self.preference,is_G=True) + GG = G @ G.T + moo_losses = np.stack([l.item() for l in losses]) + reset_optimizer = False + try: + # Calculate the alphas from the LP solver + alpha, mu_rl, reset_optimizer = self.balancer.get_alpha(moo_losses, G=GG.cpu().numpy(), C=True,get_mu=True) + if self.balancer.last_move == "dom": + self.descent += 1 + print("dom") + except Exception as e: + print(traceback.format_exc()) + alpha = None + if alpha is None: # A patch for the issue in cvxpy + alpha = self.preference / np.sum(self.preference) + + scales = torch.from_numpy(alpha).float().to(losses[-1].device) + pair_loss = scales.dot(losses) + if reset_optimizer: + self.optimizer.param_groups[0]["lr"]/=5 + # self.optimizer = torch.optim.Adam(self.params,lr=self.optimizer.param_groups[0]["lr"]/5) + self.optimizer.zero_grad() + pair_loss.backward() + self.optimizer.step() + + return pair_loss, moo_losses, mu_rl, alpha + + + +import numpy as np +import cvxpy as cp +import cvxopt + +class EPO(object): + r""" + The original EPO solver proposed in ICML2020 + https://proceedings.mlr.press/v119/mahapatra20a.html + """ + def __init__(self, m, r, eps=1e-4, coe=0, verbose=False): + # self.solver = cp.GLPK + self.solver = cp.GUROBI + # cvxopt.glpk.options["msg_lev"] = "GLP_MSG_OFF" + self.m = m + self.r = r/np.sum(r) + self.eps = eps + self.last_move = None + self.a = cp.Parameter(m) # Adjustments + self.C = cp.Parameter((m, m)) # C: Gradient inner products, G^T G + self.Ca = cp.Parameter(m) # d_bal^TG + self.rhs = cp.Parameter(m) # RHS of constraints for balancing + + self.alpha = cp.Variable(m) # Variable to optimize + self.last_alpha = np.zeros_like(r)-1 + self.coe = coe + + obj_bal = cp.Maximize(self.alpha @ self.Ca) # objective for balance + constraints_bal = [self.alpha >= 0, cp.sum(self.alpha) == 1, # Simplex + self.C @ self.alpha >= self.rhs] + self.prob_bal = cp.Problem(obj_bal, constraints_bal) # LP balance + + obj_dom = cp.Maximize(cp.sum(self.alpha @ self.C)-self.coe*cp.sum_squares(self.alpha-self.last_alpha)) # obj for descent + constraints_dom = [self.alpha >= 0, cp.sum(self.alpha) == 1, # Restrict + self.alpha @ self.Ca >= -cp.neg(cp.max(self.Ca)), + self.C @ self.alpha >= 0] + self.prob_dom = cp.Problem(obj_dom, constraints_dom) # LP dominance + + + self.gamma = 0 # Stores the latest Optimum value of the LP problem + self.mu_rl = 0 # Stores the latest non-uniformity + + self.verbose = verbose + + + def get_alpha(self, l, G, r=None, C=False, get_mu=False): + """calculate weights for all objectives given the gradient information + + Args: + l (ndarray): the values of objective losses + G (ndarray): inner products of the gradients of each objective loss w.r.t. params + r (ndarray, optional): adopt this preference if specified + C (bool, optional): True if the input gradients are inner products + get_mu (bool, optional): return detailed information if True. + + Returns: + alpha: the objective weights + mu_rl (optional): the optimal value to the LP + reset_optimizer (optional): whether to reset the inner optimizer + """ + r = self.r if r is None else r + assert len(l) == len(G) == len(r) == self.m, "length != m" + l_hat, rl, self.mu_rl, self.a.value = self.adjustments(l, r) + reset_optimizer = False + self.C.value = G if C else G @ G.T + self.Ca.value = self.C.value @ self.a.value + + if self.last_alpha.sum() is None: + self.last_alpha = np.array(r) + if self.mu_rl > self.eps: + J = self.Ca.value > 0 + J_star_idx = np.where(rl == np.max(rl))[0] + self.rhs.value = self.Ca.value.copy() + # it's equivalent to setting no constraints to objectives in J + # as maximize alpha^TCa would trivially satisfy the non-negativity + self.rhs.value[J] = -np.inf + self.rhs.value[J_star_idx] = 0 + + self.gamma = self.prob_bal.solve(solver=self.solver, verbose=False) + self.last_move = "bal" + + if self.verbose: + test_alpha = np.ones_like(self.a.value)/self.m + print(self.last_alpha,self.C.value,self.Ca.value,self.rhs.value) + print(self.gamma,test_alpha@self.Ca.value, self.alpha.value @ self.C.value) + print(self.gamma,self.coe*np.linalg.norm(self.alpha.value-self.last_alpha)**2) + + else: + self.gamma = self.prob_dom.solve(solver=self.solver, verbose=False) + self.last_move = "dom" + self.last_alpha = np.array(self.alpha.value) + + if get_mu: + return self.alpha.value, self.mu_rl, reset_optimizer + + return self.alpha.value + + + def mu(self, rl, normed=False): + if len(np.where(rl < 0)[0]): + raise ValueError(f"rl<0 \n rl={rl}") + return None + m = len(rl) + l_hat = rl if normed else rl / rl.sum() + eps = np.finfo(rl.dtype).eps + l_hat = l_hat[l_hat > eps] + return np.sum(l_hat * np.log(l_hat * m)) + + + def adjustments(self, l, r=1): + m = len(l) + rl = r * l + + l_hat = rl / rl.sum() + mu_rl = self.mu(l_hat, normed=True) + uniformity_div = np.log(l_hat * m) - mu_rl + div_r = np.array(r) + a = div_r * uniformity_div + + if self.verbose: + print(a, rl, div_r, uniformity_div, l_hat, a.dot(l)) + return l_hat, rl, mu_rl, a + + +class SEPO(object): + r""" + A smoothed variant of EPO, with two adjustments for unrobust OOD objectives: + a) normalization: unrobust OOD objective can yield large loss values that dominate the solutions of the LP, + hence we adopt the normalized OOD losses in the LP to resolve the issue + b) regularization: solutions yielded by the LP can change sharply among steps, especially when switching descending phases + hence we incorporate a L2 regularization in the LP to resolve the issue + """ + def __init__(self, m, r, eps=1e-4, coe=0, verbose=False): + # self.solver = cp.GLPK + self.solver = cp.GUROBI + # cvxopt.glpk.options["msg_lev"] = "GLP_MSG_OFF" + self.m = m + self.r = r/np.sum(r) + self.eps = eps + self.last_move = None + self.a = cp.Parameter(m) # Adjustments + self.C = cp.Parameter((m, m)) # C: Gradient inner products, G^T G + self.Ca = cp.Parameter(m) # d_bal^TG + self.rhs = cp.Parameter(m) # RHS of constraints for balancing + + self.alpha = cp.Variable(m) # Variable to optimize + self.last_alpha = np.zeros_like(r)-1 + self.coe = coe + + obj_bal = cp.Maximize(self.alpha @ self.Ca-self.coe*cp.sum_squares(self.alpha-self.last_alpha)) # objective for balance + obj_bal_orig = cp.Maximize(self.alpha @ self.Ca) # objective for balance + constraints_bal = [self.alpha >= 0, cp.sum(self.alpha) == 1, # Simplex + self.C @ self.alpha >= self.rhs] + self.prob_bal = cp.Problem(obj_bal, constraints_bal) # LP balance + self.prob_bal_orig = cp.Problem(obj_bal_orig, constraints_bal) # LP balance + + obj_dom = cp.Maximize(cp.sum(self.alpha @ self.C)-self.coe*cp.sum_squares(self.alpha-self.last_alpha)) # obj for descent + constraints_res = [self.alpha >= 0, cp.sum(self.alpha) == 1, # Restrict + self.alpha @ self.Ca >= -cp.neg(cp.max(self.Ca)), + self.C @ self.alpha >= 0] + constraints_rel = [self.alpha >= 0, cp.sum(self.alpha) == 1, # Relaxed + self.C @ self.alpha >= 0] + self.prob_dom = cp.Problem(obj_dom, constraints_res) # LP dominance + self.prob_rel = cp.Problem(obj_dom, constraints_rel) # LP dominance + + self.gamma = 0 # Stores the latest Optimum value of the LP problem + self.mu_rl = 0 # Stores the latest non-uniformity + + self.verbose = verbose + + + def get_alpha(self, l, G, r=None, C=False, get_mu=False): + """calculate weights for all objectives given the gradient information + + Args: + l (ndarray): the values of objective losses + G (ndarray): inner products of the gradients of each objective loss w.r.t. params + r (ndarray, optional): adopt this preference if specified + C (bool, optional): True if the input gradients are inner products + get_mu (bool, optional): return detailed information if True. + + Returns: + alpha: the objective weights + mu_rl (optional): the optimal value to the LP + reset_optimizer (optional): whether to reset the inner optimizer + """ + r = self.r if r is None else r + assert len(l) == len(G) == len(r) == self.m, "length != m" + l_hat, rl, self.mu_rl, self.a.value = self.adjustments(l, r) + reset_optimizer = False + if self.mu_rl <= 0.1: + self.r[0]=max(1e-15,self.r[0]/10000) + self.r = self.r/self.r.sum() + print(f"pua preference {self.r}") + l_hat, rl, self.mu_rl, self.a.value = self.adjustments(l, r) + + + a_norm = np.linalg.norm(self.a.value) + G_norm = np.linalg.norm(G,axis=1) + Ga = G.T @ self.a.value + self.C.value = G if C else G/np.expand_dims(G_norm,axis=1) @ G.T/a_norm + self.Ca.value = G/np.expand_dims(G_norm,axis=1) @ Ga.T/a_norm + + if self.last_alpha.sum() is None: + self.last_alpha = np.array(r) + if self.mu_rl > self.eps: + J = self.Ca.value > 0 + + J_star_idx = np.where(rl == np.max(rl))[0] + self.rhs.value = self.Ca.value.copy() + # it's equivalent to setting no constraints to objectives in J + # as maximize alpha^TCa would trivially satisfy the non-negativity + self.rhs.value[J] = -np.inf # Not efficient; but works. + self.rhs.value[J_star_idx] = max(0,self.Ca.value[J_star_idx]/2) + + if self.last_alpha.sum()<0: + self.gamma = self.prob_bal_orig.solve(solver=self.solver, verbose=False) + else: + self.gamma = self.prob_bal.solve(solver=self.solver, verbose=False) + + self.last_move = "bal" + + if self.verbose: + test_alpha = np.ones_like(self.a.value)/self.m + print(self.last_alpha,self.C.value,self.Ca.value,self.rhs.value) + print(self.gamma,test_alpha@self.Ca.value, self.alpha.value @ self.C.value) + print(self.gamma,self.coe*np.linalg.norm(self.alpha.value-self.last_alpha)**2) + else: + self.gamma = self.prob_dom.solve(solver=self.solver, verbose=False) + self.last_move = "dom" + self.last_alpha = np.array(self.alpha.value) + + if get_mu: + return self.alpha.value, self.mu_rl, reset_optimizer + + return self.alpha.value + + + def mu(self, rl, normed=False): + if len(np.where(rl < 0)[0]): + raise ValueError(f"rl<0 \n rl={rl}") + return None + m = len(rl) + l_hat = rl if normed else rl / rl.sum() + eps = np.finfo(rl.dtype).eps + l_hat = l_hat[l_hat > eps] + return np.sum(l_hat * np.log(l_hat * m)) + + + def adjustments(self, l, r=1): + m = len(l) + rl = r * l + + l_hat = rl / rl.sum() + mu_rl = self.mu(l_hat, normed=True) + uniformity_div = np.log(l_hat * m) - mu_rl + div_r = np.array(r) + a = div_r * uniformity_div + + if self.verbose: + print(a, rl, div_r, uniformity_div, l_hat, a.dot(l)) + return l_hat, rl, mu_rl, a + + +def getNumParams(params): + numParams, numTrainable = 0, 0 + for param in params: + npParamCount = np.prod(param.data.shape) + numParams += npParamCount + if param.requires_grad: + numTrainable += npParamCount + return numParams, numTrainable + +def get_kl_div(losses, preference): + pair_score = losses.dot(preference) + return pair_score + +def pair_selection(losses,val_accs,test_accs,anneal_iter=0,val_acc_bar=-1,pood=None): + + losses = losses[anneal_iter:] + val_accs = val_accs[anneal_iter:] + test_accs = test_accs[anneal_iter:] + if val_acc_bar < 0: + val_acc_bar = (np.max(val_accs)-np.min(val_accs))*0.05+np.min(val_accs) + + try: + preference_base = 10**max(-12,int(np.log10(np.mean(losses[:,-1]))-2)) + except Exception as e: + print(e) + preference_base = 1e-12 + if len(losses[0])==2: + preference = np.array([preference_base,1]) + elif len(losses[0])==4: + preference = np.array([1e-12,1e-4,1e-2,1]) + elif len(losses[0])==5: + preference = np.array([1e-12,1e-6,1e-4,1e-2,1]) + else: + preference = np.array([1e-12,1e-2,1]) + + if pood is not None: + preference = pood + print(f"Use preference: {preference}, validation acc bar: {val_acc_bar}") + + pair_score = np.array([get_kl_div(l,preference) if a>=val_acc_bar else 1e9 for (a,l) in zip(val_accs,losses)]) + sel_idx = np.argmin(pair_score) + return sel_idx+anneal_iter, val_accs[sel_idx], test_accs[sel_idx] + +def get_grad_sim(params,losses,preference=None,is_G=False,cosine=True): + num_ood_losses = len(losses)-1 + if is_G: + G = params + else: + pesudo_opt = SGD(params,lr=1e-6) + grads = [] + for cur_loss in losses: + pesudo_opt.zero_grad() + cur_loss.backward(retain_graph=True) + cur_grad = [] + for param in params: + if param.grad is not None: + cur_grad.append(Variable(param.grad.data.clone().flatten(), requires_grad=False)) + # print(torch.cat(cur_grad).sum()) + grads.append(torch.cat(cur_grad)) + G = torch.stack(grads) + if cosine: + G = F.normalize(G,dim=1) + GG = (G @ G.T).cpu() + if preference is not None: + G_weights = preference[1:]/np.sum(preference[1:]) + else: + G_weights = np.ones(num_ood_losses)/num_ood_losses + grad_sim =G_weights.dot(GG[0,1:]) + return grad_sim.item() diff --git a/WILDS/src/scheduler.py b/WILDS/src/scheduler.py new file mode 100644 index 0000000..b0ddc02 --- /dev/null +++ b/WILDS/src/scheduler.py @@ -0,0 +1,94 @@ +from torch.optim.lr_scheduler import LambdaLR, ReduceLROnPlateau, StepLR, CosineAnnealingLR, MultiStepLR + +def initialize_scheduler(config, optimizer, n_train_steps): + # construct schedulers + if config.scheduler is None: + return None + elif config.scheduler == 'linear_schedule_with_warmup': + from transformers import get_linear_schedule_with_warmup + scheduler = get_linear_schedule_with_warmup( + optimizer, + num_training_steps=n_train_steps, + **config.scheduler_kwargs) + step_every_batch = True + use_metric = False + elif config.scheduler == 'cosine_schedule_with_warmup': + from transformers import get_cosine_schedule_with_warmup + scheduler = get_cosine_schedule_with_warmup( + optimizer, + num_training_steps=n_train_steps, + **config.scheduler_kwargs) + step_every_batch = True + use_metric = False + elif config.scheduler=='ReduceLROnPlateau': + assert config.scheduler_metric_name, f'scheduler metric must be specified for {config.scheduler}' + scheduler = ReduceLROnPlateau( + optimizer, + **config.scheduler_kwargs) + step_every_batch = False + use_metric = True + elif config.scheduler == 'StepLR': + scheduler = StepLR(optimizer, **config.scheduler_kwargs) + step_every_batch = False + use_metric = False + elif config.scheduler == 'FixMatchLR': + scheduler = LambdaLR( + optimizer, + lambda x: (1.0 + 10 * float(x) / n_train_steps) ** -0.75 + ) + step_every_batch = True + use_metric = False + elif config.scheduler == 'MultiStepLR': + scheduler = MultiStepLR(optimizer, **config.scheduler_kwargs) + step_every_batch = False + use_metric = False + else: + raise ValueError(f'Scheduler: {config.scheduler} not supported.') + + # add an step_every_batch field + scheduler.step_every_batch = step_every_batch + scheduler.use_metric = use_metric + return scheduler + +def step_scheduler(scheduler, metric=None): + if isinstance(scheduler, ReduceLROnPlateau): + assert metric is not None + scheduler.step(metric) + else: + scheduler.step() + +class LinearScheduleWithWarmupAndThreshold(): + """ + Linear scheduler with warmup and threshold for non lr parameters. + Parameters is held at 0 until some T1, linearly increased until T2, and then held + at some max value after T2. + Designed to be called by step_scheduler() above and used within Algorithm class. + Args: + - last_warmup_step: aka T1. for steps [0, T1) keep param = 0 + - threshold_step: aka T2. step over period [T1, T2) to reach param = max value + - max value: end value of the param + """ + def __init__(self, max_value, last_warmup_step=0, threshold_step=1, step_every_batch=False): + self.max_value = max_value + self.T1 = last_warmup_step + self.T2 = threshold_step + assert (0 <= self.T1) and (self.T1 < self.T2) + + # internal tracker of which step we're on + self.current_step = 0 + self.value = 0 + + # required fields called in Algorithm when stepping schedulers + self.step_every_batch = step_every_batch + self.use_metric = False + + def step(self): + """This function is first called AFTER step 0, so increment first to set value for next step""" + self.current_step += 1 + if self.current_step < self.T1: + self.value = 0 + elif self.current_step < self.T2: + self.value = (self.current_step - self.T1) / (self.T2 - self.T1) * self.max_value + else: + self.value = self.max_value + diff --git a/WILDS/src/utils.py b/WILDS/src/utils.py new file mode 100644 index 0000000..3cd1936 --- /dev/null +++ b/WILDS/src/utils.py @@ -0,0 +1,262 @@ +import os +import random +import shutil +import sys +import operator +from numbers import Number +from collections import OrderedDict + +import torch +from torch import nn +from torch.utils.data import Dataset +import numpy as np + +def set_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + # torch.cuda.manual_seed_all(seed) # canceled as we only use one gpu + +def get_preference(preference_choice): + # erm, irm, v-rex + if preference_choice==1: + r = 1e-12 + r2 = 1e10 + r_l2 = r*r2 + n_tasks = 1+2 + preference = np.array([r,r_l2,(1-r-r_l2)]) + preference = np.array([r,1-r-r2*r,r2*r]) + preference = np.array([r,(1-r)/2,(1-r)/2]) + elif preference_choice==2: + r = 1e-12 + r2 = 1e10 + r_l2 = r*r2 + n_tasks = 1+2 + # preference = np.array([r,r_l2,(1-r-r_l2)]) + preference = np.array([r,1-r-r2*r,r2*r]) + # preference = np.array([r,(1-r)/2,(1-r)/2]) + elif preference_choice==3: + r = 1e-12 + r2 = 1e8 + r_l2 = r*r2 + n_tasks = 1+2 + preference = np.array([r,r_l2,(1-r-r_l2)]) + # preference = np.array([r,1-r-r2*r,r2*r]) + # preference = np.array([r,(1-r)/2,(1-r)/2]) + elif preference_choice==4: + r = 1e-12 + r2 = 1e6 + r_l2 = r*r2 + n_tasks = 1+2 + preference = np.array([r,r_l2,(1-r-r_l2)]) + elif preference_choice==11: + r = 1e-6 + r2 = 1e4 + r_l2 = r*r2 + preference = np.array([r,r_l2,(1-r-r_l2)]) + elif preference_choice==111: + r = 1e-3 + r2 = 1e2 + r_l2 = r*r2 + preference = np.array([r,r_l2,(1-r-r_l2)]) + elif preference_choice==22: + r = 1e-6 + r2 = 1e4 + r_l2 = r*r2 + preference = np.array([r,(1-r-r_l2),r_l2]) + elif preference_choice==222: + r = 1e-3 + r2 = 1e2 + r_l2 = r*r2 + preference = np.array([r,(1-r-r_l2),r_l2]) + else: + r = 1e-12 + r2 = 1e10 + r_l2 = r*r2 + # preference = np.array([r,r,1-2*r-r2*r,r2*r]) + preference = np.array([r]+[r_l2,(1-r-r_l2)]) + + return preference + +# https://stackoverflow.com/questions/14906764/how-to-redirect-stdout-to-both-file-and-console-with-scripting +class Logger(object): + def __init__(self, filename): + self.terminal = sys.stdout + self.log = open(filename, "a") + + def write(self, message): + self.terminal.write(message) + self.log.write(message) + self.log.flush() + + def flush(self): + # this flush method is needed for python 3 compatibility. + # this handles the flush command by doing nothing. + # you might want to specify some extra behavior here. + pass + + +# Functions +def save_vars(vs, filepath): + """ + Saves variables to the given filepath in a safe manner. + """ + filepath = filepath + if os.path.exists(filepath): + shutil.copyfile(filepath, '{}.old'.format(filepath)) + torch.save(vs, filepath) + + +def save_model(model, filepath): + """ + To load a saved model, simply use + `model.load_state_dict(torch.load('path-to-saved-model'))`. + """ + save_vars(model.state_dict(), filepath) + + +def unpack_data(data, device): + return data[0].to(device), data[1].to(device) + + +class Subset(Dataset): + r""" + Subset of a dataset at specified indices. + + Arguments: + dataset (Dataset): The whole Dataset + indices (sequence): Indices in the whole set selected for subset + """ + def __init__(self, dataset, indices): + self.dataset = dataset + self.indices = indices + if hasattr(dataset, 'images'): + self.images = dataset.images[indices] + self.latents = dataset.latents[indices, :] + else: + self.targets = dataset.targets[indices] + self.writers = dataset.domains[indices] + self.data = [dataset.data[i] for i in indices] + + def __getitem__(self, idx): + return self.dataset[self.indices[idx]] + + def __len__(self): + return len(self.indices) + + +def sample_domains(train_loader, N=1, stratified=True): + """ + Sample N domains available in the train loader. + """ + Ls = [] + for tl in train_loader.dataset.batches_left.values(): + Ls.append(max(tl, 0)) if stratified else Ls.append(min(tl, 1)) + + positions = range(len(Ls)) + indices = [] + while True: + needed = N - len(indices) + if not needed: + break + for i in random.choices(positions, Ls, k=needed): + if Ls[i]: + Ls[i] = 0.0 + indices.append(i) + return torch.tensor(indices) + + +def save_best_model(model, runPath, agg, args, pretrain=False): + if args.dataset == 'fmow' or agg['val_stat'][-1] > max(agg['val_stat'][:-1]) or pretrain: + print(f"model saved: {runPath}") + save_model(model, f'{runPath}/model.rar') + save_vars(agg, f'{runPath}/losses.rar') + + +def single_class_predict_fn(yhat): + _, predicted = torch.max(yhat.data, 1) + return predicted + + +def return_predict_fn(dataset): + return { + 'fmow': single_class_predict_fn, + 'camelyon': single_class_predict_fn, + 'poverty': lambda yhat: yhat, + 'iwildcam': single_class_predict_fn, + 'amazon': single_class_predict_fn, + 'civil': single_class_predict_fn, + 'cdsprites': single_class_predict_fn, + 'rxrx': single_class_predict_fn, + }[dataset] + +def return_criterion(dataset): + return { + 'fmow': nn.CrossEntropyLoss(), + 'camelyon': nn.CrossEntropyLoss(), + 'poverty': nn.MSELoss(), + 'iwildcam': nn.CrossEntropyLoss(), + 'amazon': nn.CrossEntropyLoss(), + 'civil': nn.CrossEntropyLoss(), + 'cdsprites': nn.CrossEntropyLoss(), + 'rxrx': nn.CrossEntropyLoss(), + }[dataset] + + +class ParamDict(OrderedDict): + """A dictionary where the values are Tensors, meant to represent weights of + a model. This subclass lets you perform arithmetic on weights directly.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, *kwargs) + + def _prototype(self, other, op): + if isinstance(other, Number): + return ParamDict({k: op(v, other) for k, v in self.items()}) + elif isinstance(other, dict): + return ParamDict({k: op(self[k], other[k]) for k in self}) + else: + raise NotImplementedError + + def __add__(self, other): + return self._prototype(other, operator.add) + + def __rmul__(self, other): + return self._prototype(other, operator.mul) + + __mul__ = __rmul__ + + def __neg__(self): + return ParamDict({k: -v for k, v in self.items()}) + + def __rsub__(self, other): + # a- b := a + (-b) + return self.__add__(other.__neg__()) + + __sub__ = __rsub__ + + def __truediv__(self, other): + return self._prototype(other, operator.truediv) + + +def fish_step(meta_weights, inner_weights, meta_lr): + meta_weights, weights = ParamDict(meta_weights), ParamDict(inner_weights) + if 'model.' in list(meta_weights.keys())[0]: + new_meta_weights = {} + new_weights = {} + for k,v in meta_weights.items(): + if 'model.' in k: + new_meta_weights[k[6:]] = v + new_weights[k[6:]] = weights[k] + else: + new_meta_weights[k] = v + new_weights[k] = weights[k] + meta_weights = ParamDict(new_meta_weights) + weights = ParamDict(new_weights) + else: + print(list(meta_weights.keys())[0]) + meta_weights += meta_lr * sum([weights - meta_weights], 0 * meta_weights) + return meta_weights