High Performance code with Jax - Tensorflow Colombia





Fecha 04 Mayo 2021

Aca el enlace de la Conferencia

https://www.youtube.com/watch?v=ZzthyO4SYEc

JAX es un paquete de Python que combina una API similar a NumPy con un conjunto de potentes transformaciones componibles para diferenciación automática, vectorización, paralelización y compilación JIT. Su código puede ejecutarse en CPU, GPU o TPU

JAX es una nueva biblioteca de aprendizaje automático de Google diseñada para computación numérica de alto rendimiento. La biblioteca de Autograd tiene la capacidad de diferenciarse a través de cada código nativo de python y NumPy.

JAX se define como "Transformaciones componibles de programas Python+NumPy: diferenciar, vectorizar, JIT a GPU/TPU y más". La biblioteca utiliza la transformación de la función de graduación para convertir una función en una función que devuelve el gradiente de la función original. Jax también ofrece un JIT de transformación de funciones para la compilación justo a tiempo de funciones existentes y vmap y pmap para vectorización y paralelización, respectivamente.

El cambio de PyTorch o Tensorflow 2 a JAX es nada menos que tectónico. PyTorch crea un gráfico durante el paso hacia adelante y degrada durante el paso hacia atrás. JAX, por otro lado, permite al usuario expresar su cálculo como una función de Python, y al transformarlo con grad() proporciona la función de gradiente que se puede evaluar como la función de cálculo, pero en lugar de la salida, proporciona el gradiente de la salida para el primer parámetro que la función tomó como entrada.

Comentarios

Entradas populares de este blog

Support Vector Machine (SVM) in 2 minutes