dl_08 — UNet / UNet3+ Feature Map Visualization¶

This notebook opens up a trained wetland-classification model and shows, in pictures, what happens inside it as a single aerial image patch passes through. It's written to be readable even if you've never worked with deep learning.

What is this model doing?¶

The model is a U-Net (here, the UNet3+ variant) — a neural network for semantic segmentation, i.e. assigning every pixel of an image a class. In this project the classes are wetland types (EMW emergent marsh, FSW forested/shrub wetland, SSW scrub-shrub wetland) plus UPL = upland (non-wetland). The model reads a stack of georeferenced input layers — elevation, slope, aerial photos, canopy height, etc. — and outputs a class for each pixel.

A U-Net has three parts, and this notebook visualizes each one:

  • Encoder — repeatedly shrinks the image while pulling out features, moving from raw cues (edges, colors, textures) toward abstract concepts ("this region looks like marsh"). Think of it as zooming out to understand context.
  • Bottleneck — the smallest, most abstract representation, in the middle of the network.
  • Decoder — rebuilds full resolution, combining that abstract understanding with the fine detail saved from the encoder, to draw sharp class boundaries. Think of it as zooming back in to place the labels precisely.

The "U" is literal — data flows down the left side, across the bottom, and up the right. For this project's data (a 256 m × 256 m patch at 1 m resolution, depth-4 model):

in: 256×256 patch, 26 layers                out: 256×256 map, 4 classes
  enc0 (256×256) ── skip ─────────────────────→ dec0 (256×256) → prediction
    enc1 (128×128) ── skip ──────────────→ dec1 (128×128)
      enc2 (64×64) ── skip ─────────→ dec2 (64×64)
        enc3 (32×32) ── skip ──→ dec3 (32×32)
                 bottleneck (16×16)

The horizontal skip connections carry fine detail straight across the U so the decoder doesn't have to reconstruct it from the blurry bottleneck. (In UNet3+, every decoder node actually receives skips from all encoder levels at once, not just its own — that's the "3+".) For a written walk-through, see UNet_Architecture_Overview.md in this folder.

A feature map is one channel of a layer's output — a grayscale image where bright pixels mean "this learned pattern fired strongly here." Early layers have feature maps for simple things (edges, brightness); deep layers have feature maps for complex, wetland-specific patterns. The plots below walk through these stage by stage, ending in the final prediction.

How it works¶

Runs one patch through the model and taps each stage with PyTorch forward hooks (read-only — the model itself is never changed). Two ways to view each stage:

  • Top-variance channels (default): the 1–3 most active channels — the ones actually carrying signal.
  • Mean activation map: all channels averaged into one image — a clean "what does this level respond to" summary.

The bottleneck shows both side by side, and there's a whole-network overview. Compute is tiny: one forward pass on a single patch (runs fine on CPU).

1. Imports & paths¶

Load the libraries and locate the project folders — Models/ (trained checkpoints) and Data/Training_Data/R_Patches/ (the 256×256 image patches). Nothing model-specific happens yet; this just wires things up.

Project root: /Users/Anthony/Data and Analysis Local/NYS_Wetlands_DL
Models dir:   /Users/Anthony/Data and Analysis Local/NYS_Wetlands_DL/Models
Patches dir:  /Users/Anthony/Data and Analysis Local/NYS_Wetlands_DL/Data/Training_Data/R_Patches

2. Choose a checkpoint¶

The cell below lists every .safetensors / .ckpt in Models/. Set CHECKPOINT_NAME to whichever you want — prefer the .safetensors (self-describing, no architecture flags needed).

