Python, from the ground up Lesson 56 / 60

PyTorch: the modern default

Tensors, autograd, the nn module, and the Python-feel that made PyTorch win.

In 2016 Meta released PyTorch and almost nobody was using it. The default deep learning framework was TensorFlow, which had Google’s marketing budget behind it and a feature called “static graphs” that the marketing said was a benefit. By 2019 the academic ML world had defected to PyTorch en masse. By 2022 industry had followed. By 2026 the situation is settled: PyTorch is the default. New research papers ship PyTorch reference implementations. Hugging Face’s entire ecosystem is PyTorch-first. The major cloud providers all have first-class PyTorch support. JAX has a niche among researchers who care obsessively about compiler-level performance and Google’s internal teams. TensorFlow lives on in Google products and a long tail of legacy projects, but if you are starting a new deep learning project today and you don’t have a specific reason to choose otherwise, you choose PyTorch.

This lesson is about the PyTorch you’ll actually use day-to-day. Tensors. Autograd. The nn.Module API. Optimizers. The DataLoader. Then we’ll define a small classification network and dissect every piece.

Why PyTorch won

The TensorFlow 1.x experience, in 2017, was: define a static computational graph, compile it, then feed data into it. Debugging meant printing tensor shapes by inserting them into the graph as side effects. Control flow meant special TensorFlow ops like tf.cond. The whole thing felt like writing in a different language that happened to share Python’s syntax.

PyTorch’s bet was: just be a Python library. Tensors are objects you can print, slice, and pass around. The graph builds itself as you call operations, and gets thrown away when you’re done. If you want a for loop in your network, you write a for loop. If you want to print intermediate values, you call print. The framework felt like NumPy with extra features. Researchers immediately preferred it. TensorFlow eventually caught up with eager execution in 2.x, but by then the war was over.

The lesson here is not narrowly about deep learning frameworks. It’s about library design: make the easy thing easy, even at some performance cost, and the world will choose you. JAX, the framework that came after PyTorch from Google, took a different bet — its functional purity and JIT compilation give it speed advantages on certain workloads — and JAX has earned a real but limited niche.

Tensors: NumPy plus three things

A PyTorch tensor is exactly what NumPy has called an ndarray since lesson 43, with three additions: GPU support, automatic differentiation, and a slightly different API. You can move a tensor to a GPU and operations on it will run on the GPU. You can ask any tensor to track gradients, and PyTorch will record every operation you do with it.

import torch

# Creation, mostly mirroring NumPy
a = torch.tensor([1.0, 2.0, 3.0])
b = torch.zeros((3, 4))
c = torch.ones((2, 2))
d = torch.randn((100, 10))    # standard normal
e = torch.arange(10)
f = torch.eye(5)              # identity matrix

# Element-wise math, broadcasting, indexing — all NumPy-shaped
print(a + 1)
print(d.shape, d.dtype)       # torch.Size([100, 10]) torch.float32
print(d[0, :3])
print(d.mean(dim=0))          # mean along axis 0 → shape (10,)

The dtype defaults are different from NumPy. PyTorch defaults to float32 for floats (NumPy defaults to float64). For deep learning, float32 is the right default — float64 is twice as slow and twice the memory and almost never gives you a meaningfully better model. Stay with the defaults.

Devices: CPU, CUDA, MPS

Tensors live on a device. By default, CPU. To use a GPU, move the tensor:

