Listen "vmap"
Episode Synopsis
What is vmap? How is it implemented? How does our implementation compare to JAX's? What is a good way of understanding what vmap does? What's up with random numbers? Why are there some issues with the vmap that PyTorch currently ships?Further reading.Tracking issue for vmap support https://github.com/pytorch/pytorch/issues/42368BatchedTensor source code https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/BatchedTensorImpl.h , logical-physical transformation helper code https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/VmapTransforms.h (well documented, worth a read)functorch, the better, more JAX-y implementation of vmap https://github.com/facebookresearch/functorchAutodidax https://jax.readthedocs.io/en/latest/autodidax.html which contains a super simple vmap implementation that is a good model for the internal implementation that PyTorch has
More episodes of the podcast PyTorch Developer Podcast
Compiler collectives
04/08/2024
TORCH_TRACE and tlparse
29/04/2024
Higher order operators
21/04/2024
Inductor - Post-grad FX passes
12/04/2024
CUDA graph trees
24/03/2024
Min-cut partitioner
17/03/2024
AOTInductor
02/03/2024
Tensor subclasses and PT2
24/02/2024
Compiled autograd
19/02/2024
PT2 extension points
05/02/2024
ZARZA We are Zarza, the prestigious firm behind major projects in information technology.