Skip to content

Latest commit



364 lines (317 loc) · 12.8 KB

File metadata and controls

364 lines (317 loc) · 12.8 KB

Context-Based Contrastive Learning for Scene Text Recognition

A code repository that attempts to implementation the ConCLR (AAAI 2022) (unofficial)

Runtime Environment

install the dependencies

pip install -r requirements.txt


Train ABINet-Vision-ConCLR model

CUDA_VISIBLE_DEVICES=0,1,2,3 python --config=configs/conclr_pretrain_vision_model.yaml


Evaluate ABINet-Vision-ConCLR model

CUDA_VISIBLE_DEVICES=0 python --config=configs/conclr_pretrain_vision_model.yaml --phase test --image_only

Run Demo

python --config=configs/conclr_pretrain_vision_model.yaml --input=test/img/path

Additional flags:
--config /path/to/config :: Set the path of configuration file
--input /path/to/image-directory :: Set the path of image directory or wildcard path
--checkpoint /path/to/checkpoint :: Set the path of trained model
--cuda [-1|0|1|2|3...] :: Set the cuda id, by default -1 is set and stands for cpu
--model_eval [alignment|vision] :: Specify which sub-model to use
--image_only :: Disable dumping visualization of attention masks


The authors’ key insight is that by pulling together embeddings of the same character in different contexts and pushing apart embeddings of different characters, we can guide models to learn a representation better balances the intrinsic and context information. Here are some key components of the ConCLR.

Context-based Data Augmentation

Before feeding the batch data into the model, Context-based Data Augmentation (ConAug) randomly permutes the input batch twice and randomly concatenates it to the left or right of the original batch. As shown in figure below:


Because ConAug need to process the input data after organizing it into batch data and before feeding it into the model, I use the Callback mechanism of the fastai to implement it. I implemented the “on_batch_begin” method which will be called before the batch feeding into model. Here is my implement code:

class ConAugPretransform(Callback):
    def __init__(self, *args, **kwargs):
        self.aug_type = kwargs["aug_type"]
        self.max_length = kwargs["max_length"]
        self.test_conaug = kwargs["test_conaug"]
        self.charset_path = kwargs["charset_path"]

    def _augment_images(self, images, permuted_indices, left):
        permuted_images = images[permuted_indices]
        augmented_images =
            [permuted_images if left else images, images if left else permuted_images],
        original_height, original_width = images.size(2), images.size(3)
        augmented_images = F.interpolate(
            size=(original_height, original_width),
        return augmented_images

    def _visualize_augmented_batch(self, x, x1, x2, y, y1, y2, yl, yl1, yl2):

    def _process_gt_labels(self, gt_labels, gt_lengths, permuted_indices, left):
        new_gt_labels = gt_labels.clone()
        new_gt_lengths = gt_lengths.clone()
        for i, perm in enumerate(permuted_indices):
            if left:
                new_gt =
                        gt_labels[perm][: gt_lengths[perm] - 1],
                        gt_labels[i][: self.max_length + 2 - gt_lengths[perm] - 1],
                        torch.unsqueeze(gt_labels[perm][-1], dim=0),
                new_gt =
                        gt_labels[i][: gt_lengths[i] - 1],
                        gt_labels[perm][: self.max_length + 2 - gt_lengths[i] - 1],
                        torch.unsqueeze(gt_labels[i][-1], dim=0),
            new_gt_labels[i] = new_gt
        new_gt_lengths += new_gt_lengths[permuted_indices] - 1
        new_gt_lengths = torch.minimum(
            new_gt_lengths, torch.tensor(self.max_length + 1)
        return new_gt_labels, new_gt_lengths

    def on_batch_begin(self, last_input, last_target, **kwargs) -> dict:
        images = last_input[0]
        permuted_indices_one = torch.randperm(images.size(0))
        permuted_indices_two = torch.randperm(images.size(0))

        left_one = torch.randint(0, 2, (1,), dtype=torch.bool)
        left_two = torch.randint(0, 2, (1,), dtype=torch.bool)

        augmented_batch_one = self._augment_images(
            images, permuted_indices_one, left_one
        augmented_batch_two = self._augment_images(
            images, permuted_indices_two, left_two

        last_input = ((images, augmented_batch_one, augmented_batch_two), last_input[1])

        gt_labels, gt_lengths = last_target[0], last_target[1]
        gt_labels_one, gt_lengths_one = self._process_gt_labels(
            gt_labels, gt_lengths, permuted_indices_one, left_one
        gt_labels_two, gt_lengths_two = self._process_gt_labels(
            gt_labels, gt_lengths, permuted_indices_two, left_two

        last_target[0] = [gt_labels, gt_labels_one, gt_labels_two]
        last_target[1] = [gt_lengths, gt_lengths_one, gt_lengths_two]

        if self.test_conaug:

        return {"last_input": last_input, "last_target": last_target}

In this function, I process the augmented batch and its label together. (the augmented batch is also used to calculate recognition loss so process the label is necessary)

I plotted the results of conaug to verify the correctness of my implementation, as follows:


Projection Head

Feed the two augmented batch into the backbone and decoder, we obtained their corresponding glimpse vectors. The author mentions that SimCLR proposes to utilize embeddings directly for contrastive learning will harm to the model performance,thereby necessitating a projection head to filter out irrelevant information. Additionally, I noted that SimCLR discusses the benefits of incorporating non-linear layers in the projection head to enhance performance. Consequently, I implemented the head using the following code.

class ProjectionHead(nn.Module):
    def __init__(self, in_channel, out_channel, head="mlp") -> None:
        if head == "linear":
            self.head = nn.Linear(in_channel, out_channel)
        if head == "mlp":
            self.head = nn.Sequential(
                nn.Linear(in_channel, in_channel),
                nn.Linear(in_channel, out_channel),

    def forward(self, x):
        embed = F.normalize(self.head(x), dim=-1, p=2)
        return embed

Contrastive Loss

By ConAug、Backbone、Decoder and the Projection Head, we have obtained two embeddings corresponding the two augmented batches, and now how could we do the “pulling together embeddings of the same character in different contexts and pushing apart embeddings of different characters”? The anwser is the Contrastive Loss.


Here are the formulate:

$$ \mathcal{L} _ {pair}(\boldsymbol{T})=∑ _ {m ∈ M} \frac{-1}{|P(m)|}∑ _ {p ∈ P(m)} \left(log exp \left(\boldsymbol{z} _ m ⋅ \boldsymbol{z} _ p / τ\right) -log\sum _ {a ∈ A(m)} exp \left(\boldsymbol{z} _ m ⋅ \boldsymbol{z} _ a / τ\right)\right) $$

$$ \mathcal{L} _ {clr}=\frac{1}{N}∑ _ {T ∈ Iaug}\mathcal{L} _ {pair}(\boldsymbol{T}) $$

I won’t go into details about the meaning of the letters in the specific formula here, it can be found in the paper. My code is implemented as follows.

class ContrastiveLoss(nn.Module):
    def __init__(self, temprature=2):
        self.temprature = temprature

    def _clr_loss(self, embed1, embed2, labels1, labels2, loss_name, record=True):
        embed =, embed2), dim=1)
        labels =, labels2), dim=1)

        _, max_length = labels.shape
        s = torch.div(
            embed @ embed.transpose(2, 1), self.temprature
        )  # calculate similarity
        am = (
            ~torch.eye(max_length, dtype=torch.bool)
        )  # negative mask(include pm)
        pm = (
            labels[:, :, None]
            == labels[:, None, :]  # expand label and compare it with its transpose
        ) & am  # positive mask

        # exclude eos
        p_num = torch.where(
            labels.unsqueeze(2) != 0, pm.sum(dim=1).unsqueeze(2), 0
        )  # count pos num of each m

        # include eos
        # p_num = pm.sum(dim=1).unsqueeze(2)  # count pos num of each m

        em = p_num.bool()  # deal with no positive

        s = torch.masked_fill(s, ~em, 0)  # mask no pos
        p = torch.masked_fill(s, ~pm, 0)
        a = torch.logsumexp(s.masked_fill(~am, 0), dim=2)
        a = torch.masked_fill(a, ~em.squeeze(), 0)

        smx = torch.where(
            p != 0,
            p - a.unsqueeze(2),
            torch.tensor(0.0, dtype=torch.float32, device="cuda"),
        dv = torch.where(
            -torch.sum(smx, dim=2) / p_num.squeeze(),
            torch.tensor(0.0, dtype=torch.float32, device="cuda"),
        l_pair = torch.sum(dv, dim=1)
        loss = torch.mean(l_pair, dim=0)

        if record and loss_name is not None:
            self.losses[f"{loss_name}_loss"] = loss

        return loss

    def forward(self, X, Y, *args):
        self.losses = {}
        features, loss_name = X[0].get("proj"), X[0].get("name")
        return self._clr_loss(
            torch.argmax(Y[0], dim=2),
            torch.argmax(Y[1], dim=2),

To improve parallel computation efficiency, I compute the contrastive loss directly across the entire batch by matriphy operation. By applying masks , we obtain an equivalent result to the loss function.


For model evaluation to demonstrate effectiveness of reproduction, I trained the ConCLR-Vision and ABInet-Vision models using same training settings. Due to constraints on time and computational resources, both models were trained on the ST dataset for 4 epochs using four NVIDIA 2080 Ti GPUs (on AutoDL 🥲). The learning rate was initialized at 1e-4 and decayed to 1e-5 by the final epoch.

Model Dataset CCR CWR
conclr-vision IIIT5k 0.942 0.891
IC13 0.963 0.890
CUTE80 0.928 0.910
abinet-vision IIIT5k 0.926 0.872
IC13 0.938 0.852
CUTE80 0.771 0.882

The results indicate that using ConCLR led to significant improvements in both CCR and CWR. However, this experiment was conducted on a dataset that was not sufficiently large, and the training duration was relatively short. Due to resource constraints, I have not been able to conduct a more comprehensive experiment now, More comprehensive experiments are underway…

btw, the training log are stored in the “log/” directory, you can visit it through tensorboard or just look at the .txt file.

(There might be errors in my implementation, and I would greatly appreciate any feedback on this matter 🥰)