Available checkpoints in Models/:
  [0] best_binary_bf64_d4_20260511_1527.ckpt
  [1] best_binary_bf64_d4_20260511_1527.safetensors
  [2] best_binary_unet3plus_bf64_d4_20260605_1548.ckpt
  [3] best_binary_unet3plus_bf64_d4_20260605_1548.safetensors
  [4] best_multiclass_bf128_d4_20260420_1924.safetensors
  [5] best_multiclass_bf128_d4_20260420_2139.safetensors
  [6] best_multiclass_bf128_d5_20260420_2152.safetensors
  [7] best_multiclass_bf32_d4_20260510_2014.ckpt
  [8] best_multiclass_bf64_d4_20260420_0134.safetensors
  [9] best_multiclass_bf64_d4_20260420_0209.safetensors
  [10] best_multiclass_bf64_d4_20260420_0230.safetensors
  [11] best_multiclass_bf64_d4_20260420_0245.safetensors
  [12] best_multiclass_bf64_d4_20260420_0304.safetensors
  [13] best_multiclass_bf64_d4_20260420_2239.safetensors
  [14] best_multiclass_bf64_d4_20260511_1543.ckpt
  [15] best_multiclass_bf64_d4_20260511_1543.safetensors
  [16] best_multiclass_bf64_d5_20260420_2226.safetensors
  [17] best_multiclass_unet3plus_bf64_d4_20260605_1522.ckpt
  [18] best_multiclass_unet3plus_bf64_d4_20260605_1522.safetensors
  [19] best_multiclass_unet3plus_bf64_d4_20260605_1532.ckpt
  [20] best_multiclass_unet3plus_bf64_d4_20260605_1537.ckpt
  [21] best_multiclass_unet3plus_bf64_d4_20260605_1537.safetensors

Selected: best_multiclass_unet3plus_bf64_d4_20260605_1537.safetensors

3. Visualization settings¶

The knobs for everything below: how many channels to show per stage (CHANNELS_PER_LEVEL), the heatmap color scheme (FEATURE_CMAP), how see-through the interactive map overlays are (OVERLAY_OPACITY), and which hardware to run on (DEVICE — CPU is plenty for a single patch; MPS/GPU also work).

Device: mps

4. Load the model¶

load_model() reads the architecture straight from the .safetensors sidecar (.meta.json) or the .ckpt hyperparameters, so no architecture flags are needed.

Loaded model from /Users/Anthony/Data and Analysis Local/NYS_Wetlands_DL/Models/best_multiclass_unet3plus_bf64_d4_20260605_1537.safetensors (safetensors)
  Architecture: unet3plus(in=26, bf=64, depth=4)
  Epoch: 29

Architecture: UNet3Plus, depth=4, num_classes=4
Classes: ['EMW', 'FSW', 'SSW', 'UPL']  (mode=multiclass)

5. Choose a patch & normalize it¶

Pick one 256 m × 256 m training patch to send through the model. Before the model sees it, every band is normalized — rescaled to roughly 0–1 using the same statistics as during training — so a band measured in big units (elevation in meters) can't drown out one that lives between 0 and 1 (a vegetation fraction). We reuse the exact WetlandPatchDataset code from training, so the model gets inputs in the form it learned on; the stats file is matched to the checkpoint's classification mode.

One count to not be confused by: the patch has 18 raster bands (17 predictors + the MOD_CLASS label), but the model input below has 26 channels. That's because the categorical landform band (Geomorph_local) is one-hot encoded — split into ten separate 0/1 layers, one per landform type — since landform category #7 isn't "more than"

3; the categories have no numeric order a network could safely do math on.¶

Stats: multiclass_normalization_stats.json
474 patches available. First few:
  ADK_cluster_11_huc_042900030103_patch_10_256m.tif
  ADK_cluster_11_huc_042900030103_patch_11_256m.tif
  ADK_cluster_11_huc_042900030103_patch_12_256m.tif
Selected patch: NEW_cluster_225_huc_043001060202_patch_24_256m.tif
Input tensor: (1, 26, 256, 256), dtype=torch.float32
Label unique: [0, 1, 2, 3]  (255 = unlabeled)
Bands in patch: ['DEM', 'slope_local', 'Geomorph_local', 'flowacc', 'twi', 'CHM', 'r', 'g', 'b', 'nir', 'r_lo', 'g_lo', 'b_lo', 'nir_lo', 'pct_below_1m', 'pct_1m_to_5m', 'pct_above_5m', 'MOD_CLASS']

6. Register forward hooks & run the patch¶

