Practical 8
Layer-wise Model Parallelism in PyTorch
Objective
In this practical session, you will:
- Implement layer-wise model parallelism
- Distribute a neural network across multiple GPUs
- Understand forward data movement between devices
- Observe GPU utilization imbalance
- Measure performance limitations
Background
In Lecture 8, we introduced the following:
👉 Model Parallelism = Split the model across GPUs, not the data
- GPU 0 → first layers
- GPU 1 → later layers
- Data moves between GPUs
- Execution is sequential
- Not all GPUs work at the same time
In this session, you will experimentally observe these concepts of model parallelism.
Part 1—Simple Implementation (Core Learning)
import torch
import torch.nn as nn
# Define devices
device0 = torch.device("cuda:0")
device1 = torch.device("cuda:1")
class SimpleMPModel(nn.Module):
def __init__(self):
super().__init__()
# First part → GPU 0
self.layer1 = nn.Linear(1000, 2000).to(device0)
self.relu1 = nn.ReLU().to(device0)
# Second part → GPU 1
self.layer2 = nn.Linear(2000, 1000).to(device1)
self.relu2 = nn.ReLU().to(device1)
def forward(self, x):
# Move input to GPU 0
x = x.to(device0)
# Forward on GPU 0
x = self.layer1(x)
x = self.relu1(x)
# Move to GPU 1 ← KEY STEP
x = x.to(device1)
# Forward on GPU 1
x = self.layer2(x)
x = self.relu2(x)
return x
# Test
model = SimpleMPModel()
x = torch.randn(32, 1000)
output = model(x)
print(output.device)
Run and check:
- Where does computation start?
- When does .to(device1) happen?
- Which GPU holds the final output?
Part 2—Add Training
import torch.optim as optim
model = SimpleMPModel()
optimizer = optim.SGD(model.parameters(), lr=0.01)
loss_fn = nn.MSELoss()
for step in range(5):
x = torch.randn(32, 1000)
y = torch.randn(32, 1000).to(device1) # target on last GPU
output = model(x)
loss = loss_fn(output, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Step {step}, Loss: {loss.item():.4f}")
Part 3—Track Data Movement
Modify forward:
print("Before move:", x.device)
x = x.to(device1)
print("After move:", x.device)
👉 This shows the GPU-to-GPU transfer
Part 4—Measure Time
import time
start = time.time()
output = model(x)
torch.cuda.synchronize()
end = time.time()
print("Forward time:", end - start)
Key Observations:
- GPU 0 works → then GPU 1 works
- No overlap (not parallel in time)
.to(device)introduces communication cost
Key Insight (Very Important)
👉 Layer-wise model parallelism:
- ✔ Solves memory problem
- ✖ Does NOT solve idle GPU problem