这是我参与8月更文挑战的第11天,活动详情查看:8月更文挑战
从根本上说,JAX 是一个库,提供 API 类似 NumPy,主要用于编写的数组操纵程序进行转换。甚至有人认为 JAX 可以看做 Numpy v2,不仅加快 Numpy 而且为 Numpy 提供自动求导(grad)功能,让我们仅凭借 JAX 就可以去实现一个机器学习框架。
接下来主要就是来解释一下为什么说 JAX 提供 API 类似 NumPy,。现在,你可以把 JAX 看作是在加速器上运行支持自动求导的 NumPy。
import jax
import jax.numpy as jnp
x = jnp.arange(10)
print(x)
复制代码
如果大家熟悉或用过 numpy 写过点东西,上面的代码应该不会陌生,这也就是 JAX的魅力,可以从 numpy 无缝过渡到 JAX 在于你不需要学习一个新的 API。可以将以前用用 numpy 实现的代码,可以用 jnp
代替 np
,程序也可以运行起来,当然也有不同之处,随后会介绍。在 jnp
是 DeviceArray 类型的变量,这也是 JAX 表示数组的方式。
我们现在将计算两个向量的点积,block_until_ready
在无需更改代码在 GPU 的设备运行代码,而不需要改变代码。使用%timeit
来检查性能。
技术细节:当一个 JAX 函数被调用时,相应的操作被派发到一个加速器上,通过是进行异步计算。因此,计算返回的数组不一定在函数返回时就被“填满”。因此,如果不需要立即得到结果,因为是异步计算,所以不会阻塞 Python 的执行。因此,除非设置 block_until_ready,否则我们将只为调度计时,而不是为实际计算计时。参见 JAX 文档中的异步调度
long_vector = jnp.arange(int(1e7))
%timeit jnp.dot(long_vector, long_vector).block_until_ready()
复制代码
The slowest run took 4.37 times longer than the fastest. This could mean that an intermediate result is being cached.
100 loops, best of 5: 6.37 ms per loop
复制代码
JAX 的第一次转换:grad
JAX的一个基本特征是允许转换函数。最常用的转换之一 是 jax.grad
,接收一个用 Python 编写的数值函数,并返回一个新的 Python 函数,计算原函数的梯度。定义一个函数sum_of_squares
,接收一个数组并返回对数组每个元素平方后求和。
def sum_of_squares(x):
return jnp.sum(x**2)
复制代码
对sum_of_squares
应用 jax.grad
将返回一个不同的函数,这个函数就是sum_of_squares
相对于其第一个参数 x 的梯度。
然后,将数组输入这个求导函数来返回相对于数组中每个元素的导数。
sum_of_squares_dx = jax.grad(sum_of_squares)
x = jnp.asarray([1.0, 2.0, 3.0, 4.0])
print(sum_of_squares(x))
print(sum_of_squares_dx(x))
复制代码
0.0
[2. 4. 6. 8.]
复制代码
你可以通过类比向量微积分中的 运算符为 jax.grad,如果函数 输入给了 jax.grad
,也就等同于返回 函数用于计算?梯度的函数。