Skip to content

Commit

Permalink
update loss tf and pt
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 committed Jan 8, 2024
1 parent bd03394 commit ed1457a
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 82 deletions.
92 changes: 35 additions & 57 deletions doctr/models/detection/differentiable_binarization/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,15 @@ def forward(

return out

def compute_loss(self, out_map: torch.Tensor, thresh_map: torch.Tensor, target: List[np.ndarray]) -> torch.Tensor:
def compute_loss(
self,
out_map: torch.Tensor,
thresh_map: torch.Tensor,
target: List[np.ndarray],
gamma: float = 2.0,
alpha: float = 0.5,
eps: float = 1e-8,
) -> torch.Tensor:
"""Compute a batch of gts, masks, thresh_gts, thresh_masks from a list of boxes
and a list of masks for each image. From there it computes the loss with the model output
Expand All @@ -222,81 +230,51 @@ def compute_loss(self, out_map: torch.Tensor, thresh_map: torch.Tensor, target:
out_map: output feature map of the model of shape (N, C, H, W)
thresh_map: threshold map of shape (N, C, H, W)
target: list of dictionary where each dict has a `boxes` and a `flags` entry
gamma: modulating factor in the focal loss formula
alpha: balancing factor in the focal loss formula
eps: epsilon factor in dice loss
Returns:
-------
A loss tensor
"""
prob_map = torch.sigmoid(out_map)
thresh_map = torch.sigmoid(thresh_map)
binary_mask = torch.reciprocal(1.0 + torch.exp(-50 * (prob_map - thresh_map)))

# TODO: needs also some checks again shrunk masks
targets = self.build_target(target, out_map.shape[1:], False) # type: ignore[arg-type]

seg_target, seg_mask = torch.from_numpy(targets[0]), torch.from_numpy(targets[1])
seg_target, seg_mask = seg_target.to(out_map.device), seg_mask.to(out_map.device)
thresh_target, thresh_mask = torch.from_numpy(targets[2]), torch.from_numpy(targets[3])
thresh_target, thresh_mask = thresh_target.to(out_map.device), thresh_mask.to(out_map.device)

balanced_bce_loss = torch.zeros(1, device=out_map.device)
focal_loss = torch.zeros(1, device=out_map.device)
dice_loss = torch.zeros(1, device=out_map.device)
l1_loss = torch.zeros(1, device=out_map.device)

# TODO: Still in progress @Oliver @Charles

if torch.any(seg_mask):
# Compute balanced bce loss
bce_loss = F.binary_cross_entropy_with_logits(
out_map,
seg_target,
reduction="none",
)

positive = (seg_target * seg_mask).float()
negative = ((1 - seg_target) * seg_mask).float()
positive_count = int(positive.sum())
if positive_count == 0:
negative_count = min(int(negative.sum()), 0)
else:
negative_count = min(int(negative.sum()), int(positive_count * 3))

positive_loss = bce_loss * positive
negative_loss = bce_loss * negative

negative_loss, _ = torch.topk(negative_loss.view(-1), negative_count)

balanced_bce_loss = (
(positive_loss.sum() + negative_loss.sum()) / (positive_count + negative_count + 1e-6) * 5.0
)

# Compute dice loss
binary_mask = binary_mask.contiguous().view(binary_mask.size(0), -1)
seg_target = seg_target.contiguous().view(seg_target.size(0), -1)

seg_mask = seg_mask.contiguous().view(seg_mask.size(0), -1)
binary_mask = binary_mask * seg_mask
seg_target = seg_target * seg_mask

dice_coeff = (2 * (binary_mask * seg_target).sum()) / (binary_mask.sum() + seg_target.sum() + 1e-6)

dice_loss = 1 - dice_coeff

# Focal loss
bce_loss = F.binary_cross_entropy_with_logits(out_map, seg_target, reduction="none")

if gamma < 0:
raise ValueError("Value of gamma should be greater than or equal to zero.")
p_t = prob_map * seg_target + (1 - prob_map) * (1 - seg_target)
alpha_t = alpha * seg_target + (1 - alpha) * (1 - seg_target)
# Unreduced version
focal_loss = alpha_t * (1 - p_t) ** gamma * bce_loss
# Class reduced
focal_loss = (seg_mask * focal_loss).sum() / seg_mask.sum()

# Dice loss
inter = (seg_mask * prob_map * seg_target).sum()
cardinality = (seg_mask * (prob_map + seg_target)).sum()
dice_loss = 1 - 2 * (inter + eps) / (cardinality + eps)

# Compute l1 loss for thresh_map
l1_scale = 10.0
if torch.any(thresh_mask):
# Compute l1 loss
assert thresh_map.size() == thresh_target.size() and thresh_target.numel() > 0
assert thresh_mask.size() == thresh_target.size()
x = thresh_map * thresh_mask
y = thresh_target * thresh_mask

loss = torch.zeros_like(thresh_target)
diff = torch.abs(x - y)
mask_beta = diff < 1
loss[mask_beta] = 0.5 * torch.square(diff)[mask_beta] / 1
loss[~mask_beta] = diff[~mask_beta] - 0.5 * 1
l1_loss = loss.sum() / (thresh_mask.sum() + 1e-6) * 10.0

return (balanced_bce_loss + dice_loss + l1_loss) / 3
l1_loss = torch.mean(torch.abs(thresh_map[thresh_mask] - thresh_target[thresh_mask]))

return l1_scale * l1_loss + focal_loss + dice_loss # type: ignore[return-value]


def _dbnet(
Expand Down
50 changes: 25 additions & 25 deletions doctr/models/detection/differentiable_binarization/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,9 @@ def compute_loss(
out_map: tf.Tensor,
thresh_map: tf.Tensor,
target: List[Dict[str, np.ndarray]],
gamma: float = 2.0,
alpha: float = 0.5,
eps: float = 1e-8,
) -> tf.Tensor:
"""Compute a batch of gts, masks, thresh_gts, thresh_masks from a list of boxes
and a list of masks for each image. From there it computes the loss with the model output
Expand All @@ -175,6 +178,9 @@ def compute_loss(
out_map: output feature map of the model of shape (N, H, W, C)
thresh_map: threshold map of shape (N, H, W, C)
target: list of dictionary where each dict has a `boxes` and a `flags` entry
gamma: modulating factor in the focal loss formula
alpha: balancing factor in the focal loss formula
eps: epsilon factor in dice loss
Returns:
-------
Expand All @@ -186,33 +192,27 @@ def compute_loss(
seg_target, seg_mask, thresh_target, thresh_mask = self.build_target(target, out_map.shape[1:], True)
seg_target = tf.convert_to_tensor(seg_target, dtype=out_map.dtype)
seg_mask = tf.convert_to_tensor(seg_mask, dtype=tf.bool)
seg_mask = tf.cast(seg_mask, tf.float32)
thresh_target = tf.convert_to_tensor(thresh_target, dtype=out_map.dtype)
thresh_mask = tf.convert_to_tensor(thresh_mask, dtype=tf.bool)

# Compute balanced BCE loss for proba_map
bce_scale = 5.0
bce_loss = tf.keras.losses.binary_crossentropy(
seg_target[..., None],
out_map[..., None],
from_logits=True,
)[seg_mask]

neg_target = 1 - seg_target[seg_mask]
positive_count = tf.math.reduce_sum(seg_target[seg_mask])
negative_count = tf.math.reduce_min([tf.math.reduce_sum(neg_target), 3.0 * positive_count])
negative_loss = bce_loss * neg_target
negative_loss, _ = tf.nn.top_k(negative_loss, tf.cast(negative_count, tf.int32))
sum_losses = tf.math.reduce_sum(bce_loss * seg_target[seg_mask]) + tf.math.reduce_sum(negative_loss)
balanced_bce_loss = sum_losses / (positive_count + negative_count + 1e-6)

# Compute dice loss for approxbin_map
bin_map = 1 / (1 + tf.exp(-50.0 * (prob_map[seg_mask] - thresh_map[seg_mask])))

bce_min = tf.math.reduce_min(bce_loss)
weights = (bce_loss - bce_min) / (tf.math.reduce_max(bce_loss) - bce_min) + 1.0
inter = tf.math.reduce_sum(bin_map * seg_target[seg_mask] * weights)
union = tf.math.reduce_sum(bin_map) + tf.math.reduce_sum(seg_target[seg_mask]) + 1e-8
dice_loss = 1 - 2.0 * inter / union
bce_loss = tf.keras.losses.binary_crossentropy(seg_target[..., None], out_map[..., None], from_logits=True)

# Focal loss
if gamma < 0:
raise ValueError("Value of gamma should be greater than or equal to zero.")
# Convert logits to prob, compute gamma factor
p_t = (seg_target * prob_map) + ((1 - seg_target) * (1 - prob_map))
alpha_t = seg_target * alpha + (1 - seg_target) * (1 - alpha)
# Unreduced loss
focal_loss = alpha_t * (1 - p_t) ** gamma * bce_loss
# Class reduced
focal_loss = tf.reduce_sum(seg_mask * focal_loss, (0, 1, 2, 3)) / tf.reduce_sum(seg_mask, (0, 1, 2, 3))

# Dice loss
inter = tf.math.reduce_sum(seg_mask * prob_map * seg_target, (0, 1, 2, 3))
cardinality = tf.math.reduce_sum((prob_map + seg_target), (0, 1, 2, 3))
dice_loss = 1 - 2 * (inter + eps) / (cardinality + eps)

# Compute l1 loss for thresh_map
l1_scale = 10.0
Expand All @@ -221,7 +221,7 @@ def compute_loss(
else:
l1_loss = tf.constant(0.0)

return l1_scale * l1_loss + bce_scale * balanced_bce_loss + dice_loss
return l1_scale * l1_loss + focal_loss + dice_loss

def call(
self,
Expand Down

0 comments on commit ed1457a

Please sign in to comment.