5  Pixels as Vectors: k-NN Segmentation

In Chapters 1–4 we converted each image to grayscale and applied a threshold to segment nucleus, cytoplasm, and background. That approach threw away two-thirds of the available information — the G and B channels — and replaced the classifier with a hard rule. Bayesian optimization could tune those thresholds, but the fundamental limitation remained: every pixel is judged by a single intensity value.

In this chapter we take a completely different view: every pixel is a point in 3-D colour space (R, G, B). We flatten all training images into a table of pixels, attach their ground-truth labels, and train a k-Nearest Neighbours (k-NN) classifier on them. For each test pixel the classifier finds its k nearest neighbours in the training data and takes a majority vote.

By the end of this chapter you will be able to:


5.1 Setup

Code
import glob
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.colors import ListedColormap

# ── Load dataset ──────────────────────────────────────────────────────────────
# X: (N, 3, H, W)  float32  0–1   (channels first)
# y: (N, H, W)     int64    0/1/2
try:
    _base = "C:/projects/VocEd"
    N = len(glob.glob(f"{_base}/imagedata/X/*.npy"))
    X = np.stack([np.load(f"{_base}/imagedata/X/{i}.npy") for i in range(N)])
    y = np.stack([np.load(f"{_base}/imagedata/y/{i}.npy") for i in range(N)])
except Exception:
    import subprocess
    if not os.path.isdir("VocEd"):
        subprocess.run(["git", "clone", "https://github.com/emilsar/VocEd.git"], check=True)
    N = len(glob.glob("VocEd/imagedata/X/*.npy"))
    X = np.stack([np.load(f"VocEd/imagedata/X/{i}.npy") for i in range(N)])
    y = np.stack([np.load(f"VocEd/imagedata/y/{i}.npy") for i in range(N)])

# ── Visualisation helpers ─────────────────────────────────────────────────────
mask_cmap = ListedColormap(['black', 'steelblue', 'crimson'])
legend_patches = [
    mpatches.Patch(color='black',     label='0 — background'),
    mpatches.Patch(color='steelblue', label='1 — cytoplasm'),
    mpatches.Patch(color='crimson',   label='2 — nucleus'),
]

# ── Dice helper (reused throughout the chapter) ───────────────────────────────
def dice_score(pred, target, cls):
    pred_mask   = (pred   == cls)
    target_mask = (target == cls)
    intersection = (pred_mask & target_mask).sum()
    denom = pred_mask.sum() + target_mask.sum()
    return 1.0 if denom == 0 else 2 * intersection / denom

print(f"Loaded {N} images.  Image shape: {X[0].shape}  Label shape: {y[0].shape}")
Loaded 200 images.  Image shape: (3, 256, 256)  Label shape: (256, 256)

5.2 Train / Test Split

We use the same reproducible 80/20 split that will be used in all subsequent chapters.

Code
from sklearn.model_selection import train_test_split

# Stratify not needed here — each image has all 3 classes
train_idx, test_idx = train_test_split(
    np.arange(N), test_size=0.2, random_state=42
)

print(f"Train: {len(train_idx)} images    Test: {len(test_idx)} images")
Train: 160 images    Test: 40 images

5.3 Visualising Colour Space

Before training anything, let’s check whether the three classes actually form separable clusters in RGB space. We use image 7 as a representative example — first viewing the raw image and its ground-truth mask, then projecting its pixels into colour space.

The key question: do background, cytoplasm, and nucleus pixels occupy different regions in colour space? If they do, a colour-based classifier has a chance.

Code
import json
from IPython.display import HTML, display

IDX = 7
img7    = X[IDX]                                      # (3, H, W)  float32  0–1
mask7   = y[IDX]                                      # (H, W)     0/1/2
pixels7 = img7.transpose(1, 2, 0).reshape(-1, 3)     # (H*W, 3)
labels7 = mask7.reshape(-1)                           # (H*W,)

rng        = np.random.default_rng(0)
cls_colors = ['#333333', 'steelblue', 'crimson']
cls_names  = ['background', 'cytoplasm', 'nucleus']

