Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 committed Jan 9, 2024
1 parent 2efb797 commit 9ec4ce5
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 52 deletions.
12 changes: 7 additions & 5 deletions doctr/models/detection/differentiable_binarization/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ def compute_loss(
l1_loss = torch.zeros(1, device=out_map.device)
if torch.any(seg_mask):
# Focal loss
focal_scale = 5.0
bce_loss = F.binary_cross_entropy_with_logits(out_map, seg_target, reduction="none")

if gamma < 0:
Expand All @@ -264,17 +265,18 @@ def compute_loss(
# Class reduced
focal_loss = (seg_mask * focal_loss).sum() / seg_mask.sum()

# Dice loss
inter = (seg_mask * prob_map * seg_target).sum()
cardinality = (seg_mask * (prob_map + seg_target)).sum()
# Compute dice loss for approx binary_map
binary_map = 1 / (1 + torch.exp(-50.0 * (prob_map - thresh_map)))
inter = (seg_mask * binary_map * seg_target).sum()
cardinality = (seg_mask * (binary_map + seg_target)).sum()
dice_loss = 1 - 2 * (inter + eps) / (cardinality + eps)

# Compute l1 loss for thresh_map
l1_scale = 10.0
if torch.any(thresh_mask):
l1_loss = torch.mean(torch.abs(thresh_map[thresh_mask] - thresh_target[thresh_mask]))
l1_loss = (torch.abs(thresh_map - thresh_target) * thresh_mask).sum() / (thresh_mask.sum() + eps)

return l1_scale * l1_loss + focal_loss + dice_loss # type: ignore[return-value]
return l1_scale * l1_loss + focal_scale * focal_loss + dice_loss # type: ignore[return-value]


def _dbnet(
Expand Down
18 changes: 11 additions & 7 deletions doctr/models/detection/differentiable_binarization/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,9 @@ def compute_loss(
thresh_target = tf.convert_to_tensor(thresh_target, dtype=out_map.dtype)
thresh_mask = tf.convert_to_tensor(thresh_mask, dtype=tf.bool)

bce_loss = tf.keras.losses.binary_crossentropy(seg_target[..., None], out_map[..., None], from_logits=True)

# Focal loss
focal_scale = 5.0
bce_loss = tf.keras.losses.binary_crossentropy(seg_target[..., None], out_map[..., None], from_logits=True)
if gamma < 0:
raise ValueError("Value of gamma should be greater than or equal to zero.")
# Convert logits to prob, compute gamma factor
Expand All @@ -209,19 +209,23 @@ def compute_loss(
# Class reduced
focal_loss = tf.reduce_sum(seg_mask * focal_loss, (0, 1, 2, 3)) / tf.reduce_sum(seg_mask, (0, 1, 2, 3))

# Dice loss
inter = tf.math.reduce_sum(seg_mask * prob_map * seg_target, (0, 1, 2, 3))
cardinality = tf.math.reduce_sum((prob_map + seg_target), (0, 1, 2, 3))
# Compute dice loss for approx binary_map
binary_map = 1.0 / (1.0 + tf.exp(-50 * (prob_map - thresh_map)))
inter = tf.reduce_sum(seg_mask * binary_map * seg_target, (0, 1, 2, 3))
cardinality = tf.reduce_sum((binary_map + seg_target), (0, 1, 2, 3))
dice_loss = 1 - 2 * (inter + eps) / (cardinality + eps)

# Compute l1 loss for thresh_map
l1_scale = 10.0
if tf.reduce_any(thresh_mask):
l1_loss = tf.math.reduce_mean(tf.math.abs(thresh_map[thresh_mask] - thresh_target[thresh_mask]))
thresh_mask = tf.cast(thresh_mask, tf.float32)
l1_loss = tf.reduce_sum(tf.abs(thresh_map - thresh_target) * thresh_mask) / (
tf.reduce_sum(thresh_mask) + eps
)
else:
l1_loss = tf.constant(0.0)

return l1_scale * l1_loss + focal_loss + dice_loss
return l1_scale * l1_loss + focal_scale * focal_loss + dice_loss

def call(
self,
Expand Down
76 changes: 36 additions & 40 deletions doctr/models/detection/linknet/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,22 +46,20 @@

def decoder_block(in_chan: int, out_chan: int, stride: int, **kwargs: Any) -> Sequential:
"""Creates a LinkNet decoder block"""
return Sequential(
[
*conv_sequence(in_chan // 4, "relu", True, kernel_size=1, **kwargs),
layers.Conv2DTranspose(
filters=in_chan // 4,
kernel_size=3,
strides=stride,
padding="same",
use_bias=False,
kernel_initializer="he_normal",
),
layers.BatchNormalization(),
layers.Activation("relu"),
*conv_sequence(out_chan, "relu", True, kernel_size=1),
]
)
return Sequential([
*conv_sequence(in_chan // 4, "relu", True, kernel_size=1, **kwargs),
layers.Conv2DTranspose(
filters=in_chan // 4,
kernel_size=3,
strides=stride,
padding="same",
use_bias=False,
kernel_initializer="he_normal",
),
layers.BatchNormalization(),
layers.Activation("relu"),
*conv_sequence(out_chan, "relu", True, kernel_size=1),
])


class LinkNetFPN(Model, NestedObject):
Expand Down Expand Up @@ -131,30 +129,28 @@ def __init__(
self.fpn = LinkNetFPN(fpn_channels, [_shape[1:] for _shape in self.feat_extractor.output_shape])
self.fpn.build(self.feat_extractor.output_shape)

self.classifier = Sequential(
[
layers.Conv2DTranspose(
filters=32,
kernel_size=3,
strides=2,
padding="same",
use_bias=False,
kernel_initializer="he_normal",
input_shape=self.fpn.decoders[-1].output_shape[1:],
),
layers.BatchNormalization(),
layers.Activation("relu"),
*conv_sequence(32, "relu", True, kernel_size=3, strides=1),
layers.Conv2DTranspose(
filters=num_classes,
kernel_size=2,
strides=2,
padding="same",
use_bias=True,
kernel_initializer="he_normal",
),
]
)
self.classifier = Sequential([
layers.Conv2DTranspose(
filters=32,
kernel_size=3,
strides=2,
padding="same",
use_bias=False,
kernel_initializer="he_normal",
input_shape=self.fpn.decoders[-1].output_shape[1:],
),
layers.BatchNormalization(),
layers.Activation("relu"),
*conv_sequence(32, "relu", True, kernel_size=3, strides=1),
layers.Conv2DTranspose(
filters=num_classes,
kernel_size=2,
strides=2,
padding="same",
use_bias=True,
kernel_initializer="he_normal",
),
])

self.postprocessor = LinkNetPostProcessor(assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh)

Expand Down

0 comments on commit 9ec4ce5

Please sign in to comment.