Can You Train a Segmentation Model with Just 5 Labeled Pixels?
If you've ever labeled segmentation data, you know the pain: every single pixel needs a class. For a 512×512 image, that's 262,144 decisions. Multiply by hundreds of images and you're looking at weeks of annotation work.
But what if you could train a decent model by labeling just 5 pixels per image?
That's the idea behind point-label supervision , and it works surprisingly well. In this post, I'll walk through the concept, implement it from scratch in PyTorch, and show real experiments on aerial imagery.
The Setup
Dataset:
Model: A simple 3-level U-Net with ~300K parameters , nothing fancy, just enough to prove the concept.
The trick: Partial cross-entropy loss.
What Is Partial Cross-Entropy?
Standard cross-entropy computes loss over every pixel:
CE = mean( -log(p_correct) ) for ALL pixels
Partial cross-entropy only computes loss on pixels that have labels:
pCE = sum( -log(p_correct) × M_labeled ) / sum(M_labeled)
Where M_labeled is a binary mask , 1 where we have a label, 0 everywhere else.
Unlabeled pixels contribute zero gradient. The model still predicts all pixels, but only gets
feedback on the labeled ones.
In PyTorch, the implementation is almost trivially simple:
class PartialCELoss(nn.Module):
def __init__(self, ignore_index=255):
super().__init__()
self.ignore_index = ignore_index
self.ce = nn.CrossEntropyLoss(ignore_index=ignore_index, reduction="none")
def forward(self, logits, targets):
loss_map = self.ce(logits, targets) # (B, H, W)
mask = (targets != self.ignore_index).float() # 1 where labeled
n_labeled = mask.sum().clamp(min=1.0)
return (loss_map * mask).sum() / n_labeled
Unlabeled pixels are set to 255 in the target mask. PyTorch's ignore_index
handles them cleanly.
Simulating Point Labels
We start with full dense masks (from LoveDA) and simulate point labels by randomly sampling N pixels per image , setting everything else to 255:
partial = torch.full_like(mask, 255)
valid = (mask != 255).nonzero(as_tuple=False)
sel = valid[torch.randperm(len(valid))[:num_points]]
for c in sel:
partial[c[0], c[1]] = mask[c[0], c[1]]
With 20 points on a 256×256 image, that's 0.03% of pixels labeled. The rest is all unlabeled.
Experiment 1: How Many Points Do You Need?
I trained three identical U-Nets from scratch, varying only the number of point labels per image: 5, 20, and 100.
Results
| Points / image | Best val mIoU | Final loss |
|---|---|---|
| 5 | 0.1592 | 1.354 |
| 20 | 0.1918 | 1.198 |
| 100 | 0.1842 | 1.195 |
For reference, random guessing on 7 classes gives ~0.05 mIoU. So even 5 labeled pixels (0.008% of the image) gets us 3× above random.
The sweet spot was 20 points , enough signal for the model to learn, without overfitting to the sparse labels. The 100-point model actually scored slightly lower, likely because 5 epochs wasn't enough for it to fully leverage the extra supervision.
What This Means
If you're annotating remote sensing data and can't afford dense masks, just clicking 20 points per image gets you most of the way there. The annotation time drops from hours to seconds per image.
Experiment 2: Can Focal Loss Help?
Standard CE treats all labeled pixels equally. But with only 20 labels per image, some of those pixels are "easy" , sitting in the middle of a big forest patch where the answer is obvious. The model learns those quickly but wastes gradient on them.
Focal loss fixes this by down-weighting easy (high-confidence) pixels:
pFCE = sum( (1 - p_t)^γ × (-log p_t) × M_labeled ) / sum(M_labeled)
The (1 - p_t)^γ term shrinks toward zero for confident predictions. With γ=2,
a pixel predicted at 90% confidence contributes 100× less gradient than one at 50%.
Results
| Loss | Best val mIoU |
|---|---|
| Partial CE | 0.1509 |
| Partial Focal CE (γ=2) | 0.1898 |
A 25.7% relative improvement just by switching the loss function. Focal loss's mIoU starts slower (the model initially struggles with the reweighted gradients) but catches up and surpasses plain CE by epoch 5 , and the curve is still climbing.
Key Takeaways
- Partial CE is dead simple and works. One class, ~10 lines of code, and it handles sparse labels cleanly.
- 20 points per image is a practical sweet spot. You get most of the benefit of dense annotation at a fraction of the cost.
- Focal loss is a free upgrade for sparse labels. Same training setup, better results, because it stops wasting gradient on easy pixels.
- The model generalizes beyond its labels. Despite seeing only 0.03% of labeled pixels, the model learns to segment entire regions , skip connections and spatial convolutions propagate the sparse signal across the image.