# ── Panel 1: raw image + ground-truth mask ────────────────────────────────────
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
axes[0].imshow(img7.transpose(1, 2, 0))
axes[0].set_title(f"Image {IDX} — RGB", fontsize=12)
axes[0].axis('off')

axes[1].imshow(mask7, cmap=mask_cmap, vmin=0, vmax=2, interpolation='nearest')
axes[1].set_title(f"Image {IDX} — Ground-truth mask", fontsize=12)
axes[1].axis('off')

fig.legend(handles=legend_patches, loc='lower center', ncol=3, fontsize=10,
           bbox_to_anchor=(0.5, 0.0))
plt.tight_layout(rect=[0, 0.07, 1, 1])
plt.show()

# ── Panel 2: 2-D colour-space projections ─────────────────────────────────────
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
for cls in range(3):
    samp   = pixels7[labels7 == cls]
    n      = min(3000, len(samp))
    chosen = rng.choice(len(samp), n, replace=False)
    s      = samp[chosen]
    axes[0].scatter(s[:, 0], s[:, 1], c=cls_colors[cls], s=1, alpha=0.5, label=cls_names[cls])
    axes[1].scatter(s[:, 0], s[:, 2], c=cls_colors[cls], s=1, alpha=0.5, label=cls_names[cls])

axes[0].set_xlabel('Red');   axes[0].set_ylabel('Green')
axes[0].set_title('2-D projection — Red vs Green'); axes[0].legend(markerscale=6)
axes[1].set_xlabel('Red');   axes[1].set_ylabel('Blue')
axes[1].set_title('2-D projection — Red vs Blue');  axes[1].legend(markerscale=6)
plt.tight_layout()
plt.show()

# ── Serialise for the 3-D widget below ───────────────────────────────────────
pts = []
for cls in range(3):
    samp = pixels7[labels7 == cls]
    n    = min(1500, len(samp))
    sel  = rng.choice(len(samp), n, replace=False)
    s    = samp[sel]
    pts.append({'r': s[:, 0].round(4).tolist(),
                'g': s[:, 1].round(4).tolist(),
                'b': s[:, 2].round(4).tolist(),
                'color': cls_colors[cls],
                'name':  cls_names[cls]})

display(HTML(f'<script>window._ch5ColourData={json.dumps(pts)};</script>'))
print(f"Image {IDX}{pixels7.shape[0]:,} total pixels, "
      f"{sum(len(p['r']) for p in pts):,} sampled for 3-D colour space widget.")

Image 7 — 65,536 total pixels, 4,500 sampled for 3-D colour space widget.

The 2-D projections above each collapse one colour channel — you can only see two dimensions at a time. The interactive plot below shows all three channels simultaneously. Drag to rotate and find the angle where the three classes separate most cleanly.

NoteWhat to look for

If the three clouds of points are well-separated in 3-D, a k-NN classifier operating purely on colour should work well. Rotate the plot to find the viewing angle that best separates the classes. If the clouds overlap significantly, the classifier will make systematic errors at those boundary regions — and no amount of tuning k will fix the underlying ambiguity.


5.4 Flattening Training Images into a Pixel Matrix

To train k-NN we need every training pixel in a single table:

  • Feature matrix X_train_px: shape (total_pixels, 3) — each row is one pixel’s (R, G, B) values
  • Label vector y_train_px: shape (total_pixels,) — each entry is 0, 1, or 2

With a 256×256 image and ~130 training images, the full table would have ~8.5 million rows — training k-NN on all of them would be slow. We subsample — picking a fixed number of random pixels from each training image — to keep the problem manageable.

A naive uniform subsample mirrors each image’s natural class distribution, which is heavily skewed toward background (most of the image is empty space). To avoid training a classifier that barely ever sees nucleus pixels, we use stratified sampling: draw the same fixed quota from each class in every image.

Code
PIXELS_PER_CLASS = 150   # 150 pixels × 3 classes = 450 per training image
rng = np.random.default_rng(42)

px_list  = []
lbl_list = []

