Skip to content

Commit

Permalink
bug fixing
Browse files Browse the repository at this point in the history
  • Loading branch information
htjb committed Jul 25, 2023
1 parent 2d4506d commit f8735d3
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 10 deletions.
21 changes: 12 additions & 9 deletions margarine/maf.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,13 @@ def _train_step(self, x, w, loss_type, maf):
maf.trainable_variables))
return loss

@tf.function(jit_compile=True)
def call_bij(self, bij, u, min, max):
x = _forward_transform(u)
x = bij(x)
x = _inverse_transform(x, min, max)
return x

def __call__(self, u):

r"""
Expand All @@ -434,13 +441,6 @@ def __call__(self, u):
"""

@tf.function(jit_compile=True)
def call_bij(bij, u, min=self.theta_min, max=self.theta_max):
x = _forward_transform(u)
x = bij(x)
x = _inverse_transform(x, min, max)
return x


if self.cluster_number is not None:
len_thetas = [len(self.theta[i]) for i in range(len(self.theta))]
Expand All @@ -456,12 +456,15 @@ def call_bij(bij, u, min=self.theta_min, max=self.theta_max):

values = []
for i in range(len(options)):
x = call_bij(self.bij[i], u, min=self.theta_min[i], max=self.theta_max[i]).numpy()
x = self.call_bij(self.bij[i], u,
min=self.theta_min[i],
max=self.theta_max[i]).numpy()
values.append(x)

x = np.concatenate(values)
else:
x = call_bij(self.bij, u).numpy()
x = self.call_bij(self.bij, u, self.theta_min,
self.theta_max).numpy()

mask = np.isfinite(x).all(axis=-1)
return x[mask, ...]
Expand Down
3 changes: 2 additions & 1 deletion margarine/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def _forward_transform(x, min=0, max=1):
"""
x = tf.cast(x, tf.float32)
min = tf.cast(min, tf.float32)
max = tf.cast(max, tf.float32)
x = tfd.Uniform(min, max).cdf(x)
x = tfd.Normal(0, 1).quantile(x)
return x
Expand Down Expand Up @@ -54,7 +56,6 @@ def _inverse_transform(x, min, max):
"""
x = tfd.Normal(0, 1).cdf(x)
print(min)
min = tf.cast(min, tf.float32)
max = tf.cast(max, tf.float32)
x = tfd.Uniform(min, max).quantile(x)
Expand Down

0 comments on commit f8735d3

Please sign in to comment.