Key Differences
- Execution Model: PyTorch uses eager execution by default (define-by-run). TensorFlow 2.x also defaults to eager but supports
@tf.functionfor graph compilation. - API Style: PyTorch is more Pythonic and intuitive. TensorFlow has Keras as high-level API but lower-level ops can be verbose.
- Debugging: PyTorch integrates naturally with Python debuggers (pdb). TensorFlow graphs are harder to debug.
- Deployment: TensorFlow has better production tooling (TF Serving, TFLite, TF.js). PyTorch catching up with TorchServe, ONNX.
- Research vs Production: PyTorch dominates research papers. TensorFlow stronger in enterprise production.
When to Choose
- PyTorch: Research, prototyping, NLP (Hugging Face), dynamic architectures, when team prefers Pythonic code
- TensorFlow: Mobile/edge deployment, existing TF infrastructure, TPU training, production-first projects
model.eval() and torch.no_grad() in PyTorch.Juniormodel.eval()
Sets the model to evaluation mode. This affects layers that behave differently during training vs inference:
- Dropout: Disabled (no random zeroing)
- BatchNorm: Uses running mean/variance instead of batch statistics
- Does NOT disable gradient computation
torch.no_grad()
Context manager that disables gradient computation:
- Saves memory: No need to store intermediate activations for backward pass
- Faster inference: Skip gradient tape operations
- Does NOT affect layer behavior
# Correct inference pattern - use BOTH model.eval() # Change layer behavior with torch.no_grad(): # Disable gradients outputs = model(inputs) # Don't forget to switch back for training model.train()
What is JAX?
JAX is Google's library for high-performance numerical computing. It's NumPy + automatic differentiation + XLA compilation + vectorization.
Key Features
- Functional paradigm: Pure functions, no hidden state, explicit random keys
- Transformations:
grad(autodiff),jit(compilation),vmap(auto-batching),pmap(parallelization) - XLA compilation: Optimized kernels for GPU/TPU
- Composable: Transformations can be combined freely
import jax.numpy as jnp from jax import grad, jit, vmap # Define a loss function def loss_fn(params, x, y): pred = jnp.dot(x, params) return jnp.mean((pred - y) ** 2) # Get gradient function (automatic!) grad_fn = grad(loss_fn) # JIT compile for speed fast_grad = jit(grad_fn) # Vectorize over batch dimension batched_pred = vmap(predict_fn)
When to Use JAX
- Research requiring custom autodiff: Higher-order gradients, Hessians
- TPU-first projects: JAX has excellent TPU support
- Scientific computing: Physics simulations, differential equations
- When you need vmap: Auto-vectorization is powerful
When NOT to Use JAX
- Need mature ecosystem (Hugging Face, torchvision)
- Team unfamiliar with functional programming
- Standard deep learning tasks where PyTorch/TF suffice
Immediate Fixes
- Reduce batch size: Most direct solution, but affects convergence
- Gradient accumulation: Simulate larger batches without memory cost
- Mixed precision (FP16/BF16): Halves memory, often faster too
# Gradient Accumulation in PyTorch accumulation_steps = 4 optimizer.zero_grad() for i, (inputs, labels) in enumerate(dataloader): outputs = model(inputs) loss = criterion(outputs, labels) / accumulation_steps loss.backward() if (i + 1) % accumulation_steps == 0: optimizer.step() optimizer.zero_grad()
Advanced Techniques
- Gradient checkpointing: Trade compute for memory by recomputing activations
- Model parallelism: Split model across GPUs (for very large models)
- Offloading: Move optimizer states to CPU (DeepSpeed ZeRO)
- 8-bit optimizers: bitsandbytes library reduces optimizer memory
# Mixed Precision Training with PyTorch from torch.cuda.amp import autocast, GradScaler scaler = GradScaler() for inputs, labels in dataloader: optimizer.zero_grad() with autocast(): # FP16 forward pass outputs = model(inputs) loss = criterion(outputs, labels) scaler.scale(loss).backward() # Scaled backward scaler.step(optimizer) scaler.update()
DataLoader Basics
DataLoader wraps a Dataset and provides batching, shuffling, and parallel data loading.
from torch.utils.data import DataLoader, Dataset dataloader = DataLoader( dataset, batch_size=32, shuffle=True, # Shuffle for training num_workers=4, # Parallel data loading pin_memory=True, # Faster GPU transfer prefetch_factor=2, # Batches to prefetch per worker persistent_workers=True # Keep workers alive between epochs )
Optimization Strategies
- num_workers: Start with 4, increase until CPU-bound. Too many causes overhead.
- pin_memory=True: Pre-allocates memory for faster CPU→GPU transfer
- persistent_workers=True: Avoids worker restart overhead between epochs
- prefetch_factor: Load next batches while GPU is computing
Common Issues
- Slow first epoch: Workers initializing. Use persistent_workers.
- Memory leak: Large objects in Dataset.__getitem__(). Process data lazily.
- Bottleneck detection: If GPU util is low, data loading is the bottleneck.