for i in train_idx:
    pixels = X[i].transpose(1, 2, 0).reshape(-1, 3)  # (H*W, 3)
    lbls   = y[i].reshape(-1)                         # (H*W,)

    for cls in [0, 1, 2]:
        cls_idx = np.where(lbls == cls)[0]              # positions of this class
        n       = min(PIXELS_PER_CLASS, len(cls_idx))   # guard: tiny regions
        chosen  = rng.choice(cls_idx, n, replace=False)
        px_list.append(pixels[chosen])
        lbl_list.append(lbls[chosen])

X_train_px = np.vstack(px_list)
y_train_px = np.hstack(lbl_list)

print(f"Training pixel matrix : {X_train_px.shape}  dtype: {X_train_px.dtype}")
print(f"Training label vector : {y_train_px.shape}")

# Class balance check — should be approximately equal across all three classes
for cls, name in enumerate(['background', 'cytoplasm', 'nucleus']):
    n = (y_train_px == cls).sum()
    print(f"  {name:>12s}: {n:6d}  ({100*n/len(y_train_px):.1f}%)")
Training pixel matrix : (70800, 3)  dtype: float32
Training label vector : (70800,)
    background:  24000  (33.9%)
     cytoplasm:  24000  (33.9%)
       nucleus:  22800  (32.2%)
NoteWhy stratified subsampling matters

Each image is 256×256 = 65,536 pixels. Background typically dominates — often 70–80% of the image is empty space around the cell. A uniform random sample of 500 pixels would therefore give the classifier ~375 background examples and only ~25 nucleus examples per image. Stratified sampling overrides this imbalance by enforcing equal quotas: 150 pixels per class per image. The classifier then sees the minority class (nucleus) as often as the majority class, which is especially important for getting accurate predictions in the clinically meaningful regions.


5.5 Train the k-NN Classifier

KNeighborsClassifier stores the entire training set in memory. At prediction time, for each query pixel it:

  1. Computes the Euclidean distance to every stored training pixel in 3-D colour space.
  2. Identifies the n_neighbors closest stored pixels.
  3. Returns the majority class label among those neighbours.

No learning happens during fit — the training data is the model. This is called a lazy learner.

NoteWhat “majority class label” means

After fit(), the model is just a lookup table of (R, G, B, label) tuples — the training pixels that were randomly sampled. The phrase “returns the majority class label among those neighbours” describes what happens at prediction time for each individual query pixel, not for the training pixels themselves.

For a single query pixel \(\mathbf{q} = (R, G, B)\), the classifier returns a single integer label \(\hat{y} \in \{0, 1, 2\}\):

\[\hat{y}(\mathbf{q}) = \underset{c \;\in\; \{0,1,2\}}{\arg\max} \sum_{i \;\in\; \mathcal{N}_k(\mathbf{q})} \mathbf{1}[y_i = c]\]

where \(\mathcal{N}_k(\mathbf{q})\) is the set of \(k\) training pixels nearest to \(\mathbf{q}\) in Euclidean distance and \(y_i\) is the ground-truth label of training pixel \(i\).

When knn.predict() is called on all \(H \times W\) pixels of an image at once, it returns a flat array of \(H \times W\) integers — one label per pixel. Reshaping that array to \((H, W)\) gives the segmentation mask: a complete label map covering every pixel in the image, including pixels whose exact RGB colour was never seen during training.

Code
from sklearn.neighbors import KNeighborsClassifier

knn = KNeighborsClassifier(n_neighbors=5, n_jobs=-1)   # n_jobs=-1 → all CPU cores

print("Fitting k-NN classifier...")
knn.fit(X_train_px, y_train_px)
print("Done.  The classifier has memorised all training pixels.")
print(f"Number of stored training points: {len(X_train_px):,}")
Fitting k-NN classifier...
Done.  The classifier has memorised all training pixels.
Number of stored training points: 70,800

5.6 Choosing k

The number of neighbours \(k\) (set by n_neighbors) is the only hyperparameter in k-NN. It controls a fundamental trade-off between variance and bias:

\(k\) Behaviour Failure mode
Small (\(k = 1\)) Every query is decided by its single nearest training neighbour High variance — a single mislabelled training pixel can corrupt all nearby predictions
Moderate (\(k = 5\)\(15\)) Majority vote over a small neighbourhood smooths out isolated noise Good balance for most problems
Large (\(k > 50\)) Vote taken from a very wide region of colour space High bias — fine boundaries are blurred; large blobs of uniform colour replace crisp edges

