Outer Loop Backward with Inner Loop no_grad #212
-
Outer Loop Backward with Inner Loop no_gradQuestionI encountered an issue where I defined a tensor However, in the outer loop, I have tried using Inner Loop Codedef get_implicit_model(solver):
class ImplicitModel(torchopt.nn.ImplicitMetaGradientModule,
linear_solve=solver):
def __init__(self, clser, atter, lamb, max_epochs=200, optimizer='Adam', lr=0.01,device=None,biggest_mask=None):
super().__init__()
self.register_meta_module('atter', atter)
object.__setattr__(self, 'clser', clser)
object.__setattr__(self, 'biggest_mask',biggest_mask)
self.lr = lr
self.lamb = lamb
self.max_epochs = max_epochs
self.optimizer = optimizer
self.device=device
def reset(self, z):
self.z0 = z
self.z = nn.Parameter(z.clone().detach_(), requires_grad=True)
def objective(self):
mask_ = self.atter(self.z)
#------------todo biggest_mask is here------------------------#
with torch.no_grad():
self.biggest_mask=save_biggest_mask(mask_.detach().clone(), self.device)
torchopt.stop_gradient(self.biggest_mask)
mask_ = mask_ * self.biggest_mask
#--------------------------------------------------------------#
pred = self.clser(mask_ * self.z)
celoss = malign_celoss(pred, self.target)
norm = torch.norm(self.z - self.z0)
return celoss + self.lamb * norm
@torch.enable_grad()
def solve(self, ):
optimizer = getattr(torch.optim, self.optimizer)(params=[self.z], lr=self.lr)
mask_=self.atter(self.z0)
pred=self.clser(mask_* self.z0)
mscore=nn.Softmax(dim=1)(pred)[:, 1]
self.target = 1 - (mscore >= 0.5).long().clone().detach()
for epoch in range(self.max_epochs):
optimizer.zero_grad()
#------------todo biggest_mask is here------------------------#
torchopt.stop_gradient(self.biggest_mask)
loss = self.objective()
#-------------------------------------------------------------#
norm = torch.norm(self.z - self.z0)
celoss = loss - self.lamb * norm
loss.backward(inputs=[self.z])
optimizer.step()
return
return ImplicitModel Outter Loop Code for batch_idx, (train_z, train_mask,_) in enumerate(train_loader):
rand_flip_(train_z, train_mask)
train_z, train_mask= train_z.to(device), train_mask.to(device)
optim_atter.zero_grad()
imp.reset(train_z)
#------------todo biggest_mask is here------------------------#
inner_log = imp.solve()#todo imp inner loop
torchopt.stop_gradient(imp.biggest_mask)
#-------------------------------------------------------------#
diff = imp.z - train_z
reward = inside_nod_loss(diff, train_mask)*args.ratio
punish = beyond_nod_loss(diff, train_mask)
alignloss = punish - reward
alignloss.backward()#todo outer loop loss backward
optim_atter.step() save_biggest_maskdef save_biggest_mask(A,device):
with torch.no_grad():
A_ = digitize(minmax_normalize(A, (2, 3))).to(torch.uint8).to(device)
B = torch.full_like(A_, fill_value=0.1, dtype=torch.float32).to(device)
A_ = cc_torch.connected_components_labeling(A_.squeeze(1)).unsqueeze(1)
for i in range(A_.size(0)):
A_i = A_[i, 0, :, :]
B_i = B[i, 0, :, :]
unique_values, counts = torch.unique(A_i, return_counts=True)
counts[torch.where(unique_values == 0)] = 0
biggest_value = unique_values[torch.argmax(counts).item()]
B_i[torch.where(A_i == biggest_value)] = 1
return B error messageTraceback (most recent call last):
File "/home/lijingwen/Projects/Counter_align/DDSM_hierarchical/hierachical_model/img2label/main_cross_avg3.py", line 161, in <module>
alignloss.backward()
File "/home/lijingwen/Miniconda3/envs/toy/lib/python3.8/site-packages/torch/_tensor.py", line 487, in backward
torch.autograd.backward(
File "/home/lijingwen/Miniconda3/envs/toy/lib/python3.8/site-packages/torch/autograd/__init__.py", line 197, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/home/lijingwen/Miniconda3/envs/toy/lib/python3.8/site-packages/torch/autograd/function.py", line 267, in apply
return user_fn(self, *args)
File "/home/lijingwen/Miniconda3/envs/toy/lib/python3.8/site-packages/torchopt/diff/implicit/decorator.py", line 341, in backward
vjps = _root_vjp(
File "/home/lijingwen/Miniconda3/envs/toy/lib/python3.8/site-packages/torchopt/diff/implicit/decorator.py", line 106, in _root_vjp
_, optimality_cond_vjp_fn, *_ = functorch.vjp(optimality_cond, solution)
File "/home/lijingwen/Miniconda3/envs/toy/lib/python3.8/site-packages/functorch/_src/eager_transforms.py", line 262, in vjp
return _vjp_with_argnums(func, *primals, has_aux=has_aux)
File "/home/lijingwen/Miniconda3/envs/toy/lib/python3.8/site-packages/functorch/_src/vmap.py", line 35, in fn
return f(*args, **kwargs)
File "/home/lijingwen/Miniconda3/envs/toy/lib/python3.8/site-packages/functorch/_src/eager_transforms.py", line 289, in _vjp_with_argnums
primals_out = func(*primals)
File "/home/lijingwen/Miniconda3/envs/toy/lib/python3.8/site-packages/torchopt/diff/implicit/decorator.py", line 104, in optimality_cond
return optimality_fn(solution, *args)
File "/home/lijingwen/Miniconda3/envs/toy/lib/python3.8/site-packages/torchopt/diff/implicit/nn/module.py", line 73, in _stateless_optimality_fn
return self.optimality(*input, **kwargs)
File "/home/lijingwen/Miniconda3/envs/toy/lib/python3.8/site-packages/torchopt/diff/implicit/nn/module.py", line 91, in optimality
flat_grads = objective_grad_fn(
File "/home/lijingwen/Miniconda3/envs/toy/lib/python3.8/site-packages/functorch/_src/eager_transforms.py", line 1241, in wrapper
results = grad_and_value(func, argnums, has_aux=has_aux)(*args, **kwargs)
File "/home/lijingwen/Miniconda3/envs/toy/lib/python3.8/site-packages/functorch/_src/vmap.py", line 35, in fn
return f(*args, **kwargs)
File "/home/lijingwen/Miniconda3/envs/toy/lib/python3.8/site-packages/functorch/_src/eager_transforms.py", line 1111, in wrapper
output = func(*args, **kwargs)
File "/home/lijingwen/Miniconda3/envs/toy/lib/python3.8/site-packages/torchopt/diff/implicit/nn/module.py", line 54, in _stateless_objective_fn
return self.objective(*input, **kwargs)
File "/home/lijingwen/Projects/Counter_align/DDSM_hierarchical/hierachical_model/model.py", line 41, in objective
self.biggest_mask=save_biggest_mask(mask_.detach().clone(), self.device)
File "/home/lijingwen/Projects/Counter_align/DDSM_hierarchical/hierachical_model/tools/util.py", line 177, in save_biggest_mask
A_ = cc_torch.connected_components_labeling(A_.squeeze(1)).unsqueeze(1)
File "/home/lijingwen/Miniconda3/envs/toy/lib/python3.8/site-packages/cc_torch-0.1-py3.8-linux-x86_64.egg/cc_torch/connected_components.py", line 18, in connected_components_labeling
return _C.cc_3d(x)
RuntimeError: Cannot access data pointer of Tensor that doesn't have storage
Process finished with exit code 1
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
It seems that the See also: |
Beta Was this translation helpful? Give feedback.
move the save_biggest_mask function out of self.objective,put it into self.solve instead~