이 글은 Triton 내부 시리즈의 3부로, MLIR 내부와 백엔드 컴파일 파이프라인을 다룹니다. Triton이 IR을 대상 하드웨어로 점진적으로 내리기 위해 활용하는 MLIR 패스들을 살펴보고, 예제 커널에서 생성된 TTGIR를 단계별로 워크스루합니다.
이 글은 Triton 내부 시리즈의 연속입니다. 이번 글에서는 Triton의 MLIR 내부를 파고듭니다. 백엔드 컴파일 과정으로 더 깊이 들어가 Triton이 IR을 대상 하드웨어로 점진적으로 내리기 위해 사용하는 MLIR 패스들을 살펴봅니다.
이전 글에서는 다음 주제를 다뤘습니다:
Triton은 매우 활발히 개발되는 프로젝트입니다. 이 글의 내용이 현재 시점과 맞는지 확인하기 위해 다음 커밋 해시를 기준으로 작성했습니다:
1
2
$ git rev-parse main
9baa051fa9dd00cd7255e750c71224153aecd3f0
이 글은 이전 글에서 다룬 핵심 개념에 익숙하다고 가정합니다. 예를 들어 Triton 컴파일러, MLIR(및 table-gen), Triton Python 바인딩에 대한 친숙함(깊은 이해는 아니어도)을 가정합니다.
이 시리즈를 시작할 때, ML 컴파일러와 가속기에서 MLIR이 어떻게 쓰이는지 이해하는 데 도움이 되었던 TensorFlow 영상이 하나 있었습니다. Triton에 특화된 내용은 아니지만 개념을 잘 개관해 줍니다.
이전 글에서 NVidia 백엔드를 살펴보기 시작했습니다. 컴파일의 여러 단계를 순회하며 진행 중인 IR을 생성하는 add_stages 함수를 마주쳤죠.
1
2
3
4
5
6
def add_stages(self, stages, options):
stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options)
stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, self.capability)
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, self.capability)
stages["ptx"] = lambda src, metadata: self.make_ptx(src, metadata, options, self.capability)
stages["cubin"] = lambda src, metadata: self.make_cubin(src, metadata, options, self.capability)
이 구현들은 triton/third_party/<vendor/provider>/backend/compiler.py 파일들에 있습니다. 제 환경에는 NVidia GPU만 있어서 이 글에서는 NVidia 백엔드만 보겠습니다.
코드 포인터:
make_ttirmake_ttgir이 중 첫 번째 함수인 make_ttir를 보면, 컴파일 과정의 첫 단계로 텐서 프로그램을 생성합니다. 이는 입력 모듈에 일련의 패스를 실행해 수행됩니다. 구현은 아래와 같습니다:
1
2
3
4
5
6
7
8
9
10
11
12
13
def make_ttir(mod, metadata, opt):
pm = ir.pass_manager(mod.context)
pm.enable_debug()
passes.common.add_inliner(pm)
passes.ttir.add_rewrite_tensor_pointer(pm)
passes.ttir.add_combine(pm)
passes.common.add_canonicalizer(pm)
passes.ttir.add_reorder_broadcast(pm)
passes.common.add_cse(pm)
passes.common.add_licm(pm)
passes.common.add_symbol_dce(pm)
pm.run(mod)
return mod
libtriton의 ir와 passes 모듈을 사용한다는 것을 확인할 수 있습니다.
이 모듈들은 triton/python/src/main.cc에서 pybind11로 정의되어 있으며, C++ 계층에 정의된 컴파일러 패스에 대한 파이썬 바인딩을 제공합니다.
1
2
3
4
5
6
7
8
9
10
11
12
PYBIND11_MODULE(libtriton, m) {
m.doc() = "Python bindings to the C++ Triton API";
init_triton_stacktrace_hook(m);
init_triton_env_vars(m);
### IR and PASSESS MODULES ###
init_triton_ir(m.def_submodule("ir"));
init_triton_passes(m.def_submodule("passes"));
### --------------------- ###
init_triton_interpreter(m.def_submodule("interpreter"));
init_triton_llvm(m.def_submodule("llvm"));
FOR_EACH_P(INIT_BACKEND, TRITON_BACKENDS_TUPLE)
}
여기부터 ir와 passes 모듈을 파고들겠습니다.
이전 글에서 ir 모듈을 간단히 다루며 MLIR PassManager 레퍼런스를 제공한다는 것을 확인했습니다.
높은 수준에서 PassManager는 패스들의 시퀀스를 담는 컨테이너입니다. 모듈에 대해 일련의 패스를 실행해 IR을 대상 하드웨어에 맞춰 점진적으로 내릴 수 있게 해줍니다.
PassManager 사용의 간단한 예시입니다:
1
2
3
4
5
6
7
8
mlir::ModuleOp module = mlir::ModuleOp::create(mlir::UnknownLoc::get());
mlir::PassManager pm(module);
// Add some passes
pm.addPass(mlir::createCSEPass());
pm.addPass(mlir::createDeadCodeEliminationPass());
pm.run(module);
이 "Hello world" 예제는 모듈 위에서 두 개의 패스를 실행하는 방법을 보여줍니다. 첫 번째 패스는 공통 부분식 제거(CSE), 두 번째는 죽은 코드 제거(DCE)입니다.
다음과 같은 장난감 언어의 함수를 생각해 보겠습니다:
1
2
3
4
5
def foo(b, c):
a = b + c
d = b + c
e = 2 * a
return d
CSE 패스를 실행한 후 코드는 다음처럼 최적화됩니다:
1
2
3
4
5
def foo(b, c):
a = b + c
d = a
e = 2 * a
return d
CSE 패스는 b + c가 두 번 계산됨을 파악하고 두 번째 계산을 a의 값으로 대체합니다.
다음 단계에서, DCE 패스는 함수에서 사용되지 않는 변수 e를 제거합니다. 최종 최적화 코드는 다음과 같습니다:
1
2
3
4
def foo(b, c):
a = b + c
d = a
return d
passes 모듈은 PassManager에 추가할 수 있는 패스 집합을 제공합니다. 공통, ttir, ttgir 등 다양한 유형의 패스에 대한 서브모듈이 있습니다.
예를 들어 passes/common은 CSE, LICM 등을 제공합니다. 몇 가지를 살펴보겠습니다:
공통 패스들은 모두 MLIR 라이브러리에 제공되는 일반적인 MLIR 패스들입니다. Triton은 이 패스들을 자체 API로 감쌉니다.
Triton은 공통 패스의 일환으로 상수 전파, 죽은 코드 제거, 인라이닝, 정규화(canonicalization), 공통 부분식 제거, 루프 불변 코드 이동(호이스팅)을 수행합니다.
1
2
3
4
5
6
7
8
9
void init_triton_passes_common(py::module &&m) {
using namespace mlir;
ADD_PASS_WRAPPER_0("add_sccp", createSCCPPass);
ADD_PASS_WRAPPER_0("add_symbol_dce", createSymbolDCEPass);
ADD_PASS_WRAPPER_0("add_inliner", createInlinerPass);
ADD_PASS_WRAPPER_0("add_canonicalizer", createCanonicalizerPass);
ADD_PASS_WRAPPER_0("add_cse", createCSEPass);
ADD_PASS_WRAPPER_0("add_licm", createLoopInvariantCodeMotionPass);
}
이들 패스에 대한 문서 참고: MLIR Passes
Triton IR 패스는 passes/ttir 모듈에 정의되어 있습니다. 이 패스들은 Triton에 특화되어 있으며 triton/lib/Dialect/Triton/Transforms에 정의되어 있습니다.
Triton은 연산을 결합하고, 브로드캐스트, splat을 앞으로 당기기 위해 연산 순서를 재배열하고, tensor pointer의 load/store를 다루는 등의 패스를 수행합니다.
참조 table-gen 파일들
컴파일 과정의 다음 단계는 Triton GPU IR을 생성하는 것입니다. 이는 make_ttgir 함수로 수행됩니다. 아래는 NVIDIA 버전의 참조입니다:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def make_ttgir(mod, metadata, opt, capability):
...truncated...
# TTIR -> TTGIR
pm = ir.pass_manager(mod.context)
pm.enable_debug()
...truncated...
nvidia.passes.ttnvgpuir.add_plan_cta(pm, cluster_info)
passes.ttgpuir.add_remove_layout_conversions(pm)
passes.ttgpuir.add_optimize_thread_locality(pm)
passes.ttgpuir.add_accelerate_matmul(pm)
passes.ttgpuir.add_remove_layout_conversions(pm)
passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80)
passes.common.add_cse(pm)
if capability // 10 >= 8:
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
passes.ttgpuir.add_pipeline(pm, opt.num_stages)
...truncated...
pm.run(mod)
return mod
Triton GPU IR 생성은 common, nvidia, ttgpuir 패스들에 의존합니다. common 패스들은 이미 다뤘습니다. nvidia 및 ttgpuir 패스의 일부로, Triton 컴파일러는 앞서 생성된 IR에 대해 여러 GPU 특화 최적화 패스를 적용합니다. 잘 알려진 최적화 패스들은 다음과 같습니다:
GPU lowering을 위한 MLIR 패스를 다루기에는 이 정도 수준이 적절하다고 생각합니다. 더 들어가면 구현 세부 사항으로 너무 깊어집니다. 독자는 Triton 코드베이스를 직접 살펴보고 특정 패스의 코드를 읽어 더 배우길 권합니다.
예: Accelerate Matmul. 더 많은 변환은 디렉터리를 따라가 보세요.
LLIR 및 PTX 생성은 1부에서 다뤘으므로 여기서는 반복하지 않겠습니다. 다만 과정은 유사하며, GPU IR에 패스를 적용해 LLVM IR로 내린 뒤 PTX를 생성합니다.
여러 패스 중 하나를 거친 후 생성된 IR을 워크스루해 봅시다. Triton GPU IR은 GPU 연산의 특성이 더 많이 드러나므로 좋은 선택입니다.
먼저 Triton 저장소의 튜토리얼 예제 중 하나(vector-add)를 컴파일해 보겠습니다.
1
2
3
4
5
$ python3 python/triton/tools/compile.py \
--kernel-name add_kernel \
--signature "*fp32,*fp32,*fp32,i32,64" \
--grid=1024,1024,1024 \
python/tutorials/01-vector-add.py
이는 컴파일된 커널과 IR 아티팩트를
~/.triton/cache/디렉터리에 저장합니다.
컴파일 후 캐시 디렉터리에는 다음 파일들이 보입니다:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
-> % ll -R ~/.triton/cache
/home/ksharma/.triton/cache:
total 16K
drwxrwxr-x 2 ksharma ksharma 4.0K Sep 5 22:03 ht7F5vagGWCbXzZnYSv8P9EWINzpLWcHYcdPh9m8Dvg
drwxrwxr-x 2 ksharma ksharma 4.0K Sep 5 22:04 KhrPjVaTuMNfu7c00XjCY5ZXpWaGyn97op4fqj6nD_Q
drwxrwxr-x 2 ksharma ksharma 4.0K Sep 5 22:03 P4-eJXPkRvvD0Z7CcR_QHVX4oqH1l6K0oPt8Posthe0
drwxrwxr-x 2 ksharma ksharma 4.0K Sep 5 22:03 RK5K7n7w7g3VToYM9EYn47bO2r6HoiisdZiDAbimv2A
/home/ksharma/.triton/cache/ht7F5vagGWCbXzZnYSv8P9EWINzpLWcHYcdPh9m8Dvg:
total 36K
-rw-rw-r-- 1 ksharma ksharma 6.7K Sep 5 22:03 add_kernel.cubin
-rw-rw-r-- 1 ksharma ksharma 685 Sep 5 22:03 add_kernel.json
-rw-rw-r-- 1 ksharma ksharma 6.6K Sep 5 22:03 add_kernel.llir
-rw-rw-r-- 1 ksharma ksharma 3.9K Sep 5 22:03 add_kernel.ptx
-rw-rw-r-- 1 ksharma ksharma 3.4K Sep 5 22:03 add_kernel.ttgir
-rw-rw-r-- 1 ksharma ksharma 3.0K Sep 5 22:03 add_kernel.ttir
-rw-rw-r-- 1 ksharma ksharma 679 Sep 5 22:03 __grp__add_kernel.json
/home/ksharma/.triton/cache/KhrPjVaTuMNfu7c00XjCY5ZXpWaGyn97op4fqj6nD_Q:
total 36K
-rw-rw-r-- 1 ksharma ksharma 5.5K Sep 5 22:04 add_kernel.cubin
-rw-rw-r-- 1 ksharma ksharma 686 Sep 5 22:04 add_kernel.json
-rw-rw-r-- 1 ksharma ksharma 4.1K Sep 5 22:04 add_kernel.llir
-rw-rw-r-- 1 ksharma ksharma 2.9K Sep 5 22:04 add_kernel.ptx
-rw-rw-r-- 1 ksharma ksharma 3.3K Sep 5 22:04 add_kernel.ttgir
-rw-rw-r-- 1 ksharma ksharma 2.9K Sep 5 22:04 add_kernel.ttir
-rw-rw-r-- 1 ksharma ksharma 679 Sep 5 22:04 __grp__add_kernel.json
/home/ksharma/.triton/cache/P4-eJXPkRvvD0Z7CcR_QHVX4oqH1l6K0oPt8Posthe0:
total 28K
-rw-rw-r-- 1 ksharma ksharma 26K Sep 5 22:03 cuda_utils.so
/home/ksharma/.triton/cache/RK5K7n7w7g3VToYM9EYn47bO2r6HoiisdZiDAbimv2A:
total 20K
-rw-rw-r-- 1 ksharma ksharma 17K Sep 5 22:03 __triton_launcher.so
ht7F5vagGWCbXzZnYSv8P9EWINzpLWcHYcdPh9m8Dvg 디렉터리를 살펴보겠습니다. ~/.triton/cache/ht7F5vagGWCbXzZnYSv8P9EWINzpLWcHYcdPh9m8Dvg/add_kernel.ttgir의 내용을 gist에 올려두었습니다. 한 단계씩 살펴봅시다:
첫 섹션은 생성된 코드에 대한 GPU 메타데이터를 제공하는 모듈 속성을 정의합니다:
1
2
3
4
5
6
module attributes {
"triton_gpu.num-ctas" = 1 : i32,
"triton_gpu.num-warps" = 4 : i32,
triton_gpu.target = "cuda:89",
"triton_gpu.threads-per-warp" = 32 : i32
}
1
2
3
4
5
6
7
8
tt.func public @add_kernel(
%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32},
%arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32},
%arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32},
%arg3: i32 {tt.divisibility = 16 : i32}
) attributes {noinline = false} {
...
}
이는 파이썬 코드와 직접적으로 매핑되므로 익숙해 보일 것입니다:
1
2
3
4
5
6
7
def add_kernel(x_ptr, # 첫 번째 입력 벡터에 대한 포인터
y_ptr, # 두 번째 입력 벡터에 대한 포인터
output_ptr, # 출력 벡터에 대한 포인터
n_elements, # 벡터의 크기
BLOCK_SIZE: tl.constexpr, # 각 프로그램이 처리할 요소 수
# 참고: shape 값으로 사용하기 위해 `constexpr`
):
%arg0, %arg1, %arg2: 16으로 나누어떨어지는 정렬 제약(16바이트 정렬)을 갖는 float32 배열에 대한 포인터%arg3: 16으로 나누어떨어지는 제약을 갖는 정수BLOCK_SIZE는 보이지 않습니다. IR에서 상수로 정의됨을 뒤에서 보게 됩니다.코드는 상수(BLOCK_SIZE)를 만들고, 프로그램 ID를 가져오며, 0부터 1024까지의 인덱스 범위를 생성하면서 커널 실행을 초기화합니다.
1
2
3
4
%c1024_i32 = arith.constant 1024 : i32
%0 = tt.get_program_id x : i32
%1 = arith.muli %0, %c1024_i32 : i32
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
i32 값 %1을 (1024,) 형태의 텐서로 splat(브로드캐스트)하여 i32 텐서를 만들고, arith.addi로 범위 %2와 원소별 덧셈을 수행합니다.
arith는 산술 연산을 제공하는 또 다른 MLIR 방언(dialect)입니다. Arith Dialect
1
2
%3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked>
%4 = arith.addi %3, %2 : tensor<1024xi32, #blocked>
코드는 f32 값에 대한 포인터 텐서를 생성하고, tt.addptr로 오프셋을 더해, tt.load를 사용해 메모리에서 값을 로드합니다.
마지막으로 조건 %6에 따라 %15가 가리키는 메모리 위치에 arith.addf 연산 결과를 저장합니다. %15는 output_ptr입니다.
1
2
3
4
5
6
%7 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>, #blocked>
%8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi32, #blocked>
%9 = tt.load %8, %6 : tensor<1024x!tt.ptr<f32>, #blocked>
...
tt.store %15, %13, %6 : tensor<1024x!tt.ptr<f32>, #blocked> loc(#loc13)
tt.return
각 명령은 파이썬(Triton DSL) 코드에도 대응시킬 수 있습니다. 아래는 IR의 위치 정보(loc)를 바탕으로 분해해 본 것입니다.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
Section 1: unknown
%c1024_i32 = arith.constant 1024 : i32 loc(#loc1)
Section 2: /home/ksharma/dev/git/triton/python/tutorials/01-vector-add.py:37:24
%0 = tt.get_program_id x : i32 loc(#loc2)
Section 3: /home/ksharma/dev/git/triton/python/tutorials/01-vector-add.py:42:24
%1 = arith.muli %0, %c1024_i32 : i32 loc(#loc3)
Section 4: /home/ksharma/dev/git/triton/python/tutorials/01-vector-add.py:43:41
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> loc(#loc4)
Section 5: /home/ksharma/dev/git/triton/python/tutorials/01-vector-add.py:43:28
%3 = tt.splat %1 : i32 -> tensor<1024xi32, #blocked> loc(#loc5)
%4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> loc(#loc5)
Section 6: /home/ksharma/dev/git/triton/python/tutorials/01-vector-add.py:45:21
%5 = tt.splat %arg3 : i32 -> tensor<1024xi32, #blocked> loc(#loc6)
%6 = arith.cmpi slt, %4, %5 : tensor<1024xi32, #blocked> loc(#loc6)
Section 7: /home/ksharma/dev/git/triton/python/tutorials/01-vector-add.py:48:24
%7 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> loc(#loc7)
%8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> loc(#loc7)
Section 8: /home/ksharma/dev/git/triton/python/tutorials/01-vector-add.py:48:16
%9 = tt.load %8, %6 : tensor<1024x!tt.ptr, #blocked> loc(#loc8)
Section 9: /home/ksharma/dev/git/triton/python/tutorials/01-vector-add.py:49:24
%10 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> loc(#loc9)
%11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> loc(#loc9)
Section 10: /home/ksharma/dev/git/triton/python/tutorials/01-vector-add.py:49:16
%12 = tt.load %11, %6 : tensor<1024x!tt.ptr, #blocked> loc(#loc10)
Section 11: /home/ksharma/dev/git/triton/python/tutorials/01-vector-add.py:50:17
%13 = arith.addf %9, %12 : tensor<1024xf32, #blocked> loc(#loc11)
Section 12: /home/ksharma/dev/git/triton/python/tutorials/01-vector-add.py:52:26
%14 = tt.splat %arg2 : !tt.ptr -> tensor<1024x!tt.ptr, #blocked> loc(#loc12)
%15 = tt.addptr %14, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> loc(#loc12)
Section 13: /home/ksharma/dev/git/triton/python/tutorials/01-vector-add.py:52:35
tt.store %15, %13, %6 : tensor<1024x!tt.ptr, #blocked> loc(#loc13)
Section 14: /home/ksharma/dev/git/triton/python/tutorials/01-vector-add.py:52:4
tt.return loc(#loc14)
Triton에는 백엔드 로깅/정보를 확인하고 디버그하는 데 도움이 되는 여러 환경 변수가 있습니다. 그중 MLIR_ENABLE_DUMP=1은 Triton이 모든 커널에 대해 실행하는 모든 MLIR 패스 이전의 IR을 덤프합니다.
1
2
3
4
5
6
7
# 덤프가 이미 존재하면 Triton 캐시가 덤프에 영향을 줍니다.
# 그래서 캐시를 비우세요. 캐시를 비활성화하는 방법은 아직 찾지 못했습니다.
$ MLIR_ENABLE_TIMING=1 MLIR_ENABLE_DUMP=1 python3 python/triton/tools/compile.py \
--kernel-name add_kernel \
--signature "*fp32,*fp32,*fp32,i32,64" \
--grid=1024,1024,1024 \
python/tutorials/01-vector-add.py
이 덤프를 파싱해 더 읽기 쉬운 출력을 제공하는 도구를 해킹하고 있습니다. 준비되면 이 글을 업데이트하겠습니다. 일단, 특정 패스 전/후의 샘플 출력은 아래와 같습니다:
이번 글에서는 Triton이 IR을 대상 하드웨어로 점진적으로 내리기 위해 사용하는 MLIR 패스들을 살펴보았습니다. 또한 다수의 패스 중 하나를 거친 후 생성된 Triton GPU IR을 워크스루하면서 원래의 파이썬 코드에 어떻게 다시 매핑되는지 확인했습니다.