A forward hook is a small listener attached to a layer: when the patch flows through the network, the hook quietly copies that layer's output so we can plot it afterward — the model itself is never modified. We attach one to every encoder level, the bottleneck, every decoder node, and the output head, then run a single forward pass.

  • encoders[i] → returns (pooled, skip); we keep the skip (the full feature map at level i)
  • bottleneck → the coarsest, most abstract feature map
  • fuse_se[idx] → a fused full-scale decoder node (UNet3+); for the plain UNet we fall back to its decoder blocks
  • head → the logits (raw, pre-softmax class scores)

In the printout below, notice the trade the encoder makes: each level halves the image size but doubles the channel count (64 → 128 → 256 → 512 → 1024) — giving up where detail to store more kinds of what information. The decoder then runs the size back up to 256×256, ending at the head's 4 channels: one score map per wetland class.

Captured feature maps (name: channels x H x W):
  enc0           64 x 256 x 256
  enc1          128 x 128 x 128
  enc2          256 x  64 x  64
  enc3          512 x  32 x  32
  bottleneck   1024 x  16 x  16
  dec3          320 x  32 x  32
  dec2          320 x  64 x  64
  dec1          320 x 128 x 128
  dec0          320 x 256 x 256
  logits          4 x 256 x 256

7. Plot helpers¶

Small reusable functions for the plots that follow: norm01 rescales a feature map to the 0–1 range so it displays well, top_variance_channels picks the most active channels, and plot_levels lays out the per-stage grids. (Note: because each map is rescaled on its own, brightness is only comparable within a single panel, not across panels.)

8. Input reference¶

Before looking inside the model, here's the context: the model's inputs and the answer key. The interactive map below lets you flip through every input layer — the NAIP true-color imagery (what a person would see), each individual predictor band, and the ground-truth label map (MOD_CLASS) a human annotator assigned. Keep these in mind — every feature map further down is the model working its way from those inputs toward that label map.

8b. Interactive Leaflet map (folium)¶

Every input layer rendered as a georeferenced overlay on an Esri satellite basemap. This is a real web map — pan, zoom (scroll), and read coordinates just like in a GIS:

  • The radio control (top-right) flips through layers one at a time: the NAIP RGB (leaf-on) and Leaf-off RGB true-color composites, the MOD_CLASS label (class colors, legend bottom-left), and each individual predictor band.
  • "None (basemap only)" switches every overlay off so you can compare a band against the real ground in the satellite image.
  • The red outline marks the patch footprint (256 m × 256 m) — zoom out and you can see exactly where in New York this patch sits.
  • A scale bar (bottom-left) and the cursor's lat/lon (bottom-right) help relate what you see to ground distances and locations.

Continuous bands get a percentile (2–98%) contrast stretch; nodata/unlabeled pixels are transparent. Geomorph_local is the exception — it's categorical (ten landform types, not a measured quantity), so it's drawn with the standard geomorphon colors and its own legend pops up (bottom-right) while that layer is selected. Geomorphons describe local terrain shape from the DEM; watch how the water-collecting landforms — valleys, hollows, footslopes, pits, flats — trace the drainage network where wetlands tend to form.

Overlays are embedded in the notebook (work offline); only the basemap tiles need internet.

Built Leaflet map with 20 layers (+ 'None' toggle): ['NAIP RGB (leaf-on)', 'Leaf-off RGB', 'MOD_CLASS (label)', 'DEM', 'slope_local', 'Geomorph_local', 'flowacc', 'twi', 'CHM', 'r', 'g', 'b', 'nir', 'r_lo', 'g_lo', 'b_lo', 'nir_lo', 'pct_below_1m', 'pct_1m_to_5m', 'pct_above_5m']
Make this Notebook Trusted to load map: File -> Trust Notebook

9. Encoder feature maps (top-variance channels)¶

The encoder as it zooms out. The top row (enc0, full resolution) responds to low-level cues — edges, texture, color/brightness contrast that still look a lot like the inputs. Each row down is half the size and a step more abstract, trading spatial detail for "meaning." Rows go fine → coarse. (We show the few most active channels at each level; there are dozens to hundreds in total.)

Tip: flip the map above to NAIP RGB and compare — in enc0/enc1 you can usually recognize what a channel locked onto (a stream channel, a field edge, canopy texture). By enc3 that's much harder; the maps are responding to broader patterns, not objects.

No description has been provided for this image

10. Bottleneck — BOTH views¶