This is the bias–variance trade-off: small \(k\) gives a wiggly, high-variance decision boundary that fits every quirk in the training data; large \(k\) gives a smooth, high-bias boundary that may miss genuine structure.

5.6.1 Choosing k via cross-validation

The principled approach is k-fold cross-validation on the training set. For each candidate \(k\), split the training pixels into \(F\) folds, train on \(F-1\) folds, and evaluate on the held-out fold. Average the Dice (or accuracy) score across folds, then pick the \(k\) that maximises it.

In k-fold cross-validation the data is split into \(k\) equal folds. Each iteration holds out one fold as the validation set and trains on the remaining \(k-1\) folds. Every data point is validated exactly once, giving a robust estimate of generalisation performance.

For each candidate \(k\):

from sklearn.model_selection import cross_val_score

for k in [1, 3, 5, 7, 11, 15, 21]:
    clf    = KNeighborsClassifier(n_neighbors=k, n_jobs=-1)
    scores = cross_val_score(clf, X_train_px, y_train_px, cv=5, scoring='accuracy')
    print(f"k={k:3d}  accuracy = {scores.mean():.4f} ± {scores.std():.4f}")

A plot of validation accuracy vs \(k\) typically shows an elbow: rapid improvement from \(k = 1\) up to some moderate value, followed by a plateau or gentle decline. The elbow is the natural choice.

Rule of thumb: \(k \approx \sqrt{N_{\text{train}}}\) (where \(N_{\text{train}}\) is the number of training points) is a common starting point. With ~65,000 training pixels that gives \(k \approx 255\) — too large for boundary-heavy segmentation tasks. In practice, treat it as an upper bound and search from small values upward.

In this chapter we use n_neighbors = 5. Exercise 5.1 asks you to verify how the Dice score changes as \(k\) varies.


5.7 Predicting Masks for Test Images

To predict a segmentation mask for a new image we:

  1. Flatten the image from (H, W, 3)(H*W, 3).
  2. Pass all pixels at once to knn.predict().
  3. Reshape the flat predictions back to (H, W).

5.7.1 How knn.predict() works

knn.predict() treats every pixel in the input image as a point in 3-D colour space and answers one question: given this (R, G, B) colour, what class do the nearest training pixels say it belongs to?

Step by step for one query pixel \(\mathbf{q} = (R, G, B)\):

  1. Compute the Euclidean distance from \(\mathbf{q}\) to every stored training pixel \(\mathbf{x}_i\):

\[d(\mathbf{q},\, \mathbf{x}_i) = \sqrt{(R - x_{iR})^2 + (G - x_{iG})^2 + (B - x_{iB})^2}\]

  1. Rank all \(N_{\text{train}}\) training pixels by distance. Select the \(k\) smallest — the \(k\) nearest neighbours \(\mathcal{N}_k(\mathbf{q})\).

  2. Tally the class labels of those \(k\) neighbours and return the class that received the most votes:

\[\hat{y}(\mathbf{q}) = \underset{c \;\in\; \{0,1,2\}}{\arg\max} \sum_{i \;\in\; \mathcal{N}_k(\mathbf{q})} \mathbf{1}[y_i = c]\]

For an image with \(H \times W\) pixels, this procedure runs across all pixels in parallel (scikit-learn uses all CPU cores when n_jobs=-1), producing a flat array of \(H \times W\) class labels.

Non-sampled pixels (pixels from training images that were not selected during subsampling): the model has no record of them — it only stores the (R, G, B, label) pairs that were randomly drawn. When those pixels later appear as query points (e.g. if you run knn.predict() on a full training image to visualise the result), they are treated like any other input: the model finds their \(k\) nearest neighbours in the stored set and votes. A skipped pixel will be classified correctly as long as its colour falls near a well-represented region of colour space; it may be misclassified if it lies in a sparse or ambiguous region.

