Pixel-Level Regression with Deep Learning for Remote Sensing¶
This notebook demonstrates how to train a pixel-level regression model using encoder-decoder architectures (U-Net, DeepLabV3+, etc.) for remote sensing applications.
Use Case: NDVI Prediction from Satellite Imagery¶
We'll train a deep learning model to predict NDVI (Normalized Difference Vegetation Index) at the pixel level from satellite imagery. The number of input bands is auto-detected from the raster.
# %pip install geoai-py segmentation-models-pytorch lightning
import torch
torch.set_float32_matmul_precision("medium")
import geoai
from geoai.timm_regress import (
create_regression_tiles,
train_pixel_regressor,
predict_raster,
plot_regression_comparison,
plot_scatter,
plot_training_history,
visualize_prediction,
)
from sklearn.model_selection import train_test_split
Download Data¶
The dataset contains satellite imagery and corresponding NDVI rasters for Knoxville, Tennessee.
train_raster = geoai.download_file(
"https://huggingface.co/datasets/giswqs/geospatial/resolve/main/knoxville_landsat_2022.tif"
)
train_target = geoai.download_file(
"https://huggingface.co/datasets/giswqs/geospatial/resolve/main/knoxville_ndvi_2022.tif"
)
test_raster = geoai.download_file(
"https://huggingface.co/datasets/giswqs/geospatial/resolve/main/knoxville_landsat_2023.tif"
)
Create Training Tiles¶
Important: NDVI values should be in range [-1, 1]. Outlier values are clipped to this range during tile creation to ensure clean training data.
# Create training tiles
image_paths, target_paths = create_regression_tiles(
input_raster=train_raster,
target_raster=train_target,
output_dir="ndvi_tiles",
tile_size=256,
stride=128,
target_band=1,
min_valid_ratio=0.9,
target_min=-1.0, # Clip NDVI values to valid range
target_max=1.0,
)
print(f"Created {len(image_paths)} tiles")
# Auto-detect number of input bands
import rasterio
with rasterio.open(train_raster) as src:
in_channels = src.count
print(f"Input bands: {in_channels}")
# Split into train/validation
train_imgs, val_imgs, train_tgts, val_tgts = train_test_split(
image_paths, target_paths, test_size=0.2, random_state=42
)
print(f"Training: {len(train_imgs)}, Validation: {len(val_imgs)}")
Train Model¶
Using U-Net with ResNet34 encoder for pixel-level regression.
model = train_pixel_regressor(
train_image_paths=train_imgs,
train_target_paths=train_tgts,
val_image_paths=val_imgs,
val_target_paths=val_tgts,
encoder_name="resnet34",
architecture="unet",
in_channels=in_channels,
output_dir="ndvi_model",
batch_size=8,
num_epochs=100,
learning_rate=1e-3,
num_workers=0,
loss_type="mse",
patience=20,
devices=1,
verbose=False,
)
Training History¶
Plot the training and validation loss/R² curves over epochs.
# Plot training curves (loss and R² over epochs)
fig, history_df = plot_training_history(
log_dir="ndvi_model",
metrics=["loss", "r2"],
)
Run Inference¶
# train_pixel_regressor reloads the best checkpoint when available
print(f"Using checkpoint: {getattr(model, 'best_model_path', 'last epoch')}")
# Predict with clipping to valid NDVI range
predict_raster(
model=model,
input_raster=train_raster,
output_raster="ndvi_model/predicted_ndvi_2022.tif",
tile_size=256,
overlap=64,
batch_size=8,
clip_range=(-1.0, 1.0), # Clip predictions to valid NDVI range
)
Evaluate Results¶
Important: We use valid_range=(-1, 1) to exclude outlier pixels from evaluation.
# Compare with ground truth (filtering outliers)
fig, metrics = plot_regression_comparison(
true_raster=train_target,
pred_raster="ndvi_model/predicted_ndvi_2022.tif",
title="NDVI Prediction Results",
cmap="RdYlGn",
vmin=-0.2,
vmax=0.8,
valid_range=(-1.0, 1.0), # Filter outliers for fair evaluation
)
# Scatter plot with trend line
fig, metrics = plot_scatter(
true_raster=train_target,
pred_raster="ndvi_model/predicted_ndvi_2022.tif",
sample_size=50000,
valid_range=(-1.0, 1.0), # Filter outliers
fit_line=True, # Show linear regression trend line
)
Predict on New Data (2023)¶
predict_raster(
model=model,
input_raster=test_raster,
output_raster="ndvi_model/predicted_ndvi_2023.tif",
tile_size=256,
overlap=64,
batch_size=8,
clip_range=(-1.0, 1.0),
)
visualize_prediction(
input_raster=test_raster,
pred_raster="ndvi_model/predicted_ndvi_2023.tif",
cmap="RdYlGn",
vmin=-0.2,
vmax=0.8,
)
Summary¶
Key Parameters¶
target_min,target_max: Filter training tiles with out-of-range valuesclip_range: Clip predictions to valid range during inferencevalid_range: Filter outliers when evaluating metricsarchitecture: 'unet', 'unetplusplus', 'deeplabv3plus', 'fpn'encoder_name: 'resnet18', 'resnet34', 'resnet50', 'efficientnet_b0'
Tips for Better Results¶
- Data quality: Ensure target values are in expected range
- More epochs: 50+ epochs for convergence
- Larger tiles: 256x256 captures more context
- Overlap in inference: 64+ pixels for smooth blending