In Chapters 1–4 we converted each image to grayscale and applied a threshold to segment nucleus, cytoplasm, and background. That approach threw away two-thirds of the available information — the G and B channels — and replaced the classifier with a hard rule. Bayesian optimization could tune those thresholds, but the fundamental limitation remained: every pixel is judged by a single intensity value.
In this chapter we take a completely different view: every pixel is a point in 3-D colour space (R, G, B). We flatten all training images into a table of pixels, attach their ground-truth labels, and train a k-Nearest Neighbours (k-NN) classifier on them. For each test pixel the classifier finds its k nearest neighbours in the training data and takes a majority vote.
By the end of this chapter you will be able to:
Explain what it means for a pixel to be a point in 3-D colour space.
Flatten a batch of images into a pixel feature matrix.
Train a KNeighborsClassifier on pixel colours and predict segmentation masks.
Compare k-NN performance to thresholding using Dice scores.
Articulate why k-NN ignores spatial context and why that matters for cell segmentation.
5.1 Setup
Code
import globimport osimport numpy as npimport matplotlib.pyplot as pltimport matplotlib.patches as mpatchesfrom matplotlib.colors import ListedColormap# ── Load dataset ──────────────────────────────────────────────────────────────# X: (N, 3, H, W) float32 0–1 (channels first)# y: (N, H, W) int64 0/1/2try: _base ="C:/projects/VocEd" N =len(glob.glob(f"{_base}/imagedata/X/*.npy")) X = np.stack([np.load(f"{_base}/imagedata/X/{i}.npy") for i inrange(N)]) y = np.stack([np.load(f"{_base}/imagedata/y/{i}.npy") for i inrange(N)])exceptException:import subprocessifnot os.path.isdir("VocEd"): subprocess.run(["git", "clone", "https://github.com/emilsar/VocEd.git"], check=True) N =len(glob.glob("VocEd/imagedata/X/*.npy")) X = np.stack([np.load(f"VocEd/imagedata/X/{i}.npy") for i inrange(N)]) y = np.stack([np.load(f"VocEd/imagedata/y/{i}.npy") for i inrange(N)])# ── Visualisation helpers ─────────────────────────────────────────────────────mask_cmap = ListedColormap(['black', 'steelblue', 'crimson'])legend_patches = [ mpatches.Patch(color='black', label='0 — background'), mpatches.Patch(color='steelblue', label='1 — cytoplasm'), mpatches.Patch(color='crimson', label='2 — nucleus'),]# ── Dice helper (reused throughout the chapter) ───────────────────────────────def dice_score(pred, target, cls): pred_mask = (pred == cls) target_mask = (target == cls) intersection = (pred_mask & target_mask).sum() denom = pred_mask.sum() + target_mask.sum()return1.0if denom ==0else2* intersection / denomprint(f"Loaded {N} images. Image shape: {X[0].shape} Label shape: {y[0].shape}")
We use the same reproducible 80/20 split that will be used in all subsequent chapters.
Code
from sklearn.model_selection import train_test_split# Stratify not needed here — each image has all 3 classestrain_idx, test_idx = train_test_split( np.arange(N), test_size=0.2, random_state=42)print(f"Train: {len(train_idx)} images Test: {len(test_idx)} images")
Train: 160 images Test: 40 images
5.3 Visualising Colour Space
Before training anything, let’s check whether the three classes actually form separable clusters in RGB space. We use image 7 as a representative example — first viewing the raw image and its ground-truth mask, then projecting its pixels into colour space.
The key question: do background, cytoplasm, and nucleus pixels occupy different regions in colour space? If they do, a colour-based classifier has a chance.
Code
import jsonfrom IPython.display import HTML, displayIDX =7img7 = X[IDX] # (3, H, W) float32 0–1mask7 = y[IDX] # (H, W) 0/1/2pixels7 = img7.transpose(1, 2, 0).reshape(-1, 3) # (H*W, 3)labels7 = mask7.reshape(-1) # (H*W,)rng = np.random.default_rng(0)cls_colors = ['#333333', 'steelblue', 'crimson']cls_names = ['background', 'cytoplasm', 'nucleus']# ── Panel 1: raw image + ground-truth mask ────────────────────────────────────fig, axes = plt.subplots(1, 2, figsize=(10, 4))axes[0].imshow(img7.transpose(1, 2, 0))axes[0].set_title(f"Image {IDX} — RGB", fontsize=12)axes[0].axis('off')axes[1].imshow(mask7, cmap=mask_cmap, vmin=0, vmax=2, interpolation='nearest')axes[1].set_title(f"Image {IDX} — Ground-truth mask", fontsize=12)axes[1].axis('off')fig.legend(handles=legend_patches, loc='lower center', ncol=3, fontsize=10, bbox_to_anchor=(0.5, 0.0))plt.tight_layout(rect=[0, 0.07, 1, 1])plt.show()# ── Panel 2: 2-D colour-space projections ─────────────────────────────────────fig, axes = plt.subplots(1, 2, figsize=(12, 5))for cls inrange(3): samp = pixels7[labels7 == cls] n =min(3000, len(samp)) chosen = rng.choice(len(samp), n, replace=False) s = samp[chosen] axes[0].scatter(s[:, 0], s[:, 1], c=cls_colors[cls], s=1, alpha=0.5, label=cls_names[cls]) axes[1].scatter(s[:, 0], s[:, 2], c=cls_colors[cls], s=1, alpha=0.5, label=cls_names[cls])axes[0].set_xlabel('Red'); axes[0].set_ylabel('Green')axes[0].set_title('2-D projection — Red vs Green'); axes[0].legend(markerscale=6)axes[1].set_xlabel('Red'); axes[1].set_ylabel('Blue')axes[1].set_title('2-D projection — Red vs Blue'); axes[1].legend(markerscale=6)plt.tight_layout()plt.show()# ── Serialise for the 3-D widget below ───────────────────────────────────────pts = []for cls inrange(3): samp = pixels7[labels7 == cls] n =min(1500, len(samp)) sel = rng.choice(len(samp), n, replace=False) s = samp[sel] pts.append({'r': s[:, 0].round(4).tolist(),'g': s[:, 1].round(4).tolist(),'b': s[:, 2].round(4).tolist(),'color': cls_colors[cls],'name': cls_names[cls]})display(HTML(f'<script>window._ch5ColourData={json.dumps(pts)};</script>'))print(f"Image {IDX} — {pixels7.shape[0]:,} total pixels, "f"{sum(len(p['r']) for p in pts):,} sampled for 3-D colour space widget.")
Image 7 — 65,536 total pixels, 4,500 sampled for 3-D colour space widget.
The 2-D projections above each collapse one colour channel — you can only see two dimensions at a time. The interactive plot below shows all three channels simultaneously. Drag to rotate and find the angle where the three classes separate most cleanly.
NoteWhat to look for
If the three clouds of points are well-separated in 3-D, a k-NN classifier operating purely on colour should work well. Rotate the plot to find the viewing angle that best separates the classes. If the clouds overlap significantly, the classifier will make systematic errors at those boundary regions — and no amount of tuning k will fix the underlying ambiguity.
5.4 Flattening Training Images into a Pixel Matrix
To train k-NN we need every training pixel in a single table:
Feature matrixX_train_px: shape (total_pixels, 3) — each row is one pixel’s (R, G, B) values
Label vectory_train_px: shape (total_pixels,) — each entry is 0, 1, or 2
With a 256×256 image and ~130 training images, the full table would have ~8.5 million rows — training k-NN on all of them would be slow. We subsample — picking a fixed number of random pixels from each training image — to keep the problem manageable.
A naive uniform subsample mirrors each image’s natural class distribution, which is heavily skewed toward background (most of the image is empty space). To avoid training a classifier that barely ever sees nucleus pixels, we use stratified sampling: draw the same fixed quota from each class in every image.
Code
PIXELS_PER_CLASS =150# 150 pixels × 3 classes = 450 per training imagerng = np.random.default_rng(42)px_list = []lbl_list = []for i in train_idx: pixels = X[i].transpose(1, 2, 0).reshape(-1, 3) # (H*W, 3) lbls = y[i].reshape(-1) # (H*W,)for cls in [0, 1, 2]: cls_idx = np.where(lbls == cls)[0] # positions of this class n =min(PIXELS_PER_CLASS, len(cls_idx)) # guard: tiny regions chosen = rng.choice(cls_idx, n, replace=False) px_list.append(pixels[chosen]) lbl_list.append(lbls[chosen])X_train_px = np.vstack(px_list)y_train_px = np.hstack(lbl_list)print(f"Training pixel matrix : {X_train_px.shape} dtype: {X_train_px.dtype}")print(f"Training label vector : {y_train_px.shape}")# Class balance check — should be approximately equal across all three classesfor cls, name inenumerate(['background', 'cytoplasm', 'nucleus']): n = (y_train_px == cls).sum()print(f" {name:>12s}: {n:6d} ({100*n/len(y_train_px):.1f}%)")
Each image is 256×256 = 65,536 pixels. Background typically dominates — often 70–80% of the image is empty space around the cell. A uniform random sample of 500 pixels would therefore give the classifier ~375 background examples and only ~25 nucleus examples per image. Stratified sampling overrides this imbalance by enforcing equal quotas: 150 pixels per class per image. The classifier then sees the minority class (nucleus) as often as the majority class, which is especially important for getting accurate predictions in the clinically meaningful regions.
5.5 Train the k-NN Classifier
KNeighborsClassifier stores the entire training set in memory. At prediction time, for each query pixel it:
Computes the Euclidean distance to every stored training pixel in 3-D colour space.
Identifies the n_neighbors closest stored pixels.
Returns the majority class label among those neighbours.
No learning happens during fit — the training data is the model. This is called a lazy learner.
NoteWhat “majority class label” means
After fit(), the model is just a lookup table of (R, G, B, label) tuples — the training pixels that were randomly sampled. The phrase “returns the majority class label among those neighbours” describes what happens at prediction time for each individual query pixel, not for the training pixels themselves.
For a single query pixel \(\mathbf{q} = (R, G, B)\), the classifier returns a single integer label \(\hat{y} \in \{0, 1, 2\}\):
where \(\mathcal{N}_k(\mathbf{q})\) is the set of \(k\) training pixels nearest to \(\mathbf{q}\) in Euclidean distance and \(y_i\) is the ground-truth label of training pixel \(i\).
When knn.predict() is called on all \(H \times W\) pixels of an image at once, it returns a flat array of \(H \times W\) integers — one label per pixel. Reshaping that array to \((H, W)\) gives the segmentation mask: a complete label map covering every pixel in the image, including pixels whose exact RGB colour was never seen during training.
Code
from sklearn.neighbors import KNeighborsClassifierknn = KNeighborsClassifier(n_neighbors=5, n_jobs=-1) # n_jobs=-1 → all CPU coresprint("Fitting k-NN classifier...")knn.fit(X_train_px, y_train_px)print("Done. The classifier has memorised all training pixels.")print(f"Number of stored training points: {len(X_train_px):,}")
Fitting k-NN classifier...
Done. The classifier has memorised all training pixels.
Number of stored training points: 70,800
5.6 Choosing k
The number of neighbours \(k\) (set by n_neighbors) is the only hyperparameter in k-NN. It controls a fundamental trade-off between variance and bias:
\(k\)
Behaviour
Failure mode
Small (\(k = 1\))
Every query is decided by its single nearest training neighbour
High variance — a single mislabelled training pixel can corrupt all nearby predictions
Moderate (\(k = 5\)–\(15\))
Majority vote over a small neighbourhood smooths out isolated noise
Good balance for most problems
Large (\(k > 50\))
Vote taken from a very wide region of colour space
High bias — fine boundaries are blurred; large blobs of uniform colour replace crisp edges
This is the bias–variance trade-off: small \(k\) gives a wiggly, high-variance decision boundary that fits every quirk in the training data; large \(k\) gives a smooth, high-bias boundary that may miss genuine structure.
5.6.1 Choosing k via cross-validation
The principled approach is k-fold cross-validation on the training set. For each candidate \(k\), split the training pixels into \(F\) folds, train on \(F-1\) folds, and evaluate on the held-out fold. Average the Dice (or accuracy) score across folds, then pick the \(k\) that maximises it.
In k-fold cross-validation the data is split into \(k\) equal folds. Each iteration holds out one fold as the validation set and trains on the remaining \(k-1\) folds. Every data point is validated exactly once, giving a robust estimate of generalisation performance.
For each candidate \(k\):
from sklearn.model_selection import cross_val_scorefor k in [1, 3, 5, 7, 11, 15, 21]: clf = KNeighborsClassifier(n_neighbors=k, n_jobs=-1) scores = cross_val_score(clf, X_train_px, y_train_px, cv=5, scoring='accuracy')print(f"k={k:3d} accuracy = {scores.mean():.4f} ± {scores.std():.4f}")
A plot of validation accuracy vs \(k\) typically shows an elbow: rapid improvement from \(k = 1\) up to some moderate value, followed by a plateau or gentle decline. The elbow is the natural choice.
Rule of thumb:\(k \approx \sqrt{N_{\text{train}}}\) (where \(N_{\text{train}}\) is the number of training points) is a common starting point. With ~65,000 training pixels that gives \(k \approx 255\) — too large for boundary-heavy segmentation tasks. In practice, treat it as an upper bound and search from small values upward.
In this chapter we use n_neighbors = 5. Exercise 5.1 asks you to verify how the Dice score changes as \(k\) varies.
5.7 Predicting Masks for Test Images
To predict a segmentation mask for a new image we:
Flatten the image from (H, W, 3) → (H*W, 3).
Pass all pixels at once to knn.predict().
Reshape the flat predictions back to (H, W).
5.7.1 How knn.predict() works
knn.predict() treats every pixel in the input image as a point in 3-D colour space and answers one question: given this (R, G, B) colour, what class do the nearest training pixels say it belongs to?
Step by step for one query pixel \(\mathbf{q} = (R, G, B)\):
Compute the Euclidean distance from \(\mathbf{q}\) to every stored training pixel \(\mathbf{x}_i\):
For an image with \(H \times W\) pixels, this procedure runs across all pixels in parallel (scikit-learn uses all CPU cores when n_jobs=-1), producing a flat array of \(H \times W\) class labels.
Non-sampled pixels (pixels from training images that were not selected during subsampling): the model has no record of them — it only stores the (R, G, B, label) pairs that were randomly drawn. When those pixels later appear as query points (e.g. if you run knn.predict() on a full training image to visualise the result), they are treated like any other input: the model finds their \(k\) nearest neighbours in the stored set and votes. A skipped pixel will be classified correctly as long as its colour falls near a well-represented region of colour space; it may be misclassified if it lies in a sparse or ambiguous region.
Test image pixels (pixels from images the model has never encountered): k-NN makes no distinction between a query from a training image and one from a test image. Every pixel is just a vector \((R, G, B)\). The model finds its nearest neighbours in the stored training set and votes. This is how k-NN generalises: it is not memorising spatial layouts or specific images — it is memorising a colour-to-class mapping. Any pixel whose colour falls in a region of colour space densely populated by, say, nucleus training pixels will be predicted as nucleus, regardless of which image it came from or where in that image it sits. The model’s generalisation ability therefore depends entirely on whether the training colour distribution is representative of the test colour distribution.
5.7.2 Try it: 2-D k-NN in action
The widget below lets you watch the k-NN vote happen for a single query pixel. We drop the blue channel and work in 2-D — so the decision is easy to see — using ~50 training points per class sampled from image 7. Drag the yellow query pixel across the plane (or use the R/G sliders), and change k with its slider. Training points that end up among the \(k\) nearest are highlighted with a ring, and thin lines trace the distances from the query to each neighbour. The shaded background shows the decision region for the current \(k\): the class the classifier would assign at every point on the plane.
Code
# ── Serialise 2-D (R, G) training data for the widget below ─────────────────_widget_rng = np.random.default_rng(0)widget_pts = []for cls inrange(3): samp = pixels7[labels7 == cls] n =min(50, len(samp)) sel = _widget_rng.choice(len(samp), n, replace=False) s = samp[sel, :2] # keep R, G only widget_pts.append({'r': s[:, 0].round(4).tolist(),'g': s[:, 1].round(4).tolist(),'class': cls,'color': cls_colors[cls],'name': cls_names[cls], })display(HTML(f'<script>window._ch5KnnWidget={json.dumps(widget_pts)};</script>'))print(f"Widget training data: {sum(len(p['r']) for p in widget_pts)} pixels from image {IDX} (R, G only).")
Widget training data: 150 pixels from image 7 (R, G only).
5.8 Test-Set Dice Scores and Comparison to Thresholding
We now run k-NN on all test images and compute the mean Dice score across cytoplasm and nucleus — the two clinically meaningful classes. We compare against the hand-picked grayscale threshold from Chapter 4.
Code
# ── k-NN Dice on full test set ────────────────────────────────────────────────print("Running k-NN inference on all test images (may take a minute)...")knn_scores = []for i in test_idx: pred = predict_mask_knn(X[i], knn) d = (dice_score(pred, y[i], 1) + dice_score(pred, y[i], 2)) /2 knn_scores.append(d)mean_knn = np.mean(knn_scores)print(f"k-NN mean Dice (cyto+nuc): {mean_knn:.4f} ± {np.std(knn_scores):.4f}")# ── Reference: simple grayscale threshold (Lab 01 / Chapter 4 style) ─────────def segment_threshold(img, t_nucleus=0.3, t_cytoplasm_max=0.7): gray = img.mean(axis=0) # (3, H, W) → (H, W), values already in [0, 1] pred = np.zeros(gray.shape, dtype=np.int64) pred[gray < t_nucleus] =2 pred[(gray >= t_nucleus) & (gray <= t_cytoplasm_max)] =1return predthresh_scores = []for i in test_idx: pred = segment_threshold(X[i]) d = (dice_score(pred, y[i], 1) + dice_score(pred, y[i], 2)) /2 thresh_scores.append(d)mean_thresh = np.mean(thresh_scores)# ── Summary table ─────────────────────────────────────────────────────────────print("\nCumulative Dice comparison (test set, mean of cytoplasm + nucleus):")print("="*55)print(f"{'Chapter 4 — grayscale threshold':<40}{mean_thresh:.4f}")print(f"{'Chapter 5 — k-NN (RGB colour only)':<40}{mean_knn:.4f}")print("="*55)
Running k-NN inference on all test images (may take a minute)...
k-NN mean Dice (cyto+nuc): 0.7602 ± 0.1061
Cumulative Dice comparison (test set, mean of cytoplasm + nucleus):
=======================================================
Chapter 4 — grayscale threshold 0.5224
Chapter 5 — k-NN (RGB colour only) 0.7602
=======================================================
5.9 Where k-NN Fails: Spatial Blindness
k-NN makes decisions based only on pixel colour. It has no idea whether a pixel is near the centre of the cell, at a boundary, or in the background. This leads to two characteristic failure modes:
Speckle noise: isolated pixels are assigned a different class from all their neighbours because their colour happens to be closer to another class in training data.
Boundary errors: pixels with ambiguous colours (where two regions meet) are decided purely by which training pixels are closest in colour space — with no smoothing from spatial context.
The error maps below make this concrete.
Code
fig, axes = plt.subplots(2, 3, figsize=(14, 8))samples = test_idx[:2]for row, idx inenumerate(samples): pred = predict_mask_knn(X[idx], knn) error = (pred != y[idx]) # True where prediction is wrong axes[row, 0].imshow(X[idx].transpose(1, 2, 0)) axes[row, 0].set_title(f"Image {idx} — RGB"); axes[row, 0].axis('off') axes[row, 1].imshow(pred, cmap=mask_cmap, vmin=0, vmax=2, interpolation='nearest') axes[row, 1].set_title("k-NN prediction"); axes[row, 1].axis('off') axes[row, 2].imshow(error, cmap='Reds', interpolation='nearest') axes[row, 2].set_title(f"Errors (red) — {error.mean()*100:.1f}% wrong") axes[row, 2].axis('off')plt.suptitle("Error Maps: Where k-NN Gets It Wrong", fontsize=13, fontweight='bold')plt.tight_layout()plt.show()print("Are errors clustered at boundaries, or scattered throughout?")print("Scattered errors → spatial context would help. Boundary errors → need better features.")
Are errors clustered at boundaries, or scattered throughout?
Scattered errors → spatial context would help. Boundary errors → need better features.
NoteSpatial context: the missing ingredient
k-NN with colour features treats every pixel as independent. A nucleus pixel at the very edge of the nucleus has the same colour as one in the centre — but they sit in very different neighbourhoods. In the next two chapters we will introduce models that do look at neighbourhoods: convolutional neural networks read a patch of pixels at a time, giving them the spatial awareness k-NN lacks.
5.10 Summary
Aspect
Grayscale threshold (Ch. 4)
k-NN colour (Ch. 5)
Features used
1 (luminance)
3 (R, G, B)
Training data needed
None
Labelled pixels
Spatial context
None
None
Speed
Very fast
Slow (at prediction time)
Typical Dice
~0.55–0.65
~0.60–0.70
Key takeaways:
Using all three colour channels instead of greyscale gives the classifier richer information, which typically improves segmentation.
k-NN is easy to implement and requires no hyperparameter tuning beyond k, but it is slow at prediction time and ignores spatial context entirely.
Errors cluster at class boundaries — where two classes share similar colours — confirming that colour alone is not always sufficient.
The next step toward better segmentation is to give the model spatial awareness. Convolutional networks (Chapters 6–7) do exactly that.
5.11 Exercises
Exercise 5.1 — Varying k
The number of neighbours k is a hyperparameter. Too small (k=1) → noisy; too large → over-smoothed boundaries.
Test k ∈ {1, 10, 25}. For each value, retrain the classifier and compute the mean Dice on the first 5 test images. Plot the error map for one image.
Which value of k gives the best Dice?
Does the error map look more or less speckly as k increases? Why?
Exercise 5.2 — Feature normalisation
Pixel values range from 0–255. Try normalising each channel to [0, 1] before training and predicting. Does Dice improve? Why might normalisation matter for distance-based classifiers?
Exercise 5.3 — Adding spatial features
Modify the pixel feature vector to include the pixel’s (row, col) position alongside its RGB values, giving a 5-D feature vector. Retrain and compare Dice. Do the segmentations look more spatially coherent? What is the downside of including position?
Exercise 5.4 — Class balance check
Inspect the class balance in y_train_px. If one class dominates, the classifier may be biased toward it. Try passing class_weight='balanced' — wait, does KNeighborsClassifier support that? Look at the scikit-learn docs and suggest an alternative strategy.