The deepest, most compressed stage — the model's most abstract "summary" of the patch. It's small and blurry by design: at this point the network captures roughly what is here rather than exactly where. At depth 4 each bottleneck pixel summarizes a 16 m × 16 m area of ground — enough to say "wet meadow around here," but not where its edge runs; re-drawing those edges is the decoder's job. Because this is the widest level (the most channels), we show it two ways: the top-variance channels and the mean of all channels, side by side.

No description has been provided for this image

11. Decoder feature maps (top-variance channels)¶

The decoder zooming back in. Each stage merges the abstract bottleneck summary with the matching fine detail from the encoder (the "skip connections"), progressively sharpening edges and re-localizing the wetland boundaries. Rows go coarse → fine; dec0 is the full-resolution node that feeds the final output head.

No description has been provided for this image

12. Whole-network mean-activation overview¶

One mean-activation image per stage, in order: encoder → bottleneck → decoder. This is the "signal flow" at a glance — watch the resolution shrink through the encoder, bottom out at the bottleneck, then grow back through the decoder, getting reorganized around the wetland features along the way.

No description has been provided for this image

13. Logits → predicted classes¶

The payoff. The output head turns the final decoder feature map into a raw score (a logit) for each class at each pixel. Softmax converts those scores into probabilities that sum to 1, and argmax picks the highest-probability class as the prediction. Shown here: a probability heatmap per class, the predicted class map, and the ground truth for comparison.

No description has been provided for this image
No description has been provided for this image

14. Deep-supervision side predictions (per-stage, coarse → fine)¶

Deep supervision means that during training the model wasn't graded only on its final full-resolution map — every decoder stage and the bottleneck got its own tiny prediction head and was graded on a "rough draft" of the wetland map at its own scale. That pressure pushes every level of the network to learn wetland-relevant features, not just the last one.

At inference those side heads are normally ignored (eval mode returns only the main, finest prediction), but they were saved with the model — so here we apply each one to the feature maps we captured in section 6 and watch the prediction sharpen from the blurry 16×16 bottleneck draft to the full-resolution final map. Top row = predicted class; bottom row = model confidence (the winning class's probability). The printout at the end quantifies how much each draft already agrees with the final answer.

No description has been provided for this image
Per-stage agreement with the main head's prediction:
  bottleneck      93.4% of pixels match
  dec3            95.3% of pixels match
  dec2            96.2% of pixels match
  dec1            96.9% of pixels match
  dec0 (main)    100.0% of pixels match

15. Combined interactive map — predictions + inputs (folium)¶

The same Esri-satellite Leaflet map as section 8b, now with the model's outputs added — flip between what the model saw and what it predicted, all draped over real imagery. The layer control (top-right) now has two independent radio groups, so you can show one model output and one input layer at the same time — the output always draws on top (try prediction confidence over NAIP RGB to see what the model hesitates about):

  • + Predicted class (argmax) vs. + Ground truth (MOD_CLASS) — flip back and forth to spot errors
  • + Prediction confidence (max probability) and a + P(class) layer per class — where the model is sure vs. unsure
  • all the NAIP / leaf-off composites and individual input bands from section 8b

Each group has a "None" option — pick None in both to see the bare satellite image. This is the one map to share if you want someone to explore the model interactively.

Built combined map: 7 model-output layers + 20 input layers, in two radio groups (each with its own 'None' option).
Make this Notebook Trusted to load map: File -> Trust Notebook

16. Side-by-side error check — prediction vs. ground truth (synced maps)¶

Two maps locked together: left = the model's prediction, right = the analyst's ground truth (unlabeled pixels are transparent, letting the imagery show through). Pan or zoom either side and the other follows, so your eyes can stay on one spot while you compare. This is the fastest way to see where the model goes wrong — a marsh edge drawn too wide, an upland inclusion it missed. Same class colors as the legend (bottom-left).

Make this Notebook Trusted to load map: File -> Trust Notebook

To explore another patch or model: change CHECKPOINT_NAME (Section 2) or PATCH_NAME (Section 5), then re-run from that cell down. CHANNELS_PER_LEVEL, FEATURE_CMAP, and OVERLAY_OPACITY (Section 3) control how the plots and the interactive maps (Sections 8b, 15, 16) are drawn.