Test image pixels (pixels from images the model has never encountered): k-NN makes no distinction between a query from a training image and one from a test image. Every pixel is just a vector \((R, G, B)\). The model finds its nearest neighbours in the stored training set and votes. This is how k-NN generalises: it is not memorising spatial layouts or specific images — it is memorising a colour-to-class mapping. Any pixel whose colour falls in a region of colour space densely populated by, say, nucleus training pixels will be predicted as nucleus, regardless of which image it came from or where in that image it sits. The model’s generalisation ability therefore depends entirely on whether the training colour distribution is representative of the test colour distribution.

5.7.2 Try it: 2-D k-NN in action

The widget below lets you watch the k-NN vote happen for a single query pixel. We drop the blue channel and work in 2-D — so the decision is easy to see — using ~50 training points per class sampled from image 7. Drag the yellow query pixel across the plane (or use the R/G sliders), and change k with its slider. Training points that end up among the \(k\) nearest are highlighted with a ring, and thin lines trace the distances from the query to each neighbour. The shaded background shows the decision region for the current \(k\): the class the classifier would assign at every point on the plane.

Code
# ── Serialise 2-D (R, G) training data for the widget below ─────────────────
_widget_rng = np.random.default_rng(0)
widget_pts  = []
for cls in range(3):
    samp = pixels7[labels7 == cls]
    n    = min(50, len(samp))
    sel  = _widget_rng.choice(len(samp), n, replace=False)
    s    = samp[sel, :2]                        # keep R, G only
    widget_pts.append({
        'r':     s[:, 0].round(4).tolist(),
        'g':     s[:, 1].round(4).tolist(),
        'class': cls,
        'color': cls_colors[cls],
        'name':  cls_names[cls],
    })

display(HTML(f'<script>window._ch5KnnWidget={json.dumps(widget_pts)};</script>'))
print(f"Widget training data: {sum(len(p['r']) for p in widget_pts)} pixels from image {IDX} (R, G only).")
Widget training data: 150 pixels from image 7 (R, G only).

5.8 Test-Set Dice Scores and Comparison to Thresholding

We now run k-NN on all test images and compute the mean Dice score across cytoplasm and nucleus — the two clinically meaningful classes. We compare against the hand-picked grayscale threshold from Chapter 4.

Code
# ── k-NN Dice on full test set ────────────────────────────────────────────────
print("Running k-NN inference on all test images (may take a minute)...")
knn_scores = []
for i in test_idx:
    pred = predict_mask_knn(X[i], knn)
    d = (dice_score(pred, y[i], 1) + dice_score(pred, y[i], 2)) / 2
    knn_scores.append(d)

mean_knn = np.mean(knn_scores)
print(f"k-NN mean Dice (cyto+nuc): {mean_knn:.4f}  ±  {np.std(knn_scores):.4f}")

# ── Reference: simple grayscale threshold (Lab 01 / Chapter 4 style) ─────────
def segment_threshold(img, t_nucleus=0.3, t_cytoplasm_max=0.7):
    gray = img.mean(axis=0)   # (3, H, W) → (H, W), values already in [0, 1]
    pred = np.zeros(gray.shape, dtype=np.int64)
    pred[gray < t_nucleus]                                          = 2
    pred[(gray >= t_nucleus) & (gray <= t_cytoplasm_max)]          = 1
    return pred

thresh_scores = []
for i in test_idx:
    pred = segment_threshold(X[i])
    d = (dice_score(pred, y[i], 1) + dice_score(pred, y[i], 2)) / 2
    thresh_scores.append(d)

mean_thresh = np.mean(thresh_scores)

# ── Summary table ─────────────────────────────────────────────────────────────
print("\nCumulative Dice comparison (test set, mean of cytoplasm + nucleus):")
print("=" * 55)
print(f"{'Chapter 4 — grayscale threshold':<40}  {mean_thresh:.4f}")
print(f"{'Chapter 5 — k-NN (RGB colour only)':<40}  {mean_knn:.4f}")
print("=" * 55)
Running k-NN inference on all test images (may take a minute)...
k-NN mean Dice (cyto+nuc): 0.7602  ±  0.1061

Cumulative Dice comparison (test set, mean of cytoplasm + nucleus):
=======================================================
Chapter 4 — grayscale threshold           0.5224
Chapter 5 — k-NN (RGB colour only)        0.7602
=======================================================

5.9 Where k-NN Fails: Spatial Blindness