device = torch.device(
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using device: {device}")

x = torch.randn(1000, 1000).to(device)
y = torch.randn(1000, 1000).to(device)
z = x @ y                     # matmul on the device

The boilerplate above is the right pattern in 2026. cuda is NVIDIA GPUs (the production case). mps is Apple Silicon’s Metal Performance Shaders — works on M-series Macs, fast enough for development and small models, not yet competitive with NVIDIA for serious training. The cpu fallback exists for environments without a GPU.

Operations require all tensors to be on the same device. A tensor on CPU plus a tensor on GPU is an error. The most common debugging session of any new PyTorch user is forgetting to call .to(device) on either the model or the data.

Autograd: the gradient machine

Any tensor with requires_grad=True tracks every operation it participates in. When you call .backward() on a scalar that depends on it, PyTorch walks the recorded operations in reverse and computes gradients of the scalar with respect to every tensor that had requires_grad=True.

x = torch.tensor(3.0, requires_grad=True)
y = x ** 2 + 2 * x + 1        # y = (x+1)^2
y.backward()                  # compute dy/dx
print(x.grad)                 # 2*(x+1) = 8.0

That’s the whole machinery of training. You don’t write derivatives. You write the forward pass, you compute a loss (a scalar), you call loss.backward(), and PyTorch fills in the .grad attribute of every parameter. The optimizer then uses those gradients to update the parameters.

A subtle but important detail: gradients accumulate. If you call .backward() twice without zeroing the gradients in between, you get the sum of the two gradients in .grad. That’s why every training loop calls optimizer.zero_grad() at the top. It’s a footgun for beginners and a feature for advanced use (gradient accumulation across micro-batches).

The nn.Module: where you define networks

You could define a network as a free function and a bag of tensors. You shouldn’t. The nn.Module base class gives you parameter management, device movement, and serialization for free. Two methods: __init__ declares the layers, forward defines the computation.

import torch.nn as nn
import torch.nn.functional as F

class MLP(nn.Module):
    def __init__(self, in_dim: int, hidden_dim: int, n_classes: int):
        super().__init__()
        self.fc1 = nn.Linear(in_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, n_classes)
        self.dropout = nn.Dropout(p=0.2)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.dropout(x)
        return self.fc3(x)         # logits, not probabilities

model = MLP(in_dim=20, hidden_dim=128, n_classes=3).to(device)
print(model)
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

A few things worth flagging. nn.Linear(in, out) is exactly y = x @ W.T + b with W of shape (out, in) — your standard fully-connected layer. The activation functions live in torch.nn.functional (imported as F by convention), or as modules in nn.ReLU() if you prefer. nn.Dropout is a regularizer that randomly zeroes activations during training and does nothing during evaluation. The forward method returns logits — raw scores, not probabilities. PyTorch’s loss functions expect logits and apply softmax internally; doing softmax yourself in forward and then again in the loss is a classic bug.

For simple stacks where you don’t need any custom logic, nn.Sequential is shorter:

model = nn.Sequential(
    nn.Linear(20, 128),
    nn.ReLU(),
    nn.Dropout(0.2),
    nn.Linear(128, 128),
    nn.ReLU(),
    nn.Dropout(0.2),
    nn.Linear(128, 3),
).to(device)

I use nn.Module for anything I expect to grow. nn.Sequential for throwaway prototypes.

Optimizers and losses

The optimizer holds the model’s parameters and knows how to update them given gradients. The standard choices in 2026:

  • torch.optim.SGD with momentum: classic, still used for computer vision.
  • torch.optim.Adam: the default for most things. Robust, fast convergence, forgiving of bad learning rates.
  • torch.optim.AdamW: Adam with proper weight decay. The standard for transformers and most modern training.
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-2)

The loss function is a callable that takes (predictions, targets) and returns a scalar. The two you’ll use 80% of the time:

  • nn.CrossEntropyLoss for multi-class classification. Takes logits of shape (batch, n_classes) and integer class labels of shape (batch,).
  • nn.MSELoss for regression.
  • nn.BCEWithLogitsLoss for binary classification (binary cross-entropy that takes logits, numerically stable).
criterion = nn.CrossEntropyLoss()

DataLoaders: getting batches of data in

Deep learning trains on batches. You need an object that hands you (inputs, targets) of a fixed batch size, optionally shuffled, optionally on multiple worker processes. PyTorch’s Dataset and DataLoader give you that.

from torch.utils.data import TensorDataset, DataLoader

# Imagine we already have these as tensors
X_train = torch.randn(10000, 20)
y_train = torch.randint(0, 3, (10000,))

train_ds = TensorDataset(X_train, y_train)
train_loader = DataLoader(
    train_ds,
    batch_size=64,
    shuffle=True,
    num_workers=2,    # parallel data loading processes
    pin_memory=True,  # faster CPU→GPU transfer
)

for inputs, targets in train_loader:
    inputs, targets = inputs.to(device), targets.to(device)
    # ... training step ...
    break

For real datasets — images on disk, text from a file, anything that needs preprocessing — you subclass Dataset and define __len__ and __getitem__. The DataLoader wraps it. We’ll use TensorDataset for the toy examples in this module; lesson 57’s training loop assumes you’ve got a working DataLoader.

torch.compile: the 2.x speedup

PyTorch 2.0, released in 2023, introduced torch.compile. You wrap your model and PyTorch JIT-compiles the computation graph for your hardware:

model = torch.compile(model)   # one line

You typically get a 1.5–3x training speedup on modern GPUs. There are caveats — the first iteration is slow because of compilation, dynamic shapes are still rough — but for stable production training runs, torch.compile is essentially free performance. By 2026 it’s the default in most serious projects.

When JAX makes sense instead

Don’t choose JAX as your first deep learning framework. PyTorch is the right default. But you’ll hear about JAX, so the short version:

JAX is functional. Models are pure functions; parameters are explicit pytrees passed in and out. JAX compiles aggressively via XLA, which gives it serious speed advantages on TPUs and on huge batched workloads with static shapes. Google uses JAX heavily internally. Some research groups prefer it. The tradeoff is that the functional style and the compilation-first model are less ergonomic for messy research code with dynamic control flow. Stick with PyTorch unless you have a specific reason and the team to support the choice.

What we have so far

We have tensors. We have a model. We have an optimizer and a loss function. We have a DataLoader. We have everything except the loop that ties them together. Lesson 57 is that loop — the five lines that constitute the actual training step, the bookkeeping around them that makes a real training system, and the framework alternatives (Lightning, Hugging Face Trainer) that let you skip the boilerplate when you don’t need it.

Search