diff --git a/test_instancesegmentation.py b/test_instancesegmentation.py
new file mode 100644
index 00000000000..784a48201f1
--- /dev/null
+++ b/test_instancesegmentation.py
@@ -0,0 +1,123 @@
+import torch
+import lightning.pytorch as pl
+from lightning.pytorch import LightningModule
+from torch.utils.data import DataLoader
+from torchgeo.datasets import VHR10
+from torchgeo.trainers import InstanceSegmentationTask
+import torch.nn.functional as F
+from pycocotools import mask as coco_mask
+from torch.utils.data import Subset
+import matplotlib.pyplot as plt
+import torchvision.transforms.functional as TF
+from PIL import Image, ImageDraw, ImageOps
+import numpy as np
+from matplotlib.patches import Rectangle
+from torchvision.transforms.functional import to_pil_image
+
+# Custom collate function for DataLoader (required for Mask R-CNN models)
+def collate_fn(batch):
+    max_height = max(sample['image'].shape[1] for sample in batch)
+    max_width = max(sample['image'].shape[2] for sample in batch)
+
+    images = torch.stack([
+        F.pad(sample['image'], (0, max_width - sample['image'].shape[2], 0, max_height - sample['image'].shape[1]))
+        for sample in batch
+    ])
+
+    targets = [
+        {
+            "labels": sample["labels"].to(torch.int64),
+            "boxes": sample["boxes"].to(torch.float32),
+            "masks": F.pad(
+                sample["masks"],
+                (0, max_width - sample["masks"].shape[2], 0, max_height - sample["masks"].shape[1]),
+            ).to(torch.uint8),
+        }
+        for sample in batch
+    ]
+
+    return {"image": images, "target": targets}
+
+# Visualization function
+def visualize_predictions(image, predictions, targets):
+    """Visualize model predictions and ground truth."""
+    image = to_pil_image(image)
+
+    fig, ax = plt.subplots(1, 1, figsize=(10, 10))
+    ax.imshow(image)
+    
+    # Plot predictions
+    for box, label in zip(predictions['boxes'], predictions['labels']):
+        x1, y1, x2, y2 = box
+        rect = Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=2, edgecolor='red', facecolor='none')
+        ax.add_patch(rect)
+        ax.text(x1, y1, str(label.item()), color='red', fontsize=12)
+    
+    # Plot ground truth
+    for box, label in zip(targets['boxes'], targets['labels']):
+        x1, y1, x2, y2 = box
+        rect = Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=2, edgecolor='blue', facecolor='none')
+        ax.add_patch(rect)
+        ax.text(x1, y1, str(label.item()), color='blue', fontsize=12)
+    
+    plt.show()
+
+# Initialize the VHR10 dataset
+train_dataset = VHR10(root="data", split="positive", transforms=None, download=True)
+val_dataset = VHR10(root="data", split="positive", transforms=None)
+
+# Select a small subset of the dataset 
+N = 100  # Number of samples to use
+train_subset = Subset(train_dataset, list(range(N)))
+val_subset = Subset(val_dataset, list(range(N)))
+
+if __name__ == '__main__':
+    import multiprocessing
+    multiprocessing.set_start_method('spawn', force=True)
+
+    train_loader = DataLoader(train_subset, batch_size=1, shuffle=True, num_workers=1, collate_fn=collate_fn, persistent_workers=True)
+    val_loader = DataLoader(val_subset, batch_size=1, shuffle=False, num_workers=1, collate_fn=collate_fn, persistent_workers=True)
+
+    print('\nDEBUG TRAIN LOADER\n')
+    for batch in train_loader:
+        print(f"Image shape: {batch['image'].shape}")
+        print(f"Target: {batch['target']}")
+        break
+
+    for batch in train_loader:
+        print(batch)
+        break
+
+    trainer = pl.Trainer(
+        max_epochs=10,
+        accelerator="gpu" if torch.cuda.is_available() else "cpu",
+        devices=1
+    )
+
+    task = InstanceSegmentationTask(
+        model="mask_rcnn",         
+        backbone="resnet50",       
+        weights=True,              
+        num_classes=11,            
+        lr=1e-3,                   
+        freeze_backbone=False      
+    )
+
+    print('\nTRAIN THE MODEL\n')
+
+    trainer.fit(task, train_dataloaders=train_loader, val_dataloaders=val_loader)
+
+    print('\nEVALUATE THE MODEL\n')
+
+    trainer.test(task, dataloaders=val_loader)
+
+    print('\nINFERENCE AND VISUALIZATION\n')
+
+    test_sample = train_dataset[0]
+    test_image = test_sample["image"].unsqueeze(0)  # Add batch dimension
+    predictions = task.predict_step({"image": test_image}, batch_idx=0)
+
+    visualize_predictions(test_image, predictions[0], test_sample)
+
+
+