k-NN makes decisions based only on pixel colour. It has no idea whether a pixel is near the centre of the cell, at a boundary, or in the background. This leads to two characteristic failure modes:

  • Speckle noise: isolated pixels are assigned a different class from all their neighbours because their colour happens to be closer to another class in training data.
  • Boundary errors: pixels with ambiguous colours (where two regions meet) are decided purely by which training pixels are closest in colour space — with no smoothing from spatial context.

The error maps below make this concrete.

Code
fig, axes = plt.subplots(2, 3, figsize=(14, 8))
samples = test_idx[:2]

for row, idx in enumerate(samples):
    pred  = predict_mask_knn(X[idx], knn)
    error = (pred != y[idx])   # True where prediction is wrong

    axes[row, 0].imshow(X[idx].transpose(1, 2, 0))
    axes[row, 0].set_title(f"Image {idx} — RGB");  axes[row, 0].axis('off')

    axes[row, 1].imshow(pred, cmap=mask_cmap, vmin=0, vmax=2, interpolation='nearest')
    axes[row, 1].set_title("k-NN prediction");  axes[row, 1].axis('off')

    axes[row, 2].imshow(error, cmap='Reds', interpolation='nearest')
    axes[row, 2].set_title(f"Errors (red) — {error.mean()*100:.1f}% wrong")
    axes[row, 2].axis('off')

plt.suptitle("Error Maps: Where k-NN Gets It Wrong", fontsize=13, fontweight='bold')
plt.tight_layout()
plt.show()

print("Are errors clustered at boundaries, or scattered throughout?")
print("Scattered errors → spatial context would help. Boundary errors → need better features.")

Are errors clustered at boundaries, or scattered throughout?
Scattered errors → spatial context would help. Boundary errors → need better features.
NoteSpatial context: the missing ingredient

k-NN with colour features treats every pixel as independent. A nucleus pixel at the very edge of the nucleus has the same colour as one in the centre — but they sit in very different neighbourhoods. In the next two chapters we will introduce models that do look at neighbourhoods: convolutional neural networks read a patch of pixels at a time, giving them the spatial awareness k-NN lacks.


5.10 Summary

Aspect Grayscale threshold (Ch. 4) k-NN colour (Ch. 5)
Features used 1 (luminance) 3 (R, G, B)
Training data needed None Labelled pixels
Spatial context None None
Speed Very fast Slow (at prediction time)
Typical Dice ~0.55–0.65 ~0.60–0.70

Key takeaways:

  • Using all three colour channels instead of greyscale gives the classifier richer information, which typically improves segmentation.
  • k-NN is easy to implement and requires no hyperparameter tuning beyond k, but it is slow at prediction time and ignores spatial context entirely.
  • Errors cluster at class boundaries — where two classes share similar colours — confirming that colour alone is not always sufficient.
  • The next step toward better segmentation is to give the model spatial awareness. Convolutional networks (Chapters 6–7) do exactly that.

5.11 Exercises

Exercise 5.1 — Varying k

The number of neighbours k is a hyperparameter. Too small (k=1) → noisy; too large → over-smoothed boundaries.

Test k ∈ {1, 10, 25}. For each value, retrain the classifier and compute the mean Dice on the first 5 test images. Plot the error map for one image.

  • Which value of k gives the best Dice?
  • Does the error map look more or less speckly as k increases? Why?

Exercise 5.2 — Feature normalisation

Pixel values range from 0–255. Try normalising each channel to [0, 1] before training and predicting. Does Dice improve? Why might normalisation matter for distance-based classifiers?

Exercise 5.3 — Adding spatial features

Modify the pixel feature vector to include the pixel’s (row, col) position alongside its RGB values, giving a 5-D feature vector. Retrain and compare Dice. Do the segmentations look more spatially coherent? What is the downside of including position?

Exercise 5.4 — Class balance check

Inspect the class balance in y_train_px. If one class dominates, the classifier may be biased toward it. Try passing class_weight='balanced' — wait, does KNeighborsClassifier support that? Look at the scikit-learn docs and suggest an alternative strategy.

Sign in to save progress
My Progress

0 / 0

📚 Gradebook

Loading…

✏️ Speed Grader

Sign in to save progress