Skip to content

Commit

Permalink
Update dys_opt_net.py
Browse files Browse the repository at this point in the history
  • Loading branch information
howardheaton authored Feb 15, 2024
1 parent d4be81a commit a902af0
Showing 1 changed file with 38 additions and 39 deletions.
77 changes: 38 additions & 39 deletions src/dys_opt_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,54 +101,53 @@ def test_time_forward(self, d):


def apply_DYS(self, z, w):
"""
Davis-Yin Splitting.
"""

x = self.project_C1(z)
y = self.project_C2(2.0*x - z - self.alpha*self.F(z, w))
z = z - x + y

return z
''' Apply a single update step from Davis-Yin Splitting.
Args:
z (tensor): Point in Euclidean space
w (tensor): Parameters defining function and its gradient
Returns:
z (tensor): Updated estimate of solution
'''
x = self.project_C1(z)
y = self.project_C2(2.0 * x - z - self.alpha*self.F(z, w))
z = z - x + y
return z


def train_time_forward(self, d, eps=1.0e-2, max_depth=int(1e4),
depth_warning=True):
"""
Default forward behaviour.
"""
with torch.no_grad():
w = self.data_space_forward(d)
self.depth = 0.0

z = torch.rand((self.n2), device=self.device)
z_prev = z.clone()
''' Default forward behaviour.
'''
with torch.no_grad():
w = self.data_space_forward(d)
self.depth = 0.0

z = torch.rand((self.n2), device=self.device)
z_prev = z.clone()

all_samp_conv = False
while not all_samp_conv and self.depth < max_depth:
z_prev = z.clone()
z = self.apply_DYS(z, w)
diff_norm = torch.norm(z - z_prev)
diff_norm = torch.norm( diff_norm)
diff_norm = torch.max( diff_norm ) # take norm along the latter two dimensions then max
self.depth += 1.0
all_samp_conv = diff_norm <= eps
all_samp_conv = False
while not all_samp_conv and self.depth < max_depth:
z_prev = z.clone()
z = self.apply_DYS(z, w)
diff_norm = torch.norm(z - z_prev)
diff_norm = torch.norm( diff_norm)
diff_norm = torch.max( diff_norm ) # take norm along the latter two dimensions then max
self.depth += 1.0
all_samp_conv = diff_norm <= eps

if self.depth >= max_depth and depth_warning:
print("\nWarning: Max Depth Reached - Break Forward Loop\n")

if self.training:
w = self.data_space_forward(d)
z = self.apply_DYS(z.detach(), w)
return self.project_C1(z)
else:
return self.project_C1(z).detach()
if self.depth >= max_depth and depth_warning:
print("\nWarning: Max Depth Reached - Break Forward Loop\n")
if self.training:
w = self.data_space_forward(d)
z = self.apply_DYS(z.detach(), w)
return self.project_C1(z)
else:
return self.project_C1(z).detach()

def forward(self, d, eps=1.0e-2, max_depth=int(1e4),
depth_warning=True):
'''
Includes a switch for using different behaviour at
test/deployment.
''' Includes a switch for using different behaviour at test/deployment.
'''
if not self.training:
return self.test_time_forward(d)
Expand Down

0 comments on commit a902af0

Please sign in to comment.