Transfer Learning and Pre-Trained Models¶
Summary¶
This chapter introduces transfer learning, a powerful technique that enables practitioners to leverage knowledge from pre-trained models to solve new tasks with limited data. Students will learn how to use pre-trained models from model zoos, understand the difference between feature extraction (using a frozen pre-trained model) and fine-tuning (adapting model weights to new data), and explore domain adaptation strategies for transferring knowledge across different but related domains. The chapter demonstrates practical transfer learning workflows using models pre-trained on ImageNet, showing how to effectively reuse learned representations while managing learning rates and layer freezing to achieve strong performance with minimal training data.
Concepts Covered¶
This chapter covers the following 10 concepts from the learning graph:
- Transfer Learning
- Pre-Trained Model
- Fine-Tuning
- Feature Extraction
- Domain Adaptation
- Model Zoo
- Validation Error
- Online Learning
- Optimizer
- Momentum
Prerequisites¶
This chapter builds on concepts from:
- Chapter 9: Neural Networks Fundamentals
- Chapter 10: Convolutional Neural Networks for Computer Vision
Introduction to Transfer Learning¶
Imagine you're learning to play the piano. Would you start completely from scratch, not using any knowledge from previous musical experiences? Of course not! You'd leverage what you already know about reading music, rhythm, and hand coordination. Transfer learning in machine learning works the same way—it allows us to take knowledge learned from one task and apply it to a new, related task.
In deep learning, training large neural networks from scratch requires enormous amounts of labeled data and computational resources. A ResNet-50 trained on ImageNet uses 25 million parameters and was trained on 1.2 million labeled images across 1,000 categories. Most practitioners don't have access to datasets of this scale or the computational budget to train such models from scratch.
Transfer learning solves this problem by allowing us to start with a pre-trained model—a neural network that has already learned useful feature representations from a large dataset—and adapt it to our specific task. This approach is particularly powerful when working with small datasets, where training from scratch would lead to severe overfitting.
The fundamental insight behind transfer learning is that the features learned by deep neural networks on large-scale tasks (like ImageNet classification) are often transferable to other computer vision tasks. Early layers in CNNs learn general features like edges, textures, and simple shapes, while later layers learn increasingly task-specific features. By reusing the general features and only retraining the task-specific layers, we can achieve excellent performance with far less data and computation.
The Model Zoo: Pre-Trained Models¶
A model zoo is a collection of pre-trained neural network models that have been trained on large-scale benchmark datasets and made publicly available for reuse. The most famous model zoo consists of networks trained on ImageNet, a dataset containing 1.2 million training images across 1,000 object categories.
Popular pre-trained models available in PyTorch's model zoo include:
- ResNet (18, 34, 50, 101, 152 layers): Deep residual networks with skip connections
- VGG (16, 19 layers): Very deep convolutional networks with small filters
- AlexNet: The breakthrough architecture that won ImageNet 2012
- DenseNet: Networks with dense connections between layers
- MobileNet: Lightweight networks optimized for mobile devices
- EfficientNet: Networks optimized for both accuracy and efficiency
These models achieve impressive accuracy on ImageNet classification, with top-5 error rates (the percentage of test images where the correct label is not among the model's top 5 predictions) often below 5%. The learned feature representations from these models transfer remarkably well to other computer vision tasks.
Choosing a Pre-Trained Model
When selecting a pre-trained model, consider the tradeoff between accuracy and computational cost. ResNet-18 is faster and requires less memory than ResNet-152, but may achieve slightly lower accuracy. For most transfer learning tasks on moderate-sized datasets, ResNet-18 or ResNet-50 provide an excellent balance.
Loading Pre-Trained Models in PyTorch¶
PyTorch makes it straightforward to load pre-trained models from the torchvision.models module. Here's how to load a ResNet-18 model with weights pre-trained on ImageNet:
import torch
import torchvision
from torchvision import models
import torch.nn as nn
# Load a pre-trained ResNet-18 model
model_ft = models.resnet18(weights='IMAGENET1K_V1')
# Examine the architecture
print(model_ft)
The weights='IMAGENET1K_V1' parameter loads the official ImageNet pre-trained weights. The model architecture includes:
- An initial convolutional layer with batch normalization and ReLU activation
- Four residual layer groups with increasing channel depths (64, 128, 256, 512)
- An adaptive average pooling layer
- A final fully connected layer with 1,000 outputs (one for each ImageNet class)
For our custom task, we'll need to modify this architecture—specifically, we'll replace the final fully connected layer to match our number of output classes.
Data Preprocessing for Pre-Trained Models¶
When using pre-trained models, it's crucial to preprocess input images in the same way the model was trained. All ImageNet pre-trained models expect input images to be normalized with specific mean and standard deviation values:
from torchvision import transforms
# Data augmentation and normalization for training
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
Let's break down these transformations:
Training transformations:
RandomResizedCrop(224): Randomly crops the image to 224×224 pixels with random scale and aspect ratio—this is data augmentation that helps the model generalizeRandomHorizontalFlip(): Randomly flips images horizontally with 50% probability—another augmentation techniqueToTensor(): Converts PIL images or NumPy arrays to PyTorch tensors with values in [0, 1]Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]): Normalizes each channel using ImageNet statistics
Validation transformations:
Resize(256): Resizes the smaller edge to 256 pixels while maintaining aspect ratioCenterCrop(224): Takes a 224×224 center crop- No augmentation is applied during validation—we want consistent evaluation
The normalization values [0.485, 0.456, 0.406] (mean) and [0.229, 0.224, 0.225] (standard deviation) are the channel-wise statistics computed across the entire ImageNet training set. Using these same values ensures that the pre-trained model receives inputs in the distribution it expects.
Normalization Formula¶
For each pixel value \(x\) in channel \(c\), the normalized value \(x'\) is:
\(x' = \frac{x - \mu_c}{\sigma_c}\)
where:
- \(x\) is the original pixel value in range [0, 1]
- \(\mu_c\) is the mean for channel \(c\) (e.g., 0.485 for red channel)
- \(\sigma_c\) is the standard deviation for channel \(c\) (e.g., 0.229 for red channel)
- \(x'\) is the normalized pixel value
Feature Extraction: Freezing the Base Network¶
Feature extraction is the simplest form of transfer learning. In this approach, we treat the pre-trained model as a fixed feature extractor, freezing all convolutional layers and only training a new classification head for our specific task.
The rationale is straightforward: the convolutional layers of a network trained on ImageNet have already learned to extract useful visual features (edges, textures, object parts, etc.). These features are generally applicable to many computer vision tasks. By freezing these layers, we prevent their weights from updating during training, dramatically reducing the number of parameters we need to learn.
Here's how to implement feature extraction with a frozen ResNet-18:
import torch.optim as optim
from torch.optim import lr_scheduler
# Load pre-trained ResNet-18
model_conv = torchvision.models.resnet18(weights='IMAGENET1K_V1')
# Freeze all parameters in the base network
for param in model_conv.parameters():
param.requires_grad = False
# Replace the final fully connected layer
# New layers have requires_grad=True by default
num_ftrs = model_conv.fc.in_features # 512 for ResNet-18
model_conv.fc = nn.Linear(num_ftrs, 2) # 2 classes: ants and bees
# Move model to GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_conv = model_conv.to(device)
# Only optimize parameters of the final layer
optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)
# Learning rate scheduler: decay LR by 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)
# Loss function
criterion = nn.CrossEntropyLoss()
The key steps are:
- Freeze base network: Setting
param.requires_grad = Falsefor all parameters prevents gradient computation and weight updates for those layers - Replace classification head: The original
fclayer outputs 1,000 classes (ImageNet); we replace it with a new layer outputting 2 classes (ants vs. bees) - Optimize only new layers: We pass only
model_conv.fc.parameters()to the optimizer, so only the new classification head is trained
This approach is extremely efficient when you have a small dataset. Training only the final layer requires far less computation and memory than training the entire network, and it's less prone to overfitting since we're only learning a linear classifier on top of robust features.
Fine-Tuning: Adapting All Layers¶
While feature extraction works well for many tasks, fine-tuning can achieve even better performance by allowing the entire network to adapt to the new task. In fine-tuning, we initialize the network with pre-trained weights, then train all layers (or a subset of later layers) with a small learning rate.
The intuition is that while early layers learn general features applicable to many tasks, later layers learn increasingly task-specific features. By fine-tuning, we allow the network to adjust these features for our specific domain while still benefiting from the initialization provided by pre-training.
Here's how to implement fine-tuning:
# Load pre-trained ResNet-18
model_ft = models.resnet18(weights='IMAGENET1K_V1')
# Replace the final fully connected layer
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 2)
model_ft = model_ft.to(device)
# Optimize ALL parameters (not just the final layer)
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
# Decay learning rate by 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
# Loss function
criterion = nn.CrossEntropyLoss()
The key difference from feature extraction is that we optimize model_ft.parameters() (all parameters) rather than just model_ft.fc.parameters() (final layer only).
Learning Rate for Fine-Tuning
When fine-tuning, use a smaller learning rate than you would for training from scratch. The pre-trained weights are already in a good region of the parameter space, so aggressive updates can destroy the learned features. Learning rates like 0.001 or 0.0001 are typical for fine-tuning.
Training Loop for Transfer Learning¶
Both feature extraction and fine-tuning use the same training loop structure. Here's a complete training function that handles both training and validation:
import time
import copy
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
since = time.time()
best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0.0
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10)
# Each epoch has a training and validation phase
for phase in ['train', 'val']:
if phase == 'train':
model.train() # Set model to training mode
else:
model.eval() # Set model to evaluate mode
running_loss = 0.0
running_corrects = 0
# Iterate over data batches
for inputs, labels in dataloaders[phase]:
inputs = inputs.to(device)
labels = labels.to(device)
# Zero the parameter gradients
optimizer.zero_grad()
# Forward pass
# Track gradients only in training phase
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
outputs = nn.functional.log_softmax(outputs, dim=1)
loss = criterion(outputs, labels)
# Backward pass and optimization only in training phase
if phase == 'train':
loss.backward()
optimizer.step()
# Accumulate statistics
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
# Step the learning rate scheduler after each training epoch
if phase == 'train':
scheduler.step()
# Calculate epoch statistics
epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = running_corrects.double() / dataset_sizes[phase]
print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
# Save the best model based on validation accuracy
if phase == 'val' and epoch_acc > best_acc:
best_acc = epoch_acc
best_model_wts = copy.deepcopy(model.state_dict())
print()
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(
time_elapsed // 60, time_elapsed % 60))
print('Best val Acc: {:4f}'.format(best_acc))
# Load best model weights
model.load_state_dict(best_model_wts)
return model
This training function implements several important practices:
Model modes:
model.train(): Enables training mode, where dropout and batch normalization behave appropriately for trainingmodel.eval(): Enables evaluation mode, where dropout is disabled and batch normalization uses running statistics
Gradient tracking:
torch.set_grad_enabled(phase == 'train'): Disables gradient computation during validation, saving memory and computation
Learning rate scheduling:
scheduler.step(): Reduces the learning rate according to the schedule (e.g., multiply by 0.1 every 7 epochs)
Model checkpointing:
- The function saves the model weights that achieve the best validation accuracy, preventing overfitting to the training set
Validation Error and Generalization¶
Validation error measures how well your model performs on data it hasn't seen during training. In transfer learning, monitoring validation error is crucial for several reasons:
- Early stopping: If validation error starts increasing while training error continues decreasing, the model is overfitting
- Hyperparameter tuning: Validation error helps you compare different learning rates, architectures, or augmentation strategies
- Model selection: The checkpoint with lowest validation error is typically the best model to deploy
The relationship between training and validation error reveals important information about your model:
| Training Error | Validation Error | Diagnosis | Solution |
|---|---|---|---|
| High | High | Underfitting | More training, larger model, less regularization |
| Low | High | Overfitting | More data, data augmentation, regularization, early stopping |
| Low | Low | Good fit | Model is ready for deployment |
In our ants vs. bees example with only 120 training images per class, a model trained from scratch would likely show low training error but high validation error (overfitting). Transfer learning achieves both low training and validation error by leveraging pre-trained features.
Visualize how training and validation errors evolve during model training:
View Fullscreen | Documentation
Case Study: Ants vs. Bees Classification¶
Let's walk through a complete transfer learning example classifying images of ants and bees. This dataset is deliberately small (only 120 training images per class and 75 validation images per class) to demonstrate transfer learning's effectiveness with limited data.
First, we set up the dataset and dataloaders:
import os
from torchvision import datasets
data_dir = 'hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
data_transforms[x])
for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x],
batch_size=4,
shuffle=True,
num_workers=2)
for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
The ImageFolder class automatically creates a dataset from a directory structure where each subdirectory name is a class label:
hymenoptera_data/
├── train/
│ ├── ants/
│ │ ├── image1.jpg
│ │ ├── image2.jpg
│ │ └── ...
│ └── bees/
│ ├── image1.jpg
│ └── ...
└── val/
├── ants/
└── bees/
Now we train the model using fine-tuning (all layers trainable):
After just 5 epochs of fine-tuning, we achieve remarkable results:
Epoch 0/4
----------
train Loss: 0.5150 Acc: 0.7213
val Loss: 0.1700 Acc: 0.9346
Epoch 1/4
----------
train Loss: 0.6623 Acc: 0.7172
val Loss: 0.1956 Acc: 0.9150
Epoch 2/4
----------
train Loss: 0.6573 Acc: 0.7705
val Loss: 0.4305 Acc: 0.8431
Epoch 3/4
----------
train Loss: 0.4789 Acc: 0.8033
val Loss: 0.2175 Acc: 0.9085
Epoch 4/4
----------
train Loss: 0.4125 Acc: 0.8197
val Loss: 0.1653 Acc: 0.9477
Training complete in 7m 38s
Best val Acc: 0.947712
The model achieves 94.77% validation accuracy with only 240 total training images! This demonstrates the power of transfer learning—training a model from scratch on such a small dataset would likely achieve only 60-70% accuracy.
Compare this with feature extraction (frozen base network, only final layer trained):
Results after 5 epochs:
Epoch 0/4
----------
train Loss: 0.6429 Acc: 0.5943
val Loss: 0.3860 Acc: 0.8301
Epoch 4/4
----------
train Loss: 0.5453 Acc: 0.7787
val Loss: 0.4029 Acc: 0.8497
Training complete in 3m 46s
Best val Acc: 0.947712
Feature extraction also achieves 94.77% validation accuracy, but with significantly faster training (3m 46s vs. 7m 38s) since we're only updating the final layer. For this particular task, the frozen features are sufficiently powerful that fine-tuning provides no additional benefit.
Optimizers and Momentum¶
An optimizer is an algorithm that updates model parameters to minimize the loss function. The most common optimizer for transfer learning is Stochastic Gradient Descent (SGD) with momentum.
Standard SGD updates parameters using the gradient:
SGD Update Rule¶
\(\theta_{t+1} = \theta_t - \eta \nabla L(\theta_t)\)
where:
- \(\theta_t\) represents the parameters at iteration \(t\)
- \(\eta\) is the learning rate
- \(\nabla L(\theta_t)\) is the gradient of the loss with respect to parameters
Momentum improves upon basic SGD by accumulating a velocity vector that smooths out the gradient updates:
SGD with Momentum¶
\(v_{t+1} = \beta v_t + \nabla L(\theta_t)\)
\(\theta_{t+1} = \theta_t - \eta v_{t+1}\)
where:
- \(v_t\) is the velocity (momentum) vector at iteration \(t\)
- \(\beta\) is the momentum coefficient (typically 0.9)
- \(\eta\) is the learning rate
Momentum has several benefits:
- Accelerates convergence: In directions where gradients consistently point the same way, momentum builds up speed
- Dampens oscillations: In directions where gradients fluctuate, momentum averages them out
- Escapes local minima: Accumulated momentum can carry the optimization through shallow local minima
In PyTorch, we specify momentum when creating the optimizer:
The momentum value of 0.9 is a common default that works well for most transfer learning tasks.
Alternative Optimizers
While SGD with momentum is popular for fine-tuning, Adam and AdamW are also effective choices. Adam adapts learning rates per parameter and includes momentum-like behavior. For transfer learning, SGD with momentum often generalizes slightly better, but Adam can converge faster.
Understanding momentum's effect on gradient descent convergence:
Gradient Descent Comparison
**Without Momentum (Standard SGD):** - Oscillates perpendicular to gradient - Slow progress in valleys - Sensitive to learning rate **With Momentum (β = 0.9):** - Accelerates in consistent directions - Dampens oscillations - Faster convergence **Update Rules:** Standard SGD: $\theta_{t+1} = \theta_t - \eta \nabla L(\theta_t)$ SGD with Momentum: - $v_t = \beta v_{t-1} + \eta \nabla L(\theta_t)$ - $\theta_{t+1} = \theta_t - v_t$ Where β (typically 0.9) controls how much past gradients influence current update. **Key Benefits:** - ✓ Accelerates convergence in consistent gradient directions - ✓ Reduces oscillation in high-curvature regions - ✓ Helps escape shallow local minima - ✓ More stable training with larger learning ratesFor an interactive visualization, see the Training vs Validation Curves which shows momentum's effect on convergence speed.
Domain Adaptation¶
Domain adaptation addresses the challenge of transferring knowledge when the source domain (where the model was trained) differs from the target domain (where we want to apply it). Even when pre-trained on large datasets like ImageNet, models may struggle when the target domain has different characteristics.
Examples of domain shift:
- Dataset bias: ImageNet contains mostly high-quality photos, but your application involves sketches or medical images
- Environmental differences: A model trained on daytime images applied to nighttime images
- Sensor differences: A model trained on camera images applied to satellite imagery
Domain adaptation techniques help bridge this gap:
Strategies for domain adaptation:
- Gradual fine-tuning: Start with very small learning rates and gradually increase them
- Layer-wise unfreezing: Freeze early layers longer, unfreeze later layers first, then progressively unfreeze earlier layers
- Domain-specific data augmentation: Apply augmentations that simulate domain shift (e.g., color jittering, blur)
- Adversarial training: Train a domain classifier to make features domain-invariant
- Self-supervised pre-training: Pre-train on unlabeled target domain data before fine-tuning
For moderate domain shift, fine-tuning with appropriate data augmentation often suffices. For severe domain shift, more sophisticated techniques may be necessary.
Online Learning and Continual Adaptation¶
Online learning refers to updating models incrementally as new data arrives, rather than retraining from scratch. This is particularly relevant for transfer learning in production systems where:
- User behavior changes over time
- New classes or categories emerge
- The data distribution shifts gradually (concept drift)
In online learning scenarios, we can periodically fine-tune our transferred model on recent data:
# Initial transfer learning
model = train_model(model, criterion, optimizer, scheduler, num_epochs=10)
# Deploy model to production
# ... time passes, collect new labeled examples ...
# Online update: fine-tune on recent data
new_dataloaders = create_dataloaders(recent_data)
model = train_model(model, criterion, optimizer, scheduler, num_epochs=2)
Key considerations for online learning:
- Catastrophic forgetting: Fine-tuning on new data may cause the model to forget previous knowledge. Mitigation strategies include mixing old and new data or using regularization techniques like Elastic Weight Consolidation (EWC)
- Learning rate scheduling: Use smaller learning rates for online updates to preserve existing knowledge
- Monitoring performance: Track performance on both old and new data to detect forgetting
Online Learning Workflow¶
flowchart TD
Start(("Pre-trained Model<br/>(ImageNet ResNet-50)"))
Transfer["Initial Transfer Learning<br/>(fine-tune on 10K images)"]
Deploy["Deploy Model to Production<br/>(serve predictions)"]
Collect["Collect User Feedback<br/>(gather labeled examples)"]
CheckData{"Sufficient<br/>New Data?<br/>(500+ samples)"}
FineTune["Online Fine-Tuning<br/>(2-3 epochs, LR=0.0001)"]
Validate["Validate Performance<br/>(test on old + new data)"]
CheckPerf{"Performance<br/>Acceptable?"}
Investigate["Investigate Issues<br/>(check for forgetting)"]
UpdateDeploy["Deploy Updated Model<br/>(replace in production)"]
Start --> Transfer
Transfer --> Deploy
Deploy --> Collect
Collect --> CheckData
CheckData -->|No| Collect
CheckData -->|Yes| FineTune
FineTune --> Validate
Validate --> CheckPerf
CheckPerf -->|No| Investigate
Investigate --> FineTune
CheckPerf -->|Yes| UpdateDeploy
UpdateDeploy --> Collect
classDef setupNode fill:#4299e1,stroke:#2c5282,stroke-width:2px,color:#fff,font-size:14px
classDef collectNode fill:#48bb78,stroke:#2f855a,stroke-width:2px,color:#fff,font-size:14px
classDef updateNode fill:#9f7aea,stroke:#6b46c1,stroke-width:2px,color:#fff,font-size:14px
classDef decisionNode fill:#ecc94b,stroke:#b7791f,stroke-width:2px,color:#333,font-size:14px
classDef issueNode fill:#ed8936,stroke:#c05621,stroke-width:2px,color:#fff,font-size:14px
class Start,Transfer,Deploy setupNode
class Collect,Validate collectNode
class FineTune,UpdateDeploy updateNode
class CheckData,CheckPerf decisionNode
class Investigate issueNode
linkStyle default stroke:#666,stroke-width:2px,font-size:12px Continuous Improvement Loop: This workflow shows how deployed models can be continually updated with new production data, maintaining performance as data distributions shift over time. - Production Deployment - Data Collection - Model Update - Validation
Annotations: - Cycle time: "Typical update cycle: 1-4 weeks" - Warning icon on "Catastrophic Forgetting": "Mix old and new data to prevent forgetting"
Implementation: Flowchart with Mermaid.js or D3.js Canvas size: Responsive, minimum 700×600px
Practical Tips for Transfer Learning Success¶
Based on extensive experimentation, here are proven strategies for successful transfer learning:
1. Start with feature extraction, then fine-tune if needed
Feature extraction is faster and less prone to overfitting. If accuracy is insufficient, try fine-tuning with a small learning rate.
2. Use appropriate learning rates
- Feature extraction (training only final layer): 0.001 to 0.01
- Fine-tuning (training all layers): 0.0001 to 0.001
3. Apply aggressive data augmentation
With small datasets, augmentation is crucial. Use flips, crops, rotations, color jittering, and other transformations to increase effective dataset size.
4. Monitor both training and validation metrics
Watch for overfitting by comparing training and validation loss. Early stopping based on validation performance prevents overfitting.
5. Experiment with different pre-trained models
Different architectures may perform better for different tasks. Try both shallower (ResNet-18) and deeper (ResNet-50) models.
6. Consider the domain gap
If your task is very different from ImageNet (e.g., medical imaging, satellite imagery), you may need domain-specific augmentation or even domain-specific pre-trained models.
7. Use learning rate schedules
Reducing learning rate during training (e.g., multiply by 0.1 every 7 epochs) helps convergence.
8. Leverage model ensembles
Training multiple models with different random seeds and averaging their predictions often improves performance.
Transfer Learning Beyond Image Classification¶
While this chapter focuses on image classification with CNNs, transfer learning applies broadly across machine learning:
Object detection: Models like Faster R-CNN use pre-trained CNN backbones
Semantic segmentation: FCN and U-Net architectures build on pre-trained encoders
Natural language processing: BERT, GPT, and other language models use transfer learning extensively
Speech recognition: Models pre-trained on large speech corpora transfer to specific accents or languages
Reinforcement learning: Policies trained in simulation transfer to real-world robots
The fundamental principle remains the same: leverage knowledge from large-scale pre-training to achieve better performance with less data on downstream tasks.
Transfer Learning Strategy Comparison¶
| Strategy | Frozen Layers | Trainable Params | Training Time | Typical Accuracy | Best For |
|---|---|---|---|---|---|
| Feature Extraction | All except final layer | ~2,000 | Fast (minutes) | 80-85% | Very limited data (<100 samples/class) |
| Partial Fine-Tuning | Early layers only | ~2-5M | Medium (hours) | 88-93% | Moderate data (100-1000 samples/class) |
| Full Fine-Tuning | None | ~11-25M | Slow (hours-days) | 92-97% | Substantial data (1000+ samples/class) |
Key Insights: - Feature Extraction: Fastest approach, leverages pre-trained features as fixed representations. Best when target task is similar to source task (e.g., both are image classification). - Partial Fine-Tuning: Balances adaptation with preservation of learned features. Early layers learn general patterns (edges, textures) that transfer well; late layers adapt to task-specific patterns. - Full Fine-Tuning: Maximum flexibility and potential performance, but requires more data and compute. Risk of overfitting with limited data.
Choosing a Strategy: 1. Start with feature extraction if you have <100 examples per class 2. Use partial fine-tuning for 100-1000 examples per class 3. Try full fine-tuning only if you have 1000+ examples per class and sufficient compute 4. Always use data augmentation and monitor validation curves for overfitting
Summary and Key Takeaways¶
Transfer learning enables practitioners to leverage knowledge from large-scale pre-trained models to achieve excellent performance on new tasks with limited data. The key insights from this chapter:
Core concepts:
- Pre-trained models from model zoos (ImageNet, etc.) have learned general visual features applicable to many tasks
- Feature extraction treats the pre-trained model as a fixed feature extractor, training only a new classification head
- Fine-tuning adapts all layers of the pre-trained model to the new task, typically achieving higher accuracy
- Domain adaptation addresses distribution shift between source and target domains
Best practices:
- Always preprocess inputs to match the pre-trained model's expected distribution
- Start with feature extraction; upgrade to fine-tuning if accuracy is insufficient
- Use small learning rates (0.0001-0.001) for fine-tuning to preserve learned features
- Apply data augmentation aggressively when working with small datasets
- Monitor validation error to detect overfitting and determine when to stop training
- Use SGD with momentum (0.9) for stable convergence
Practical impact:
Transfer learning has democratized deep learning by making it accessible to practitioners without access to massive datasets or computational resources. Tasks that once required millions of labeled examples can now be solved with hundreds or thousands, enabling applications across medicine, agriculture, manufacturing, and countless other domains.
Further Reading¶
For deeper exploration of transfer learning concepts:
- PyTorch Transfer Learning Tutorial - Official PyTorch documentation and examples
- Yosinski et al. (2014) - "How transferable are features in deep neural networks?" - Seminal paper analyzing feature transferability across layers
- Deep Learning Book, Chapter 15.2 - Goodfellow, Bengio, Courville on transfer learning theory
- Kornblith et al. (2019) - "Do Better ImageNet Models Transfer Better?" - Analysis of which architectures transfer best
- Raghu et al. (2019) - "Transfusion: Understanding Transfer Learning with Applications to Medical Imaging" - Domain adaptation for medical tasks
Exercises¶
Exercise 1: Feature Extraction vs. Fine-Tuning
Using the ants vs. bees dataset, implement both feature extraction and fine-tuning approaches. Train each for 10 epochs and compare: - Final validation accuracy - Training time - Number of trainable parameters
Which approach would you choose for this task and why?
Exercise 2: Learning Rate Sensitivity
Train a transfer learning model with different learning rates: [0.0001, 0.001, 0.01, 0.1]. Plot the training and validation loss curves for each. What happens when the learning rate is too large? Too small?
Exercise 3: Data Augmentation Impact
Train two models on a small subset of the data (50 images per class): 1. With data augmentation (random crops, flips, color jitter) 2. Without data augmentation
Compare the validation accuracy and the gap between training and validation accuracy. What does this tell you about overfitting?
Exercise 4: Custom Dataset Transfer
Adapt the transfer learning code to a new dataset of your choice (e.g., Kaggle's Dogs vs. Cats, Sign Language MNIST, or Food-101). Experiment with different pre-trained models (ResNet-18, ResNet-50, MobileNet) and report which achieves the best performance.
Exercise 5: Layer Freezing Strategy
Implement a progressive unfreezing strategy: 1. Train only the final layer for 3 epochs 2. Unfreeze the last convolutional block and train for 3 more epochs 3. Unfreeze all layers and train for 3 final epochs
Compare this approach to full fine-tuning from the start. Does progressive unfreezing improve final accuracy?
Exercise 6: Domain Adaptation
Create a domain shift by applying strong transformations (e.g., converting images to grayscale, adding noise, extreme color shifts) to the validation set while keeping the training set normal. How does this affect validation accuracy? What strategies might help improve performance under domain shift?