JAX相当于把numpy进行了亮点增强。(1)支持GPU TPU等硬件加速;(2)自动求导。

JAX1 is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance machine learning research.

import jax.numpy as np
from jax import grad, jit, vmap

def predict(params, inputs):
  for W, b in params:
    outputs = np.dot(inputs, W) + b
    inputs = np.tanh(outputs)
  return outputs

def logprob_fun(params, inputs, targets):
  preds = predict(params, inputs)
  return np.sum((preds - targets)**2)

grad_fun = jit(grad(logprob_fun))  # compiled gradient evaluation function
perex_grads = jit(vmap(grad_fun, in_axes=(None, 0, 0)))  # fast per-example grads

  1. JAX