对MoonBit中矩阵运算提供jaxpr接入的提案

JAX是什么,有什么优势

JAX是一个由Google主导用于高性能数值计算和机器学习的Python库,其中大部分矩阵运算子使用方式与Numpy相同,并添加了自动微分,JAX 遵循函数式编程哲学,例如所有矩阵均是不可变的。在AI领域,JAX目前已广泛使用,Google内部的新项目几乎均使用JAX,除Google外知名的大语言模型如OpenLlama和马斯克的Grok的训练均使用JAX,且如Stable Diffusion等模型的推理均可使用JAX。
当一个函数遵循纯函数式写法以及特定规范(JAX The Sharp Bits),使用@jax.jit装饰器后,JAX将Python代码转换成为Jaxpr并生成MLIR StableHLO dialect(Jaxpr及Tensorflow转换到StableHLO dialect非常完备,为官方默认实现。与之对应pytorch的转换路径非常不稳定,大部分pytorch所编写的模型无法直接转换到任何MLIR dialect中。Mojo目前并不开源且其路径明显不同,无法兼容此生态的各种优化,这也导致了Mojo现在以及未来一段时间内不会有GPU的支持)
除Google AMD外国内大厂如字节跳动阿里巴巴也在使用StableHLO dialect作为上层接入IR的主力。目前我所在公司的业务团队也正进行基于Google和AMD主导的下一代StableHLO dialect编译器IREE适配各国产信创硬件加速器及RISCV CPU平台的工作(IREE为目前SPIRV支持最完备的框架,已支持Stable Diffusion和Vision Transformer在SPIRV设备上运行,目前正在推进WebGPU的支持,同时可完全基于LLVM IR生成优化后的矩阵运算而无须使用特定平台手写的asm kernel)。且JAX中原生提供了AOT导出的方法(jax.export.export docs)
在原生的Jaxlib中,使用了XLA作为StableHLO编译器生成优化后的CPU或GPU代码。
CPU上,JAX的性能明显优于底层使用MKL的Numpy和原生Pytorch乃至Julia,这吸引了许多除AI之外原先Julia的用户转向使用JAX编写科学计算程序,对于AI与科学计算交叉的AI4Science项目有着极大的吸引力。
对于GPU TPU及分布式设备,StableHLO不可变矩阵的特性方便了sharding,使得在分布式设备中存在更大的优化潜力。在大部分情况下其性能至少与Pytorch相当,且在对于除Nvidia CUDA以外的平台支持提供了巨大的便利。

Jaxpr是什么

Jaxpr是一种由Python DSL生成的纯函数式的IR(Understanding Jaxprs) 与Python除JAX以外的使用无关。其形式如下

Jaxpr ::= { lambda Var* ; Var+. let
              Eqn*
            in  [Expr+] }

在JAX中的分支函数例如 jax.lax.switch生成的Jaxpr

>>> from jax import lax
>>>
>>> def one_of_three(index, arg):
...   return lax.switch(index, [lambda x: x + 1.,
...                             lambda x: x - 2.,
...                             lambda x: x + 3.],
...                     arg)
...
>>> print(make_jaxpr(one_of_three)(1, 5.))
{ lambda ; a:i32[] b:f32[]. let
    c:i32[] = convert_element_type[new_dtype=int32 weak_type=False] a
    d:i32[] = clamp 0 c 2
    e:f32[] = cond[
      branches=(
        { lambda ; f:f32[]. let g:f32[] = add f 1.0 in (g,) }
        { lambda ; h:f32[]. let i:f32[] = sub h 2.0 in (i,) }
        { lambda ; j:f32[]. let k:f32[] = add j 3.0 in (k,) }
      )
    ] d b
  in (e,) }

由Jaxpr生成MLIR StableHLO dialect

>>> import jax

>>> def f(x, y): return 2 * x + y
>>> x, y = 3, 4

>>> lowered = jax.jit(f).lower(x, y)

>>> # Print lowered HLO
>>> print(lowered.as_text())
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<i32> {mhlo.layout_mode = "default"}, %arg1: tensor<i32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %c = stablehlo.constant dense<2> : tensor<i32>
    %0 = stablehlo.multiply %c, %arg0 : tensor<i32>
    %1 = stablehlo.add %0, %arg1 : tensor<i32>
    return %1 : tensor<i32>
  }
}

>>> compiled = lowered.compile()

>>> # Query for cost analysis, print FLOP estimate
>>> compiled.cost_analysis()[0]['flops']
2.0

>>> # Execute the compiled function!
>>> compiled(x, y)
Array(10, dtype=int32, weak_type=True)

JAX有什么缺陷

由于Python本身设计并非为矩阵计算的DSL。在类型系统上Python无法良好的表达矩阵的Shape polymorphism,只有当调用jaxpr编译器后才能提示存在的问题。且由于Python内置的分支语句与函数式用法的冲突,在JAX中需要调用jax.lax.cond jax.lax.fori_loop等才可以在生成Jaxpr中使用分支(JAX The Sharp Bits)。由于缺乏对模块的管理,Google又设计了NNX。这些不足使得用户在使用JAX的时候获得了不好的体验。

对MoonBit中矩阵运算提供jaxpr接入的提案

综上,我认为将MoonBit中矩阵运算提供Jaxpr接入是解决上述问题的一个好的方案,经过多年的设计后,Jaxpr的除细粒度优化外的基础语法已经稳定。在摆脱了Python DSL的缺陷后,接入Jaxpr可以作为一个使用方便优雅的高性能矩阵运算编程语言。对于Moonbit WASM,可提供生成WebGPU或利用LLVMIR的途径。对于Moonbit native,XLA/IREE均提供了完备的C ffi方案。

3 个赞