High Performance code with Jax - Tensorflow Colombia
Fecha 04 Mayo 2021
Aca el enlace de la Conferencia
https://www.youtube.com/watch?v=ZzthyO4SYEc
JAX es un paquete de Python que combina una API similar a NumPy con un conjunto de potentes transformaciones componibles para diferenciación automática, vectorización, paralelización y compilación JIT. Su código puede ejecutarse en CPU, GPU o TPU
JAX es una nueva biblioteca de aprendizaje automático de Google diseñada para computación numérica de alto rendimiento. La biblioteca de Autograd tiene la capacidad de diferenciarse a través de cada código nativo de python y NumPy.
JAX se define como "Transformaciones componibles de programas Python+NumPy: diferenciar, vectorizar, JIT a GPU/TPU y más". La biblioteca utiliza la transformación de la función de graduación para convertir una función en una función que devuelve el gradiente de la función original. Jax también ofrece un JIT de transformación de funciones para la compilación justo a tiempo de funciones existentes y vmap y pmap para vectorización y paralelización, respectivamente.
El cambio de PyTorch o Tensorflow 2 a JAX es nada menos que tectónico. PyTorch crea un gráfico durante el paso hacia adelante y degrada durante el paso hacia atrás. JAX, por otro lado, permite al usuario expresar su cálculo como una función de Python, y al transformarlo con grad() proporciona la función de gradiente que se puede evaluar como la función de cálculo, pero en lugar de la salida, proporciona el gradiente de la salida para el primer parámetro que la función tomó como entrada.
Comentarios
Publicar un comentario