JAX是什么,有什么优势
JAX是一个由Google主导用于高性能数值计算和机器学习的Python库,其中大部分矩阵运算子使用方式与Numpy相同,并添加了自动微分,JAX 遵循函数式编程哲学,例如所有矩阵均是不可变的。在AI领域,JAX目前已广泛使用,Google内部的新项目几乎均使用JAX,除Google外知名的大语言模型如OpenLlama和马斯克的Grok的训练均使用JAX,且如Stable Diffusion等模型的推理均可使用JAX。
当一个函数遵循纯函数式写法以及特定规范(JAX The Sharp Bits),使用@jax.jit
装饰器后,JAX将Python代码转换成为Jaxpr并生成MLIR StableHLO dialect(Jaxpr及Tensorflow转换到StableHLO dialect非常完备,为官方默认实现。与之对应pytorch的转换路径非常不稳定,大部分pytorch所编写的模型无法直接转换到任何MLIR dialect中。Mojo目前并不开源且其路径明显不同,无法兼容此生态的各种优化,这也导致了Mojo现在以及未来一段时间内不会有GPU的支持)
除Google AMD外国内大厂如字节跳动和阿里巴巴也在使用StableHLO dialect作为上层接入IR的主力。目前我所在公司的业务团队也正进行基于Google和AMD主导的下一代StableHLO dialect编译器IREE适配各国产信创硬件加速器及RISCV CPU平台的工作(IREE为目前SPIRV支持最完备的框架,已支持Stable Diffusion和Vision Transformer在SPIRV设备上运行,目前正在推进WebGPU的支持,同时可完全基于LLVM IR生成优化后的矩阵运算而无须使用特定平台手写的asm kernel)。且JAX中原生提供了AOT导出的方法(jax.export.export
docs)
在原生的Jaxlib中,使用了XLA作为StableHLO编译器生成优化后的CPU或GPU代码。
CPU上,JAX的性能明显优于底层使用MKL的Numpy和原生Pytorch乃至Julia,这吸引了许多除AI之外原先Julia的用户转向使用JAX编写科学计算程序,对于AI与科学计算交叉的AI4Science项目有着极大的吸引力。
对于GPU TPU及分布式设备,StableHLO不可变矩阵的特性方便了sharding,使得在分布式设备中存在更大的优化潜力。在大部分情况下其性能至少与Pytorch相当,且在对于除Nvidia CUDA以外的平台支持提供了巨大的便利。
Jaxpr是什么
Jaxpr是一种由Python DSL生成的纯函数式的IR(Understanding Jaxprs) 与Python除JAX以外的使用无关。其形式如下
Jaxpr ::= { lambda Var* ; Var+. let
Eqn*
in [Expr+] }
在JAX中的分支函数例如 jax.lax.switch
生成的Jaxpr
>>> from jax import lax
>>>
>>> def one_of_three(index, arg):
... return lax.switch(index, [lambda x: x + 1.,
... lambda x: x - 2.,
... lambda x: x + 3.],
... arg)
...
>>> print(make_jaxpr(one_of_three)(1, 5.))
{ lambda ; a:i32[] b:f32[]. let
c:i32[] = convert_element_type[new_dtype=int32 weak_type=False] a
d:i32[] = clamp 0 c 2
e:f32[] = cond[
branches=(
{ lambda ; f:f32[]. let g:f32[] = add f 1.0 in (g,) }
{ lambda ; h:f32[]. let i:f32[] = sub h 2.0 in (i,) }
{ lambda ; j:f32[]. let k:f32[] = add j 3.0 in (k,) }
)
] d b
in (e,) }
由Jaxpr生成MLIR StableHLO dialect
>>> import jax
>>> def f(x, y): return 2 * x + y
>>> x, y = 3, 4
>>> lowered = jax.jit(f).lower(x, y)
>>> # Print lowered HLO
>>> print(lowered.as_text())
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}, %arg1: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
%c = stablehlo.constant dense<2> : tensor<i32>
%0 = stablehlo.multiply %c, %arg0 : tensor<i32>
%1 = stablehlo.add %0, %arg1 : tensor<i32>
return %1 : tensor<i32>
}
}
>>> compiled = lowered.compile()
>>> # Query for cost analysis, print FLOP estimate
>>> compiled.cost_analysis()[0]['flops']
2.0
>>> # Execute the compiled function!
>>> compiled(x, y)
Array(10, dtype=int32, weak_type=True)
JAX有什么缺陷
由于Python本身设计并非为矩阵计算的DSL。在类型系统上Python无法良好的表达矩阵的Shape polymorphism,只有当调用jaxpr编译器后才能提示存在的问题。且由于Python内置的分支语句与函数式用法的冲突,在JAX中需要调用jax.lax.cond
jax.lax.fori_loop
等才可以在生成Jaxpr中使用分支(JAX The Sharp Bits)。由于缺乏对模块的管理,Google又设计了NNX。这些不足使得用户在使用JAX的时候获得了不好的体验。
对MoonBit中矩阵运算提供jaxpr接入的提案
综上,我认为将MoonBit中矩阵运算提供Jaxpr接入是解决上述问题的一个好的方案,经过多年的设计后,Jaxpr的除细粒度优化外的基础语法已经稳定。在摆脱了Python DSL的缺陷后,接入Jaxpr可以作为一个使用方便优雅的高性能矩阵运算编程语言。对于Moonbit WASM,可提供生成WebGPU或利用LLVMIR的途径。对于Moonbit native,XLA/IREE均提供了完备的C ffi方案。