Triton Puzzles
之前做tilelang puzzles的时候,发现readme里提到是仿照triton puzzles的,但当时感觉triton没有学的必要,就没做
最近发现triton的设计思想和tilelang差异很大,感觉可以开拓一下视野,就找到这个https://github.com/SiriusNEO/Triton-Puzzles-Lite项目看看,这是改进过的轻量版,不是原版triton puzzles,题目内容没变,只是减少了依赖,原本的可视化和jupyter notebook都去掉了,就在.py文件运行,并且附上了作者写的答案,可以对比学习。
环境需要,别的可能也行,但是作者建议这个,这个是肯定可以跑通的,高版本可能报错。
pipinstalltorch==2.5.0# Check triton version: triton==3.1.0安装时,如果能访问外网,实测最快的是直接用pytorch库,其他源都可能会把torch下载限速(因为下载的人太多了,还大,可能把带宽占满了),然后注意这里的cu124,根据显卡的驱动版本来安装,可以先检查cuda驱动版本。然后选一个低于驱动版本的torch whl,这种库都是可以向下兼容,但不能向上兼容。
python3-mpipinstalltorch==2.5.0 --index-url https://download.pytorch.org/whl/cu124 --no-cache-dir--isolated推荐使用<2.0.0的numpy,结果正确性验证时会用numpy,check脚本用了低版本接口,版本高了会出错。
python3-mpipinstallnumpy==1.26.4--isolated运行时设置环境变量,1表示用cpu模式py解释器运行,0则是gpu模式。gpu模式由于显卡版本不同可能出现各种bug,推荐先cpu模式跑通,这也是原始triton puzzles的推荐运行方式。当前仓库的答案,gpu模式下case 11会运行出错。
TRITON_INTERPRET=1python3 puzzles.py-a最后的参数部分
-a#运行全部puzzles-px#运行第x个puzzle-i#运行四个demo-h#显示帮助文档clone下来可以先跑以下指令,验证所有答案cpu模式下是不是都能跑通,能的话说明基础环境配置没问题。
TRITON_INTERPRET=1python3 puzzles_ans.py-aTriton简介
Triton 是由 OpenAI 开源的一种专为深度学习加速设计的编程语言和编译器。
如果你写过 CUDA,你可能会觉得它太底层、开发周期太长;如果你只用 PyTorch,你可能会发现很多自定义的算子(比如各种新型的 Attention 或量化算子)无法获得极致的性能。Triton 的诞生,正是为了在“开发效率”与“极致性能”之间取得完美的平衡。
1. Triton 解决的核心痛点
在传统的 GPU 算子开发中,通常面临两极分化:
高端玩家(写 CUDA C++): 可以手动控制线程块、共享内存(Shared Memory)和寄存器,性能毁天灭地,但开发极其痛苦,且代码很难跨硬件(比如从 NVIDIA 转到 AMD)复用。
普通玩家(写 PyTorch/TensorFlow): 拼凑现有的 API(如 torch.relu + torch.matmul),开发极快,但会在显存中产生大量中间变量,造成频繁的显存读写(Memory Bound),浪费算力。
Triton 的核心思想是:让没有 CUDA 经验的深度学习研究员,也能用类似 Python 的语法,写出性能媲美甚至超越专家级 CUDA 的硬件加速算子。
2. Triton 的核心设计理念:基于块(Block-based)的编程
这也是 Triton 与 CUDA 最本质的区别:
CUDA 是“基于线程(Thread-based)”的: 你需要精确计算每个 Thread 的 ID,去算它该读哪一个具体的显存地址,还要手动处理线程之间的同步(__syncthreads())和数据共享。
Triton 是“基于块(Block-based)”的: 它把张量块(Block)作为一等公民(First-class citizen)。你不需要操心单个线程,而是直接对一个分块进行加载(tl.load)、计算(tl.dot)和存储(tl.store)。
并且,triton除了是基于数据块的编程,还是声明式编程,而不是CUDA的过程式编程,也就是你只用写要对这个数据块做什么,而不需要写怎么做,编译器会把做什么转化成怎么做的机器码。
3. 编译器在幕后做了什么?
既然写起来像 Python 一样简单,那极致的性能是怎么来的?这全靠 Triton 编译器。它会把你的 Python 风格代码编译成高效的机器码(通过 LLVM IR 到 PTX/AMDGCN),自动帮你做好以下最头疼的硬件优化:
自动内存合并(Memory Coalescing): 自动优化全局显存(Global Memory)的访问模式,确保带宽跑满。
自动管理共享内存(Shared Memory Allocation): 你不需要像写 CUDA 那样手动声明shared数组,编译器会自己决定什么时候把数据缓存在片上高速存储里。
指令流水线与排程(Instruction Scheduling): 自动隐藏访存延迟,让计算单元(Tensor Cores)和访存单元能够高效并发。
注意这里和tilelang的设计思想不同,并不会先映射到CUDA代码,再编译。而是自定义了TTIR(Triton IR),生成TTIR后,下一步就会映射到PTX、SaaS代码了,不会经过CUDA,也就是triton可以被视为一个独立的语言,有自己的编译路径,而不是CUDA语法糖。
4. 谁在用 Triton?
如今 Triton 已经成为大模型时代基础设施的绝对主力:
PyTorch 2.0+ 的核心: PyTorch 2.0 引入的重磅编译功能 TorchInductor,其后端默认就是将 PyTorch 代码自动生成为 Triton 内核,这也是其实现图编译加速的秘密武器。
FlashAttention: 著名的闪电注意力机制,其后续的很多高效变体和工程实现(如 FlashAttention-3)都大量采用了 Triton 进行快速迭代。
大模型推理加速: 比如 vLLM、DeepSpeed 以及各类轻量级量化插件,里面普遍包含大量用 Triton 编写的定制化算子(如上面我们聊到的量化 GEMM)。
如果你想深入 AI 芯片底层硬件加速,或者想为自己的大模型设计专属的奇门遁甲算子,Triton 是目前投产比(ROI)最高、最值得学习的技术。
Demos
demo 1
数据搬运是GPU编程中最核心的概念,第一个示例主要熟悉tl.load搬运数据
tl.load(ptr, mask)参数是两个张量,ptr是一个指针数组,表示数据搬运源地址,数组内每个指针对应一个要搬运的元素。mask是一个掩码数组,数据类型是bool,用0/1表示ptr数组中传入的每个指针,是否搬运。
需要额外引入mask的原因是triton里的所有张量(数据块)的大小都是二的幂次,如果我们想灵活搬运一个大小不对齐的张量时,比如大小5,可以传入一个刚好大于这个张量大小的指针数组,长度对齐2的幂次,然后用mask来约束搬运范围,比如mask就是[1,1,1,1,1,0,0,0],表示前五个位置利用指针地址搬运,后三个位置不进行操作。
需要注意的是,
- 这里传入的x_ptr,已经不是torch tensor了,而是底层数据的首地址,类似c的数组首地址指针,这也是命名上带一个ptr的原因,因此我们传入指针ptr数组和mask,需要人为避免越界,如果x_ptr对应的tensor只有八个元素,那么就不能访问大于8的位置,否则会运行错误或者读到垃圾值。编译器不会阻止你,编译时的思路是类C的,允许你直接用指针寻址。
- 如果指针数组大小超过tensor了,但是mask限制了读取范围,不会出问题,因为mask为0的位置,不会真的去读内存,而是直接返回一个值表示不操作,可以在
tl.load(ptr, mask,0)操作时传入第三个参数,表示mask为0的位置填充什么值,如果不传入第三个参数,默认填充0
定义讲完了,来看这个算子的具体事项。range = tl.arange(0, 8)类似torch.arrange,生成一个公差为1的等差数列,左闭右开。
x = tl.load(x_ptr + range, range < 5, 0),这一行有很多看点。
x_ptr + range,这里的x_ptr本身是一个指针,也就是一个标量,但是range是刚才生成的数据块,两者相加,这里triton规定,遵循torch/numpy的广播规则,把标量广播到和张量一样的shape,再执行相加。也就是此时形成了一个
[x_ptr,x_ptr+1,x_ptr+2,...,x_ptr+7]的指针数组,接下来会去这个数组内的位置搬运数据。
range < 5类似,5是一个标量,会广播到和range一样大,然后<操作会返回一个bool数组,用这个方式就构造了一个[1 1 1 1 1 0 0 0]的maskx = tl.load(x_ptr + range, range < 5, 0)最后load返回的是一个triton数据块,需要把它复制给一个变量保存下来。
demo1[(1, 1, 1)](torch.ones(4, 3))最后是triton内核的启动方法,triton设计时DSL还没这么多,很多设计师对齐CUDA,比如这里(1, 1, 1)就是CUDA启动时传入的launch参数dim3,表示grid shape,或者说三个维度的block个数。
传递给函数的直接参数则在后面圆括号内,这里传入一个二维张量,(torch.ones(4, 3))。可能会好奇,这里传入的是二维张量,但kernel内看起来是把他当成一维数组用的?这也是类C设计带来的,CUDA编程时,多维数组不管几维,都是当成一维数组使用,用的时候再多次寻址实现多维数组的效果,triton继承了这一点,这个张量4*3=12个元素,在triton kernel内会看成一个长度12的连续内存。
r""" ## Introduction To begin with, we will only use `tl.load` and `tl.store` in order to build simple programs. """""" ### Demo 1 Here's an example of load. It takes an `arange` over the memory. By default the indexing of torch tensors with column, rows, depths or right-to-left. It also takes in a mask as the second argument. Mask is critically important because all shapes in Triton need to be powers of two. Expected Results: [0 1 2 3 4 5 6 7] [1. 1. 1. 1. 1. 0. 0. 0.] Explanation: tl.load(ptr, mask) tl.load use mask: [0 1 2 3 4 5 6 7] < 5 = [1 1 1 1 1 0 0 0] """@triton.jitdefdemo1(x_ptr):range=tl.arange(0,8)# print works in the interpreterprint(range)x=tl.load(x_ptr+range,range<5,0)print(x)defrun_demo1():print("Demo1 Output: ")demo1[(1,1,1)](torch.ones(4,3))print_end_line()"""demo 2
仍然是load,只是这次需要load一个复杂一点的二维区域i < 4 and j < 3
那么用一个range mask就有点难做到了,可以用两个。
首先构造两个等差数列,一个对应行,一个对应列。然后给他们升维,类似torch.unsqueeze,弄完之后两个mask的shape分别是(8,1)(1,4)
i_range=tl.arange(0,8)[:,None]j_range=tl.arange(0,4)[None,:]range = i_range * 4 + j_range让这两个mask做加法,遵循torch/numpy广播规则,会都先变成(8,4)再执行加法。并且加之前,先把行张量乘上每一行的元素个数,这样最后得到的结果,每个位置的值都等于,把这个张量展开到一维后这个位置的编号,可以用来构造mask数组了
(i_range < 4) & (j_range < 3)构造mask时可以把两个条件取and,这里重载了&的规则,不是py里的按位与,而是表示and。这样我们就限制了只拷贝i < 4 and j < 3的区域
""" ### Demo 2: You can also use this trick to read in a 2d array. Expected Results: [[ 0 1 2 3] [ 4 5 6 7] [ 8 9 10 11] [12 13 14 15] [16 17 18 19] [20 21 22 23] [24 25 26 27] [28 29 30 31]] [[1. 1. 1. 0.] [1. 1. 1. 0.] [1. 1. 1. 0.] [1. 1. 1. 0.] [0. 0. 0. 0.] [0. 0. 0. 0.] [0. 0. 0. 0.] [0. 0. 0. 0.]] Explanation: tl.load use mask: i < 4 and j < 3. """@triton.jitdefdemo2(x_ptr):i_range=tl.arange(0,8)[:,None]j_range=tl.arange(0,4)[None,:]range=i_range*4+j_range# print works in the interpreterprint(range)x=tl.load(x_ptr+range,(i_range<4)&(j_range<3),0)print(x)defrun_demo2():print("Demo2 Output: ")demo2[(1,1,1)](torch.ones(4,4))print_end_line()demo 3
这节主要是学习tl.store写入操作,和读取tl.load一起构成了完整的数据搬运。
tl.store(ptr, value, mask)参数和tl.load类似,也是传入一个指针数组,一个mask,只不过这是个无返回值的函数,所以ptr就是目的地址,源则是value。ptr类似前面load的规则,类C的指针数组,手动寻址。但value是类似py张量,可以传入一个标量进行广播,也可以传入一个前面load进来的张量,不能传入和ptr类似的指针数组,也就是源不是给传指针,寻址,而是直接给出值。
一般的范式是,读取到一个张量,做想做的操作,然后再写入,也就是读取写入之间一定有一个张量来倒手。
x=tl.load(x_ptr,mask)tl.store(y_ptr,x,mask)来看具体实现,z = tl.store(z_ptr + range, 10, range < 5)这里用z接受了返回值,其实是一个陷阱,tl.store无返回值,所以尝试print(z)会报错。想看结果,数据已经被写入z_ptr为首地址的张量了,在kernel内只有首地址指针,没有z_ptr对应的张量对象,看不了,必须从kernel里返回后host侧才能看。
""" ### Demo 3 The `tl.store` function is quite similar. It allows you to write to a tensor. Expected Results: tensor([[10., 10., 10.], [10., 10., 1.], [ 1., 1., 1.], [ 1., 1., 1.]]) Explanation: tl.store(ptr, value, mask) here range < 5 corresponds to the 2D-mask [[1. 1. 1.] [1. 1. 0.] [0. 0. 0.] [0. 0. 0.]] """@triton.jitdefdemo3(z_ptr):range=tl.arange(0,8)z=tl.store(z_ptr+range,10,range<5)defrun_demo3():print("Demo3 Output: ")z=torch.ones(4,3)demo3[(1,1,1)](z)print(z)print_end_line()"""demo 4
前三个都是单线程的,但作为GPU编程当然可以根据数据块编号不同,做不同的操作,这节来看如何利用tl.program_id确定所在块号,然后执行不同操作。
tl.program_id(0)这里的0,1,2分别是取出这个数据块的三个维度编号,三个维度是我们启动内核时传入的,比如这里就是demo4[(3, 1, 1)](x),表示0维度长度3,另外1,2维度长度1,也就是有3 * 1 * 1 = 3个block。
x = torch.ones(2, 4, 4)传入的张量展平后有32个元素,想要搬运前20个。均分给三个block实现,考虑到每次搬运操作的长度都是二的幂次,最少的搬运方式是,每个block搬8个元素,前两个block都全搬,最后一个block只用搬前四个,设一个mask实现这一点。
kernel内,range = tl.arange(0, 8) + pid * 8实现了每个block搬运的位置不同,也就是根据block id进行偏移。每个都搬长度为8的区间,所以生成一个长度8的等差数列,然后累加上块偏移,就是这个块负责的地址范围
range < 20为了只搬前20个,增加一个mask限制,这个限制只会让最后一个block的mask是前四个1,后四个0,对前两个block无影响。
""" ### Demo 4 You can only load in relatively small `blocks` at a time in Triton. To work with larger tensors you need to use a program id axis to run multiple blocks in parallel. Here is an example with one program axis with 3 blocks. Expected Results: Print for each [0] [1. 1. 1. 1. 1. 1. 1. 1.] Print for each [1] [1. 1. 1. 1. 1. 1. 1. 1.] Print for each [2] [1. 1. 1. 1. 0. 0. 0. 0.] Explanation: This program launch 3 blocks in parallel. For each block (pid=0, 1, 2), it loads 8 elements. Note that similar to demo3, multi-dimensional tensors are flattened when we use pointer (i.e. continuous in memory). """@triton.jitdefdemo4(x_ptr):pid=tl.program_id(0)range=tl.arange(0,8)+pid*8x=tl.load(x_ptr+range,range<20)print("Print for each",pid,x)defrun_demo4():print("Demo4 Output: ")x=torch.ones(2,4,4)demo4[(3,1,1)](x)print_end_line()