vmap

21/06/2021 17 min Episodio 34
vmap

Episode Synopsis

What is vmap? How is it implemented? How does our implementation compare to JAX's? What is a good way of understanding what vmap does? What's up with random numbers? Why are there some issues with the vmap that PyTorch currently ships?Further reading.Tracking issue for vmap support https://github.com/pytorch/pytorch/issues/42368BatchedTensor source code https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/BatchedTensorImpl.h , logical-physical transformation helper code https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/VmapTransforms.h (well documented, worth a read)functorch, the better, more JAX-y implementation of vmap https://github.com/facebookresearch/functorchAutodidax https://jax.readthedocs.io/en/latest/autodidax.html which contains a super simple vmap implementation that is a good model for the internal implementation that PyTorch has