MLIR에서 표현된 고수준 텐서 연산을 NVIDIA GPU에서 병렬로 실행되는 저수준 코드(PTX/CUBIN)로 단계적으로 로워링하는 과정을 살펴보고, CUDA 생태계와 GPU 실행 모델, 그리고 Python에서 MLIR 파이프라인을 구성해 PTX를 생성·실행하는 최소 컴파일러 골격을 구축한다.

지난번에 이어, 이제 최적화해야 할 transformer 프리미티브가 준비되었으니 MLIR로 표현된 고수준 텐서 연산(softmax, attention 등)을 NVIDIA GPU에서 병렬로 실행 가능한 저수준 코드로 어떻게 번역할지 살펴보자. 성능을 끌어올리기 위해 작은 컴파일러를 만들기 시작할 것이다.
이제 NVIDIA GPU가 있는 머신으로 옮겨야 한다(미안하지만 MacOS는 안 된다). NVIDIA GPU가 장착된 Linux 머신을 켜거나, 여러 제공업체에서 임대해 사용하자.
CUDA Toolkit 설치는 악명 높게 고통스럽다. 클라우드 제공업체에서 GPU를 임대하면 보통 이미 설치되어 있다. 별도 인스턴스를 직접 구성해야 한다면, NVIDIA의 공식 docker 이미지 중 하나를 사용하는 것이 좋다.
docker pull nvidia/cuda:12.1.0-devel-ubuntu22.04
docker run -it --gpus all nvidia/cuda:12.1.0-devel-ubuntu22.04 bash
하지만 자신의 머신에 직접 설치해야 한다면 다음 지침을 참고하자. 다만 최신 지침은 온라인에 있는 여러 가이드를 확인하는 것이 좋다. CUDA Toolkit이 설치되어 있지 않은 머신이 있다면, 다음 단계로 설치할 수 있다.
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.0-1_all.deb
sudo dpkg -i cuda-keyring_1.0-1_all.deb
sudo apt-get update
sudo apt-get install build-essential
sudo apt-get -y install cuda-toolkit-12-x
nvcc 컴파일러가 설치되어 있는지 확인하자. GPU를 임대했다면 아마 이미 설치되어 있을 것이다.
$ nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2021 NVIDIA Corporation
Built on Thu_Nov_18_09:45:30_PST_2021
Cuda compilation tools, release 11.5, V11.5.119
Build cuda_11.5.r11.5/compiler.30672275_0
NVIDIA GPU가 설치되어 있는지 확인한다:
lspci | grep -i nvidia
그다음 CUDA 버전과 드라이버 버전을 확인한다:
$ nvidia-smi
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.183.01 Driver Version: 535.183.01 CUDA Version: 12.2 |
|-----------------------------------------+----------------------+----------------------+
특정 CUDA 버전(이 경우 12.2)을 꼭 기록해 두자. 이 버전에 맞는 라이브러리를 설치해야 하기 때문에 중요하다.
보통이라면 CUDA runner를 활성화해 MLIR를 다시 빌드해야 하는데, 그건 정말 별로다. 그래서 CUDA runner가 활성화된 MLIR이 들어 있는 docker 이미지를 만들어 두었다. RunPod나 LambdaLabs에서 GPU를 임대했다면 이 베이스 이미지를 그대로 쓰면 된다. 로컬 머신이라면 NVIDIA Container Toolkit을 이용해 컨테이너에 GPU를 패스스루할 수 있다.
docker pull ghcr.io/sdiehl/docker-mlir-cuda:main
docker run -it ghcr.io/sdiehl/docker-mlir-cuda:main bash
그렇지 않다면 MLIR을 소스에서 다시 컴파일해야 한다. 단계는 다음과 같다.
git clone https://github.com/llvm/llvm-project
mkdir llvm-project/build
cd llvm-project/build
RUN cmake -G Ninja ../llvm \
-DCUDACXX=/usr/local/cuda/bin/nvcc \
-DCUDA_PATH=/usr/local/cuda \
-DCMAKE_CUDA_ARCHITECTURES="75;80;86;90" \ # Add your GPU architecture here
-DCMAKE_C_COMPILER=clang \
-DCMAKE_CXX_COMPILER=clang++ \
-DCMAKE_CUDA_COMPILER=/usr/local/cuda/bin/nvcc \
-DLLVM_ENABLE_PROJECTS=mlir \
-DLLVM_BUILD_EXAMPLES=ON \
-DLLVM_TARGETS_TO_BUILD="Native;NVPTX;AMDGPU" \
-DCMAKE_BUILD_TYPE=Release \
-DLLVM_ENABLE_ASSERTIONS=ON \
-DLLVM_CCACHE_BUILD=ON \
-DMLIR_ENABLE_CUDA_RUNNER=ON \
-DMLIR_ENABLE_CUDA_CONVERSIONS=ON \
-DMLIR_ENABLE_NVPTXCOMPILER=ON
RUN cmake --build . -t mlir-opt mlir-translate mlir-transform-opt mlir-runner
RUN cmake --build . -t install
주의: 이 글은 CUDA 프로그래밍을 다소 이상한 방식으로 ‘크래시 코스’처럼 훑는다. 목표는 MLIR에서 GPU를 타깃으로 하는 컴파일러를 만드는 것이므로, GPU 기본을 아주 빠르게 다룬 뒤 곧바로 GPU의 저수준 어셈블리 프로그래밍으로 들어간다. 일상 업무에서 PTX를 직접 작성하는 사람은(컴파일러를 다루지 않는 이상) 거의 없다. CUDA를 처음 작성한다면, 먼저 온라인의 CUDA 프로그래밍 가이드 중 하나를 읽고 GPU 프로그래밍을 더 자세히 익히는 것이 도움이 될 수 있다.
CUDA 컴파일 과정에는 GPU 프로그래밍을 위해 이해해야 할 몇 가지 핵심 구성요소와 중간 표현(IR)이 포함된다. 먼저 CUDA 프로그램이 어떻게 컴파일되고 실행되는지 살펴보자.
CUDA는 GPU를 프로그래밍할 수 있도록 C++을 확장한 것이다. C++ 위에 구축된 언어로, GPU 프로그래밍을 위한 기능이 추가되어 있다. 자체 컴파일러 nvcc를 사용해 NVIDIA GPU용 중간 어셈블리 언어인 PTX(Parallel Thread Execution)로 컴파일한다. PTX는 병렬 컴퓨팅을 위한 안정적인 프로그래밍 모델과 명령어 집합을 제공한다. 아키텍처 중립적이어서 여러 세대의 NVIDIA 하드웨어를 타깃으로 할 수 있다.
그다음 컴파일 과정은 PTX에서 장치별 바이너리 코드인 CUBIN(CUDA Binary)로 이동한다. CUBIN 파일에는 특정 GPU 아키텍처에서 실행되는 실제 머신 코드가 들어 있다. 이 파일들은 ELF 포맷이며 실행 코드뿐 아니라 심볼, 재배치 정보, 디버그 정보도 포함한다. 일반적으로 CUBIN은 기본값으로 호스트 실행 파일에 임베드되지만, nvcc의 -cubin 옵션으로 별도로 생성할 수도 있다.
nvcc는 이 복잡한 컴파일 과정을 관리한다. 여러 컴파일 단계를 조율하며, 필요에 따라 다양한 출력 포맷을 만들어 낼 수 있다.
컴파일된 프로그램을 호스트에서 실행하면 다음을 수행한다.
nvcc 컴파일러를 사용하면 이 모든 과정은 최종 사용자 입장에서 컴파일러가 처리해준다.
FATBIN 포맷은 하나의 실행 파일이 여러 GPU 아키텍처를 지원할 수 있게 해주므로 특히 중요하다. JIT 컴파일을 통한 향후 호환성을 위한 PTX 코드와, 알려진 아키텍처에서 최적 성능을 위한 사전 컴파일된 CUBIN 파일을 모두 담을 수 있다. 이제 CUDA를 사용하는 간단한 C++ 코드를 보자.
#include <stdio.h>
// CUDA kernel function
__global__ void helloKernel() {
printf("Hello, CUDA!\n");
}
int main() {
// Call the CUDA kernel
helloKernel<<<1, 1>>>();
// Wait for the kernel to finish
cudaDeviceReset(0);
return 0;
}
CUDA의 C++ 확장에는 커널을 런치하기 위한 특수 문법이 있다. helloKernel 함수가 GPU에서 실행하고 싶은 커널이다. <<<1, 1>>>은 커널 런치를 위한 CUDA 문법이다. 첫 번째 인자는 런치할 블록 개수, 두 번째 인자는 블록당 스레드 개수다.
kernel<<<blocks, threads>>>(...)
스레드와 블록은 뒤에서 더 다룬다. 이제 커널 내부에서는 blockDim, blockIdx, threadIdx 같은 특수 변수에 접근할 수 있다. 간단한 예로 배열의 각 원소를 제곱하는 커널을 보자. 여기에는 square 커널이 있으며 내부는 나중에 설명하겠지만, 지금은 배열의 각 원소를 제곱한다고만 보면 된다.
__global__ void square(int* array, int n) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid < n)
array[tid] = array[tid] * array[tid];
}
CUBIN 코드를 얻기 위해 다음 명령을 사용할 수 있다.
nvcc -o square square.cu
호스트 로직은 CPU에서, 컴파일된 CUBIN 코드는 GPU에서 실행되는 바이너리를 실행할 수 있다.
./square
nvcc에서 PTX 코드를 얻으려면 다음 옵션을 사용한다.
nvcc -ptx square.cu -o square.ptx
PTX 코드는 다음처럼 보인다.
.visible .entry square(int*, int)(
.param .u64 square(int*, int)_param_0,
.param .u32 square(int*, int)_param_1
)
{
ld.param.u64 %rd1, [square(int*, int)_param_0];
ld.param.u32 %r2, [square(int*, int)_param_1];
mov.u32 %r3, %ntid.x;
mov.u32 %r4, %ctaid.x;
mov.u32 %r5, %tid.x;
mad.lo.s32 %r1, %r3, %r4, %r5;
setp.ge.s32 %p1, %r1, %r2;
@%p1 bra $L__BB0_2;
cvta.to.global.u64 %rd2, %rd1;
mul.wide.s32 %rd3, %r1, 4;
add.s64 %rd4, %rd2, %rd3;
ld.global.u32 %r6, [%rd4];
mul.lo.s32 %r7, %r6, %r6;
st.global.u32 [%rd4], %r7;
$L__BB0_2:
ret;
}
여기서 무슨 일이 일어나는지 이해하기 위해 PTX를 명령어 단위로 분해해 보자. 커널에는 두 개의 파라미터가 있다. 첫 번째는 입력 정수 배열을 가리키는 포인터 param_0이며 입력 정수 배열의 디바이스 주소다. 두 번째는 입력 정수 배열의 크기인 param_1이다.
| Instruction | Description |
|---|---|
ld.param.u64 %rd1, [square(int*, int)_param_0]; | 커널 파라미터에서 입력 정수 배열의 디바이스 주소를 로드한다. |
ld.param.u32 %r2, [square(int*, int)_param_1]; | 커널 파라미터에서 입력 정수 배열의 크기를 로드한다. |
mov.u32 %r3, %ntid.x; | x-차원에서 블록당 스레드 수를 가져온다. |
mov.u32 %r4, %ctaid.x; | x-차원에서 현재 스레드 블록의 ID를 가져온다. |
mov.u32 %r5, %tid.x; | x-차원에서 해당 블록 내 현재 스레드의 ID를 가져온다. |
mad.lo.s32 %r1, %r3, %r4, %r5; | 블록 ID와 스레드 ID를 바탕으로 현재 스레드의 선형(전역처럼 보이는) 인덱스를 계산한다. |
setp.ge.s32 %p1, %r1, %r2; | 계산된 인덱스가 배열 크기 이상이면(경계 검사) 프레디케이트 플래그 %p1을 true로 설정한다. |
@%p1 bra $L__BB0_2; | 계산된 인덱스가 범위를 벗어나면(%p1이 true) 리턴으로 분기해 아무 일도 하지 않는다. |
cvta.to.global.u64 %rd2, %rd1; | 배열의 제네릭 디바이스 주소를 글로벌 주소로 변환한다. |
mul.wide.s32 %rd3, %r1, 4; | 현재 스레드 원소의 배열 내 바이트 오프셋을 계산한다(int는 4바이트 가정). |
add.s64 %rd4, %rd2, %rd3; | 현재 스레드 인덱스에 해당하는 배열 원소의 절대 글로벌 메모리 주소를 계산한다. |
ld.global.u32 %r6, [%rd4]; | 계산된 글로벌 메모리 주소에서 32비트 정수를 레지스터 %r6로 로드한다. |
mul.lo.s32 %r7, %r6, %r6; | %r6의 값을 제곱하고 결과의 하위 32비트를 %r7에 저장한다. |
st.global.u32 [%rd4], %r7; | %r7의 제곱 값을 동일한 글로벌 메모리 위치에 저장해 원래 값을 덮어쓴다. |
$L__BB0_2: | 범위 밖 스레드가 분기해 오는 지점을 표시하는 라벨. |
ret; | 현재 스레드의 커널 실행에서 반환한다. |
cuobjdump -sass 명령으로 CUBIN 코드를 SASS로 디스어셈블할 수도 있다(SASS는 “Shader ASSembly”).
-arch=sm_90에서 우리의 커널은 다음과 같다.
square(int*, int):
LDC R1, c[0x0][0x28]
S2R R5, SR_CTAID.X
ULDC UR4, c[0x0][0x0]
S2R R0, SR_TID.X
IMAD R5, R5, UR4, R0
ULDC UR4, c[0x0][0x218]
ISETP.GE.AND P0, PT, R5, UR4, PT
@P0 EXIT
LDC.64 R2, c[0x0][0x210]
ULDC.64 UR4, c[0x0][0x208]
IMAD.WIDE R2, R5, 0x4, R2
LDG.E R0, desc[UR4][R2.64]
IMAD R5, R0, R0, RZ
STG.E desc[UR4][R2.64], R5
EXIT
.L_x_0:
BRA `(.L_x_0)
NOP
NOP
NOP
NOP
NOP
NOP
NOP
NOP
.L_x_1:
위 CUDA 커널은 여러 스레드에 걸쳐 런치될 때, 각 스레드마다 고유한 인덱스를 계산하고 그 인덱스가 입력 정수 배열의 경계 안에 있으면 글로벌 메모리에서 해당 인덱스의 정수를 로드해 제곱한 뒤 같은 메모리 위치에 다시 저장한다. 즉 배열 원소를 병렬로 in-place 제곱한다. 계산된 인덱스가 배열 범위를 벗어난 스레드는 어떤 메모리 연산도 수행하지 않는다.
sm_53 같은 GPU 코드는 일관되게 sm_ 접두사로 시작한다. 하나의 가상 GPU 아키텍처는 여러 실제 아키텍처를 포함할 수 있다. CUBIN은 같은 세대의 모든 GPU에서 실행 가능하지만, 이전 세대나 이후 세대 GPU와는 호환되지 않는다. 예를 들어 sm_52로 컴파일된 CUBIN은 sm_50, sm_52, sm_53 GPU와는 호환되지만 sm_60 GPU와는 호환되지 않는다. 또한 sm_52로 컴파일된 CUBIN을 sm_53 GPU에서 실행했을 때 성능은, 같은 아키텍처에서 sm_53로 특화 컴파일한 CUBIN보다 떨어질 수 있다.
# Compile for sm_52
nvcc square.cu -o square --gpu-architecture=compute_52 --gpu-code=sm_52
# Compile for sm_52, sm_53, sm_60
nvcc square.cu -o square --gpu-architecture=compute_52 --gpu-code=sm_52,sm_53,sm_60
가장 흔한 아키텍처는 다음과 같다.
| Architecture | SM Version | GPUs | Features |
|---|---|---|---|
| Maxwell | sm_50, sm_52, sm_53 | GTX 750 Ti, GTX 960, GTX 970, GTX 980 | Dynamic Parallelism, CUDA Dynamic Parallelism |
| Pascal | sm_60, sm_61, sm_62 | P100, GTX 1080, GTX 1080 Ti | HBM2 Memory, NVLink, Unified Memory |
| Volta | sm_70, sm_72 | V100, Titan V | First-gen Tensor Cores, Independent Thread Scheduling |
| Turing | sm_75 | T4, RTX 2080, RTX 2080 Ti | First-gen RT Cores, Second-gen Tensor Cores |
| Ampere | sm_80, sm_86, sm_87 | A100, A40, RTX 3090 | Third-gen Tensor Cores, TF32 precision, Structural Sparsity |
| Ada Lovelace | sm_89 | RTX 4090, L40, L40S | Third-gen RT Cores, Fourth-gen Tensor Cores, AV1 encode/decode |
| Hopper | sm_90, sm_90a | H100, H200 | Fourth-gen Tensor Cores, Transformer Engine, 900 GB/s NVLink |
| Blackwell | sm_100, sm_100a, sm_101, sm_101a, sm_120, sm_120a | B100, B200 | Fifth-gen Tensor Cores, Improved occupancy, Enhanced shared memory |
Python Poetry 프로젝트에서 NVIDIA CUDA 라이브러리를 사용하려면 먼저 NVIDIA 패키지 인덱스를 추가해야 한다. Poetry를 사용한다면 nvidia-pyindex 패키지를 추가하고 extra index url을 설정하면 된다.
poetry add nvidia-pyindex --source https://pypi.ngc.nvidia.com
[extra]
index-url = "https://pypi.ngc.nvidia.com"
이제 다음 라이브러리들을 사용할 수 있다. cu12 접미사는 CUDA 12.x 버전을 뜻하며, 자신의 CUDA 버전에 따라 조정해야 할 수 있다. 설치해야 할 핵심 라이브러리 하나는 PyPI에서 제공되는 cuda-python이다: available on PyPI.
cuda-python - CUDA Runtime 및 기타 핵심 기능에 대한 Pythonic 접근그다음 핵심 라이브러리는 다음과 같다.
nvidia-cuda-runtime-cu12 - 필수 런타임 기능을 제공하는 핵심 런타임 라이브러리nvidia-cublas-cu12 - GPU 가속 BLAS(기본 선형대수 서브루틴)용 cuBLAS 라이브러리nvidia-cudnn-cu12 - 딥러닝 프리미티브 라이브러리(cuDNN: scaled dot-product attention, convolution, matrix multiplication, softmax, pooling)nvidia-cudnn-frontend - cuDNN frontend API를 위한 Python 바인딩계측 및 컴파일러 유틸리티도 있다.
nvidia-cuda-nvrtc-cu12 - 런타임 컴파일 라이브러리(NVRTC)nvidia-nvml-dev-cu12 - GPU 모니터링/관리용 Management Library(NVML)nvidia-nvtx-cu12 - 커스텀 프로파일링/트레이싱 계측을 위한 Tools Extension(NVTX)nvidia-cuda-nvcc-cu12 - C/C++ 코드를 컴파일하는 컴파일러(NVCC)nvidia-nvjitlink-cu12 - 디바이스 코드 런타임 링크를 위한 JIT Linkernvidia-cuda-sanitizer-api-cu12 - 애플리케이션 메모리 오류 탐지를 위한 Memory Checkernvidia-cuda-cupti-cu12 - 성능 분석/프로파일링을 위한 Profiling Tools Interface(CUPTI)다른 라이브러리들은 과학 계산, 신호 처리 등 특정 작업에 더 특화되어 있다. 여기서는 자세히 다루지 않지만 존재한다.
nvidia-cufft-cu12 - FFT 계산을 위한 cuFFT 라이브러리nvidia-curand-cu12 - 난수 생성을 위한 cuRAND 라이브러리nvidia-cusolver-cu12 - 밀집/희소 직접 해법을 위한 cuSOLVER 라이브러리nvidia-cusparse-cu12 - 희소 행렬 연산을 위한 cuSPARSE 라이브러리nvidia-npp-cu12 - 이미지/비디오/신호 처리를 위한 Performance Primitives(NPP)nvidia-nvjpeg-cu12 - 하드웨어 가속 JPEG 인코딩/디코딩용 nvJPEG 라이브러리nvidia-opencl-cu12 - GPU 컴퓨팅을 위한 OpenCL 구현CUDA에서 **커널(kernel)**은 GPU에서 실행되도록 정의하는 함수다. 커널 런치를 시작할 때 단지 함수 하나를 호출하는 것이 아니라, 그 함수를 다양한 데이터 원소에 대해 동시에 실행하는 수백~수천 개의 병렬 스레드를 생성한다.
이 실행 모델은 보통 SIMT(Single-Instruction Multiple-Thread) 패러다임이라고 부르며, 하나의 프로그램을 많은 데이터 조각에 대해 실행하는 것을 강조한다. 예를 들어 배열 각 원소의 제곱을 계산하는 커널이 있다고 하자. blockDim.x와 blockIdx.x가 왜 마법처럼 보이는지는 걱정하지 말자. 뒤에서 더 설명한다.
__global__ void square(float* array, int n) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid < n)
array[tid] = array[tid] * array[tid];
}
CUDA 커널은 두 가지 주요 특성을 가진다. 첫째, 명시적으로 값을 반환할 수 없으므로 결과 데이터는 함수에 전달된 배열에 기록해야 한다(스칼라 계산이라면 보통 길이 1짜리 배열을 전달한다). 둘째, 커널 호출 시 스레드 계층을 명시적으로 선언해야 하며, 스레드 블록 수와 블록당 스레드 수를 지정해야 한다. 커널은 한 번 컴파일되지만, 블록 크기나 그리드 크기를 달리해서 여러 번 실행할 수 있다는 점도 중요하다.
이 코드를 뜯어보면, 각 스레드에 배열에서 작업할 고유 인덱스 tid를 부여하는 것뿐이다. 커널은 그 인덱스를 사용해 각 스레드가 A와 B에서 원소를 하나씩 더하게 만든다. __global__ 키워드는 이 함수가 GPU에서 실행되는 CUDA 커널이며 CPU 측에서 시작할 수 있음을 뜻한다. if (tid < n) 체크는 배열에 필요한 스레드보다 약간 더 많이 런치할 수 있기 때문에(보통 좋은 숫자로 올림) 범위를 벗어난 추가 스레드가 놀게 만들기 위해 존재한다.
메인 CPU 프로그램(“host code”)에서 커널을 실행하려면 사용할 스레드 수를 정의해야 한다. 커널은 GPU에서 돌지만, 커널을 설정하고 런치하는 CPU 코드가 필요하다. 예를 들어 실행 코드는 다음과 같을 수 있다.
int N = 10000;
int threadsPerBlock = 256;
int numberOfBlocks = (N + threadsPerBlock - 1) / threadsPerBlock;
square<<<numberOfBlocks, threadsPerBlock>>>(d_array, N);
개념을 풀어보면, CUDA는 스레드를 32개가 한 묶음으로 동시에 실행되는 워프(warp)로 조직하고, 이 워프들을 다시 블록(block)으로 조직한다. 각 블록은 레지스터와 shared memory 같은 제한된 자원을 가진 SM(Streaming Multiprocessor)에서 실행된다. 블록 크기는 이런 자원의 할당에 영향을 주며, 동시에 실행 가능한 워프 수(즉 occupancy)를 결정한다.
효율적인 자원 관리는 GPU 성능을 극대화하는 데 매우 중요하다. GPU 스케줄러는 사용 가능한 SM에 블록을 배치한다. 블록 수가 SM 수를 초과하면 초과 블록은 큐에 쌓이고, 자원이 생기는 대로 실행된다. 이 스케줄링 과정은 SM당 shared memory 양과 레지스터 파일 크기 같은 여러 요소에 영향을 받으며, 결국 동시에 실행 가능한 블록 수를 결정한다.
또 다른 큰 이슈는 warp divergence다. 이는 if문 같은 조건문 때문에 한 워프 안의 스레드들이 서로 다른 실행 경로로 갈라지는 상황을 말한다. 이상적으로는 워프 내 모든 스레드가 같은 명령을 동시에 실행해야 한다. 하지만 어떤 스레드는 한 경로를, 다른 스레드는 다른 경로를 선택하면 실행이 직렬화되어 비효율이 생긴다. GPU 하드웨어는 마스크 비트를 사용해 각 경로를 따라야 할 스레드를 관리하여 모든 스레드가 올바르게 일을 마치게 한다. 이 메커니즘은 정합성을 보장하지만, 분기 실행 중 놀고 있는 스레드가 연산 자원을 낭비하므로 성능에 악영향을 줄 수 있다. 특히 divergence가 심하면 GPU 전체 처리량이 떨어질 수 있다. 성능을 높이려면 워프 내부에서 조건 분기를 최소화해 실행 효율을 최적화해야 한다.
GPU에는 세 가지 메모리 유형이 있다.
Register Memory: GPU 칩 위에 직접 존재하는 가장 빠른 메모리. 개별 스레드만 접근할 수 있으며 수명도 가장 짧아 스레드 실행 동안만 유지된다.
Shared Memory: GPU 블록 내 모든 스레드가 공유하는 메모리. 글로벌 메모리보다 빠르지만 레지스터 메모리보다는 느리다. 스레드 간 통신 및 데이터 공유에 사용되며 블록 생명주기 동안 유지된다.
Global Memory: GPU에서 가장 큰 메모리 풀. 모든 블록/스레드에서 접근 가능하다. 레지스터/공유 메모리보다 느리지만 용량이 가장 크다. 여러 블록/스레드에 걸쳐 공유해야 하는 데이터를 저장하는 데 사용된다.
가장 중요한 성능 고려사항은 CPU(호스트)와 GPU(디바이스) 간 메모리 전송 오버헤드다. 호스트/디바이스 메모리 간 전송은(온칩 메모리 접근에 비해) 상대적으로 느리므로, 가능한 한 전송을 최소화하고 데이터를 GPU에 오래 유지하는 것이 유리하다. 즉 각 연산마다 데이터를 왕복시키기보다, 같은 데이터에 대해 여러 계산을 수행한 뒤에 결과를 호스트로 가져오는 편이 좋다. 가능한 한 GPU에서 계속 살다가 꼭 필요할 때만 CPU로 돌아오면 된다.
또 **메모리 코얼레싱(memory coalescing)**을 이해하는 것도 중요하다. 이는 여러 메모리 접근을 하나의 트랜잭션으로 합칠 수 있는 능력이다. 워프 내 스레드들이 연속적인 메모리 주소에 접근하면 코얼레싱이 일어난다. 최적 성능을 위해 메모리 접근 패턴이 가능한 한 코얼레싱되도록 해야 하며, 코얼레싱되지 않은 접근은 여러 메모리 트랜잭션이 필요해 성능에 큰 영향을 준다.
Shared memory를 사용할 때는 bank conflict도 성능에 영향을 줄 수 있다. 여러 스레드가 동일한 메모리 bank에서 서로 다른 주소를 동시에 접근하려 하면 bank conflict가 발생하고 접근이 직렬화된다. 이를 피하려면 shared memory 접근 패턴을 신중히 설계하고, bank에 걸친 접근 분포가 최적이 되도록 데이터 구조에 padding을 넣는 등의 방법을 고려해야 한다. 적절한 padding과 접근 패턴은 메모리 대역폭 활용을 극대화하는 데 도움이 된다.
CUDA 런타임을 관리하려면 in the Python API(C API를 감싼 래퍼)에 있는 몇 가지 핵심 함수를 사용한다. 다음은 Python 바인딩의 cuda.cuda 모듈에 있는 핵심 함수들이다.
cuInit - CUDA 드라이버 API를 초기화한다.cuDeviceGet - 첫 번째 컴퓨트 디바이스의 핸들을 가져온다.cudaSetDevice - 현재 디바이스를 설정한다.cuCtxCreate - 지정된 디바이스에 컴퓨트 컨텍스트를 만든다.cuModuleLoadData - PTX 문자열을 JIT 컴파일해 디바이스 바이너리로 만든다.cuModuleGetFunction - 모듈에서 커널 함수의 핸들을 가져온다.cuLaunchKernel - 지정된 디바이스에서 커널을 런치한다.cuMemcpyHtoD - 호스트에서 디바이스 메모리로 데이터를 복사한다.cuMemcpyDtoH - 디바이스에서 호스트 메모리로 데이터를 복사한다.cuMemFree - 디바이스 메모리를 해제한다.cuMemAlloc - 디바이스 메모리를 할당한다.cuCtxDestroy - 컴퓨트 컨텍스트를 파괴한다.cuModuleUnload - 디바이스에서 모듈을 언로드한다.CUDA Python API의 에러 처리는, 정수 리턴 코드 기반인 C API를 감싼 래퍼라서 다소 귀찮다. 따라서 이 함수들 모두에 대해 리턴 값을 확인하고 성공 값이 아니면 예외를 던져야 한다. 이를 쉽게 하기 위해 checkCudaErrors 헬퍼 함수를 제공하는데, 성공 값이 아니면 예외를 발생시킨다.
CUDA 컨텍스트를 설정하기 위해 다음 보일러플레이트 함수를 사용할 수 있다.
# gpu_setup.py
# Setup CUDA and create a context.
import cuda.cuda as cu # type: ignore
import cuda.cudart as cudart # type: ignore
import cuda.nvrtc as nvrtc # type: ignore
def _cudaGetErrorEnum(error):
if isinstance(error, cu.CUresult):
err, name = cu.cuGetErrorName(error)
return name if err == cu.CUresult.CUDA_SUCCESS else "<unknown>"
elif isinstance(error, cudart.cudaError_t):
return cudart.cudaGetErrorName(error)[1]
elif isinstance(error, nvrtc.nvrtcResult):
return nvrtc.nvrtcGetErrorString(error)[1]
else:
raise RuntimeError(f"Unknown error type: {error}")
def checkCudaErrors(result):
if result[0].value:
raise RuntimeError(
f"CUDA error code={result[0].value}({_cudaGetErrorEnum(result[0])})"
)
if len(result) == 1:
return None
elif len(result) == 2:
return result[1]
else:
return result[1:]
def findCudaDevice():
devID = 0
checkCudaErrors(cudart.cudaSetDevice(devID))
return devID
def findCudaDeviceDRV():
devID = 0
checkCudaErrors(cu.cuInit(0))
cuDevice = checkCudaErrors(cu.cuDeviceGet(devID))
return cuDevice
def setup_cuda(device_id=None):
print("Initializing CUDA...")
# Initialize CUDA
checkCudaErrors(cu.cuInit(0))
# Get device
if device_id is None:
device = findCudaDeviceDRV()
device_id = 0 # For printing purposes
else:
device = checkCudaErrors(cu.cuDeviceGet(device_id))
# Create context
context = checkCudaErrors(cu.cuCtxCreate(0, device))
print(f"CUDA context created on device {device_id}.")
return context
def cleanup_cuda(context):
if context:
print("Destroying CUDA context...")
checkCudaErrors(cu.cuCtxDestroy(context))
print("CUDA context destroyed.")
이제 CUDA Python API를 사용해 호스트(CPU)와 디바이스(GPU) 사이의 메모리를 관리하는 방법을 보자. 다음 예시는 GPU 프로그래밍에 필요한 핵심 메모리 연산을 보여준다.
cuMemAlloc으로 GPU에 메모리를 할당한다. 이는 데이터를 위한 디바이스 공간을 확보한다.cuMemcpyHtoD(Host to Device)로 호스트에서 디바이스로 데이터를 전송한다.cuMemcpyDtoH(Device to Host)로 결과를 가져온다.아래 예시에서는 간단한 float 배열을 만들어 GPU로 보내고, 다시 복사해 돌아오는지 확인한다. 이 패턴이 GPU 컴퓨팅의 기초가 된다.
# gpu_memory.py
# Minimal example demonstrating CUDA context setup and basic memory movement
from gpu_setup import setup_cuda, checkCudaErrors, cleanup_cuda
import numpy as np
import ctypes
import cuda.cuda as cu # type: ignore
# Setup CUDA context
context = setup_cuda()
try:
# Allocate memory on GPU
buffer_size = 5 * ctypes.sizeof(ctypes.c_float)
device_ptr = checkCudaErrors(cu.cuMemAlloc(buffer_size))
# Create and initialize host array
host_array = np.array([1.0, 2.0, 3.0, 4.0, 5.0], dtype=np.float32)
print(f"Host array: {host_array}")
# Copy data from host to device
checkCudaErrors(cu.cuMemcpyHtoD(device_ptr, host_array.ctypes.data, buffer_size))
# Create array for results
result_array = np.zeros(5, dtype=np.float32)
# Copy data back from device to host
checkCudaErrors(cu.cuMemcpyDtoH(result_array.ctypes.data, device_ptr, buffer_size))
print(f"Result array: {result_array}")
# Free GPU memory
checkCudaErrors(cu.cuMemFree(device_ptr))
finally:
# Always clean up the context
cleanup_cuda(context)
이제 아주 간단한 커스텀 커널을 런치해 기본 연산을 수행해 보자. 이 예시는 CUDA C++ 코드를 컴파일할 필요 없이 PTX 어셈블리에서 바로 GPU 코드를 실행하는 전체 워크플로우를 보여준다. 커널은 배열의 첫 번째 원소를 42로 설정할 뿐이지만, GPU 커널 실행의 필수 요소를 모두 담고 있다. CUDA 컨텍스트 초기화, GPU 메모리 할당, PTX로 커널 정의, CUDA 모듈로 PTX 로드, 실행 차원(블록 1개에 스레드 1개) 설정, 적절한 인자로 커널 런치, 완료 동기화, 결과를 호스트로 복사해 검증까지. PTX 코드에는 첫 스레드만 대입하도록 스레드/블록 ID 체크도 포함된다.
# gpu_memory.py
# Minimal example demonstraitng lauching a kernel
from gpu_setup import setup_cuda, checkCudaErrors, cleanup_cuda
import numpy as np
import cuda.cuda as cu # type: ignore
# This is the trivial kernel that sets the first element
# of the array to 42 using a single thread
# __global__ void set_value_kernel(int *data) {
# if (threadIdx.x == 0 && blockIdx.x == 0) {
# *data = 42;
# }
# }
ptx_kernel = """"
.visible .entry kernel(int*)( .param .u64 kernel(int*)_param_0) {
ld.param.u64 %rd1, [kernel(int*)_param_0];
mov.u32 %r1, %tid.x;
mov.u32 %r2, %ctaid.x;
or.b32 %r3, %r1, %r2;
setp.ne.s32 %p1, %r3, 0;
@%p1 bra $L__BB0_2;
cvta.to.global.u64 %rd2, %rd1;
mov.u32 %r4, 42;
st.global.u32 [%rd2], %r4;
$L__BB0_2:
ret;
}
"""
# Setup CUDA context
context = setup_cuda()
try:
# Allocate a single integer on the device
data_size = np.dtype(np.int32).itemsize
d_data = checkCudaErrors(cu.cuMemAlloc(data_size))
# Set up grid and block dimensions - just one thread
grid_dims = (1, 1, 1)
block_dims = (1, 1, 1)
# Prepare arguments for the kernel
args = [d_data]
arg_types = [None] # None for pointer types
# Load the module
module = checkCudaErrors(cu.cuModuleLoadData(ptx_kernel.encode("utf-8")))
# Get kernel function
kernel_func = checkCudaErrors(
cu.cuModuleGetFunction(module, "kernel".encode("utf-8"))
)
# Prepare kernel arguments
kernel_args = (tuple(args), tuple(arg_types))
# Launch kernel
checkCudaErrors(
cu.cuLaunchKernel(
kernel_func,
grid_dims[0],
grid_dims[1],
grid_dims[2],
block_dims[0],
block_dims[1],
block_dims[2],
0, # shared memory bytes
0, # stream
kernel_args, # kernel args
0, # extra
)
)
# Synchronize
checkCudaErrors(cu.cuCtxSynchronize())
# Copy result back to host
host_data = np.zeros(1, dtype=np.int32)
checkCudaErrors(cu.cuMemcpyDtoH(host_data.ctypes.data, d_data, host_data.nbytes))
# Verify result
print(f"Result: {host_data[0]}")
assert host_data[0] == 42, f"Expected 42 but got {host_data[0]}"
print("Success! Kernel executed correctly.")
# Free device memory
checkCudaErrors(cu.cuMemFree(d_data))
# Unload module
checkCudaErrors(cu.cuModuleUnload(module))
finally:
cleanup_cuda(context)
앞서 워프(일반적으로 32 스레드)로 스레드가 조직되고, 이 워프들이 블록으로 묶여 SM에 할당되어 실행된다고 설명했다. 이제 관점을 넓혀 보자. 스레드 블록은 shared memory와 동기화 메커니즘을 사용해 함께 작업하는 스레드들의 집합이며, 모두 동일한 SM에서 실행된다.
GPU를 점검하기 위해 cuda-python 라이브러리로 GPU 드라이버 API를 조회해 사용 가능한 병렬성 속성 정보를 얻을 수 있다.
from cuda.cuda import CUdevice_attribute, cuDeviceGetAttribute, cuDeviceGetName, cuInit
(err,) = cuInit(0)
err, DEVICE_NAME = cuDeviceGetName(128, 0)
DEVICE_NAME = DEVICE_NAME.decode("ascii").replace("\x00", "")
err, MAX_THREADS_PER_BLOCK = cuDeviceGetAttribute(
CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MAX_THREADS_PER_BLOCK, 0
)
err, MAX_BLOCK_DIM_X = cuDeviceGetAttribute(
CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MAX_BLOCK_DIM_X, 0
)
err, MAX_GRID_DIM_X = cuDeviceGetAttribute(
CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MAX_GRID_DIM_X, 0
)
err, SMs = cuDeviceGetAttribute(
CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, 0
)
print(f"GPU Device: {DEVICE_NAME}")
print(f"Number of multiprocessors: {SMs}")
print(f"Maximum number of threads per block: {MAX_THREADS_PER_BLOCK:10}")
print(f"Maximum number of blocks per grid: {MAX_BLOCK_DIM_X:10}")
print(f"Maximum number of threads per grid: {MAX_GRID_DIM_X:10}")
예를 들어 A100 GPU에서는 다음을 얻는다.
Device Name: A100
Maximum number of multiprocessors: 108
Maximum number of threads per multiprocessor: 1024
Maximum number of threads per block: 1024
Maximum number of blocks per grid: 2147483647
블록을 사용하는 이유는 실용적 제약 때문이다. 현대 GPU에서 단일 블록에 런치할 수 있는 스레드 수에는 하드웨어 한계가 있으며 보통 1024로 제한된다. 계산에 이보다 많은 스레드가 필요하면 여러 블록으로 나눠야 한다. 블록은 특정 데이터 구간을 함께 처리하는 스레드 팀이라고 생각하면 된다. 전체 그리드는 커널이 시작한 모든 블록을 나타내며, 여러 팀이 협력해 전체 작업을 끝내는 것을 뜻한다.
예를 들어 2048 원소의 두 배열을 더하고 싶다면, 1024 스레드짜리 블록 2개를 사용할 수 있다. 블록 0은 인덱스 0-1023, 블록 1은 1024-2047을 처리한다. 일반적으로 원소 수가 N이고 블록이 최대 B 스레드를 수용할 수 있다면, 모든 원소를 처리하기 위해 ceil(N/B)개의 블록을 런치한다.
블록은 스케일링과 스케줄링에도 매우 유용하다. 예를 들어 GPU에 SM이 여러 개(예: 20개) 있고, 각 SM은 자원 상황에 따라 몇 개의 블록을 동시에 돌릴 수 있다. 100개의 블록을 던지면, 동시에 20개만 실행 가능하다면(각 SM에 하나씩) GPU는 그 20개를 병렬로 실행하고, 하나가 끝날 때마다 대기 중인 다른 블록을 즉시 가져와 비워진 SM에서 실행한다.
프로그래머 관점에서는 100개의 블록이 모두 함께 답을 만들기 위해 일하는 셈이다. 마치 100 * blockSize 스레드가 모두 달리는 것처럼 보인다. 하지만 실제로는 GPU가 하드웨어 전반에 블록을 영리하게 분산한다. 동시에 처리 가능한 물리적 한계를 넘는 스레드를 런치하는 것에 대해 스트레스 받을 필요가 없다. 런타임이 필요에 따라 블록을 타임슬라이스해준다. 또한 블록은 여러 GPU로 작업을 나누거나, 공유 메모리나 레지스터 같은 자원 한계에 부딪힐 때 동시에 실행되는 병렬 작업량을 줄이는 자연스러운 단위가 된다.
같은 블록 안의 스레드는 특별히 강력한 협업 능력을 가진다. 빠른 온칩 메모리에 공동으로 접근할 수 있고, 배리어 연산으로 동기화할 수 있어 스레드 간 통신이 필요한 작업에서 효율적으로 협업할 수 있다. 하지만 이 협업은 같은 블록 내부로 엄격히 제한된다. 서로 다른 블록의 스레드는 통신을 위해 더 느린 글로벌 메모리를 써야 하며 직접 동기화할 수도 없다. 이 설계는 중요한 목적이 있다. 블록을 SM에 유연하게 스케줄링할 수 있게 하고, SM 개수가 서로 다른 다양한 GPU 아키텍처에서도 CUDA 프로그램이 자연스럽게 스케일되도록 한다.
보통 블록 크기는 2의 거듭제곱(128, 256, 512 등)으로 선택해 워프 크기 32에 맞추고 하드웨어 특성을 최적화한다. 이런 크기 선택은 여러 성능 요소에 영향을 준다. 자원 활용을 극대화하고, 메모리 접근 패턴을 최적화하며, 효율적인 스레드 스케줄링으로 메모리 지연을 더 잘 숨길 수 있게 한다.
# N is size of the array to process
N = 10000
# We calculate the number of blocks needed based on the total elements and threads per block
# Typically a power of 2 like 128, 256, or 512
threads_per_block = 256
# Ceiling division to ensure all elements are processed
num_blocks = (N + threads_per_block - 1) // threads_per_block
CUDA의 shared memory는 성능을 개선하는 데 사용할 수 있는 가장 강력한 최적화 도구 중 하나로, 메모리 접근 지연을 극적으로 줄이는 프로그래머블 캐시처럼 동작한다. 각 SM 안의 온칩에 위치한 shared memory는 글로벌 메모리보다 대역폭이 100배 높을 수 있고, 지연 시간은 대략 100배 낮을 수 있다. 이런 장점은 compute 코어와의 물리적 근접성에서 나오며, 글로벌 메모리가 오프칩에 있어 훨씬 느린 PCIe 버스나 메모리 컨트롤러를 거쳐야 하는 것과 대비된다. 블록의 스레드들이 같은 데이터를 사용하는 계산을 협업으로 수행할 때 shared memory는 중복 글로벌 메모리 로드를 제거하여, 블록 전체 스레드에 걸쳐 글로벌 메모리 접근 비용을 분산(상각)한다.
shared memory의 메커니즘은 CUDA 스레드 계층 설계를 그대로 반영한다. 커널이 런치되면 각 스레드 블록은 블록 수명 동안 유지되며 해당 블록 스레드만 볼 수 있는 전용 shared memory 영역을 할당받는다. 이 메모리는 bank로 조직되며(현대 GPU는 보통 32개 bank) 서로 다른 스레드들이 서로 다른 bank를 동시에 접근해 고대역폭 병렬 접근을 가능하게 한다. 하지만 여러 스레드가 같은 bank 내 서로 다른 주소(즉 bank conflict)를 접근하면 접근이 직렬화되어 성능이 떨어진다.
__syncthreads() 함수는 shared memory 작업의 핵심이다. 블록의 모든 스레드가 코드의 그 지점에 도달할 때까지 대기하도록 강제한 뒤 다음으로 진행한다. 이는 어떤 스레드는 아직 쓰기를 끝내지 않았는데 다른 스레드가 그 데이터를 읽어버리는 레이스 컨디션을 방지한다. 이런 문제는 미묘하고 디버깅이 어려운 오류의 흔한 원인이다.
다음 예시에서는 블록 내 모든 스레드가 볼 수 있는 __shared__ 배열을 선언한다. 글로벌 메모리에서 shared memory로 데이터를 로드하고, 모두 로드했는지 동기화한 뒤 shared memory에서 처리하고, 결과를 글로벌 메모리로 다시 쓴다.
#include <stdio.h>
__global__ void sharedMemoryExample(float* input, float* output, int n) {
// Declare shared memory array - visible to all threads in the block
__shared__ float sharedData[256];
// Calculate global thread ID
int tid = blockIdx.x * blockDim.x + threadIdx.x;
// Load data from global memory to shared memory
if (tid < n) {
sharedData[threadIdx.x] = input[tid];
}
// Synchronize to make sure all threads have loaded their data
__syncthreads();
// Process data in shared memory (simple example: add 1 to each element)
if (tid < n) {
sharedData[threadIdx.x] += 1.0f;
}
// Synchronize again before writing results back
__syncthreads();
// Write the processed data back to global memory
if (tid < n) {
output[tid] = sharedData[threadIdx.x];
}
}
shared memory를 사용하면, 그렇지 않았다면 비현실적인 알고리즘 기법들을 적용할 수 있다. 타일드 알고리즘에서는 데이터가 shared memory에 들어갈 수 있는 작은 덩어리(타일)로 처리되어, 행렬 곱 같은 연산이 글로벌 메모리 트래픽을 크게 줄여 이론적 피크 성능에 근접할 수 있다. 스텐실 계산에서는 shared memory로 스레드들이 겹치는 영역을 협업해 중복 글로벌 로드를 없앤다. 합/최댓값 같은 리덕션 연산에서는 shared memory가 블록 내에서 스레드들이 결과를 점진적으로 결합하는 효율적인 병렬 리덕션 패턴을 가능하게 한다. 핵심 디자인 패턴은 글로벌에서 shared로 로드하고, shared에서만 계산한 뒤, 결과를 글로벌로 다시 쓰는 것이다. 이렇게 하면 메모리 바운드 계산을 컴퓨트 바운드로 바꿔 GPU의 막대한 연산 처리량을 활용하고, 주된 병목인 메모리 접근 지연을 최소화한다.
이제 목표는, 고수준 텐서 연산을 동적으로(즉 런타임에) 지정하고 이를 MLIR로 적절한 GPU 코드로 로워링하는 것이다. MLIR의 점진적 로워링 기능을 활용해 고수준 텐서 연산을 일련의 잘 정의된 단계로 변환하여 최적화된 GPU 코드로 만들 것이다.
고수준 표현: linalg dialect로 표현된 추상 텐서 계산(행렬 곱, 컨볼루션, 원소별 연산)에서 시작한다. 이 연산들은 구현 세부사항 없이 순수한 수학적 의도를 나타내며, 코드를 도메인에 가깝게 유지한다.
Affine 변환: -convert-linalg-to-loops 또는 -convert-linalg-to-affine-loops 같은 패스를 사용해 고수준 연산을 affine dialect로 변환한다. affine은 루프와 메모리 접근을 정확한 수학적 관계로 표현한다. 이 단계는 메모리 지역성과 계산 효율을 높이는 타일링, 퓨전, 루프 인터체인지 같은 핵심 최적화 기회를 노출한다.
GPU 매핑: -convert-affine-for-to-gpu, -convert-parallel-loops-to-gpu 같은 패스를 통해 루프 구조를 GPU 실행 모델에 매핑한다. 이로써 코드가 gpu dialect로 변환되며, 스레드 블록과 스레드 같은 개념을 명시적으로 표현하지만 하드웨어에 구애받지 않는다.
하드웨어 특화 로워링: 마지막으로 nvvm dialect를 통해 NVIDIA 특화 기능으로 넘어가고, 최종적으로 GPU 인트린식을 포함한 LLVM IR로 내린다. 마지막 단계는 이를 NVPTX 어셈블리로 변환하고, 이어서 GPU에서 직접 실행되는 CUBIN 머신 코드로 만든다.
이 파이프라인 전체에서 각 로워링 단계는 프로그램 의미론을 보존하면서 타깃 하드웨어에 더 가까워진다. 수학 표현에서 명시적 루프 중첩으로, 다시 GPU 실행 구성요소로, 마지막으로 하드웨어 특화 코드로 체계적으로 변환한다. 이렇게 하면 계산은 고수준으로 표현하면서도 하드웨어 역량을 최대한 활용하는 고도로 최적화된 GPU 코드를 생성할 수 있다.
gpu DialectMLIR의 GPU dialect는 CUDA나 OpenCL과 유사한 프로그래밍 모델을 따르는 GPU 커널 런치를 위한 중간 수준 추상화를 제공한다. 이 dialect는 GPU 커널을 런치하는 데 필요한 디바이스/드라이버 특화 작업을 추상화하여, MLIR에서 GPU 실행으로 가는 경로를 간소화한다.
이 dialect는 gpu를 정식 접두사로 사용하며, 일반적인 GPU 프리미티브를 감싸는 연산들을 노출한다. 이상적인 미래에는 여러 GPU 백엔드를 타깃으로 할 수도 있겠지만, 지금은 PTX로 로워링하는 것을 목표로 NVIDIA 특화 연산에 집중하자.
디바이스 연산
gpu.launch - 지정된 그리드/블록 차원으로 GPU 커널을 런치한다. 블록/스레드 설정을 받고 커널 바디를 포함한다.gpu.launch_func - 지정된 그리드/블록 차원으로 GPU 함수를 런치한다. 함수 레퍼런스와 블록/스레드 설정을 받는다.gpu.barrier - 스레드 블록 내 모든 스레드를 동기화하여, 모두 이 지점에 도달할 때까지 진행을 멈춘다.gpu.binary - 컴파일된 GPU 코드(예: PTX 또는 CUBIN 포맷)를 담은 바이너리 블랍을 나타내며 로드/실행할 수 있다.gpu.printf - GPU 커널 내부에서 C의 printf처럼 포맷된 출력을 출력한다. GPU 코드 디버깅에 유용하다.gpu.return - GPU 함수/커널에서 값을 호스트로 반환한다.gpu.terminator - gpu.launch 바디의 마지막 연산으로 필요하며, GPU launch 영역의 끝을 표시한다.gpu.wait - 비동기 GPU 연산이 끝날 때까지 기다린 뒤 호스트 실행을 계속한다.gpu.yield - scf.yield와 유사하게, GPU 코드의 구조적 제어 흐름 영역에서 값을 산출한다.스레드 및 블록 연산
gpu.block_id - 지정된 차원(x, y, z)에서 그리드 내 현재 블록의 ID를 반환한다gpu.block_dim - 지정된 차원(x, y, z)에서 그리드의 블록 차원을 반환한다gpu.block_size - 지정된 차원(x, y, z)에서 블록 내 전체 스레드 수를 반환한다gpu.grid_id - 지정된 차원(x, y, z)에서 실행 중인 현재 그리드의 ID를 반환한다gpu.grid_dim - 지정된 차원(x, y, z)에서 그리드의 차원을 반환한다gpu.grid_size - 지정된 차원(x, y, z)에서 그리드 내 전체 블록 수를 반환한다gpu.thread_id - 지정된 차원(x, y, z)에서 블록 내 현재 스레드의 ID를 반환한다메모리 연산
gpu.alloc - GPU 디바이스에서 메모리를 할당한다. shape과 element type을 받아 할당된 메모리를 가리키는 memref를 반환한다.gpu.memcpy - 호스트↔디바이스 또는 디바이스 내 서로 다른 위치 간 메모리 복사를 수행한다. host-to-device, device-to-host, device-to-device를 처리할 수 있다.gpu.dealloc - gpu.alloc으로 할당한 디바이스 메모리를 해제한다.gpu.host_register - memref를 호스트 로컬 메모리로 등록하여 호스트에서 직접 접근 가능하게 한다.gpu.host_unregister - 디바이스에서 접근하던 memref의 호스트 등록을 해제한다.이 dialect는 특정 GPU 백엔드로 로워링 가능한 타깃 비종속 표현으로 설계되었다. 커널 호출에 대한 추상화를 제공하며, GPU용 LLVM IR 인트린식처럼 더 낮은 레벨에서는 제공되지 않는 디바이스 관리 기능이 장차 포함될 수도 있다.
이 dialect는 GPU 코드를 gpu.module 연산 안에 조직하고, 커널을 gpu.func 연산으로 표현하는 것을 기대한다. 이 구조는 호스트 코드와 디바이스 코드를 명확히 분리하여 컴파일/실행 파이프라인을 처리하기 쉽게 만든다. 커널이 아닌 함수(예: 디바이스 라이브러리 호출)는 func.func나 다른 비-GPU dialect 연산으로 정의할 수 있어, GPU 프로그램 구조에 어느 정도 유연성을 제공한다.
module {
// GPU module
gpu.module @gpu_module {
// GPU function
gpu.func @kernel() {
%0 = arith.addi %arg0, %arg0 : i32
%1 = arith.addi %arg0, %arg0 : i32
%2 = arith.addi %0, %1 : i32
gpu.return
}
}
// Host function
func.func @main() {
gpu.launch_func @gpu_module::@kernel
blocks(%0, %1, %2) in (%3 = %c1, %4 = %c1, %5 = %c1)
threads(%6, %7, %8) in (%9 = %c2, %10 = %c1, %11 = %c1)
args()
}
}
대안으로 gpu.launch 연산을 사용할 수 있는데, 이는 launch 연산에 전달되는 region에 커널 함수를 임베드한다. gpu.terminator 연산은 region의 끝을 표시하는 데 사용된다.
func.func @main() {
%c2 = arith.constant 2 : index
%c1 = arith.constant 1 : index
gpu.launch
blocks(%0, %1, %2) in (%3 = %c1, %4 = %c1, %5 = %c1)
threads(%6, %7, %8) in (%9 = %c2, %10 = %c1, %11 = %c1) {
gpu.printf "Hello from %d\n" %6 : index
gpu.terminator
}
return
}
nvgpu Dialectnvgpu dialect는 MLIR 생태계에서 중간자 역할을 하며, gpu나 vector 같은 고수준 타깃 비종속 dialect와 NVIDIA GPU를 위한 저수준 NVVM dialect를 연결한다. PTX 특화 연산을 표현하면서도 memref와 tensor dialect 같은 MLIR의 고수준 추상화를 계속 사용함으로써, 양쪽 세계의 장점을 유지하는 다리를 제공한다.
이 dialect를 사용하면 복잡한 NVVM 인트린식을 직접 다루지 않고도 NVGPU dialect를 통해 NVIDIA 특화 하드웨어 기능에 접근할 수 있다. 이 중간 레이어는 효율적인 GPU 코드 생성을 크게 단순화하면서도 고수준 MLIR 표현의 명료함과 표현력을 유지해준다.
이 dialect는 주로 Tensor Memory Accelerator(TMA)를 통해 글로벌↔shared memory 간 효율적인 텐서 전송 같은 NVIDIA GPU의 고급 기능을 노출하는 데 초점을 맞춘다. 또한 비동기 메모리 연산을 제공해 계산과 메모리 전송을 겹쳐 지연을 숨길 수 있다. 메모리 접근 패턴은 bank conflict를 줄이는 swizzling 기법으로 최적화되며, 메모리 배리어(mbarrier)는 스레드 간 연산을 조율하기 위한 정교한 동기화 도구를 제공한다.
워프 레벨 프로그래밍도 NVGPU dialect에서 특히 중요하게 다뤄진다. 워프 행렬 multiply-accumulate 연산과 동기화 프리미티브 같은 하드웨어 특화 기능을 노출하는 세밀한 제어를 제공한다. 이런 기능은 Hopper와 Blackwell 같은 최신 NVIDIA GPU에서 특수 하드웨어 유닛을 활용해 고성능 행렬 연산 및 기타 연산 집약 워크로드를 구현하는 데 필수적이다.
컴파일 파이프라인에서 이 dialect의 위치는 전략적으로 중요하다. 일반 GPU dialect와 하드웨어 특화 NVVM dialect 사이에 놓여 있어, 고수준 개념을 최적화된 하드웨어 특화 연산으로 번역하면서 개발자를 불필요한 복잡함으로부터 보호한다. 최종 코드를 생성할 때는 -convert-nvgpu-to-nvvm 패스로 NVGPU dialect를 NVVM으로 로워링할 수 있으며, NVGPU 연산을 대응되는 NVVM dialect 인트린식으로 변환한다.
지금은 이런 기능을 일단 미뤄두겠지만, 나중에 특히 TMA와 async 메모리 연산 같은 기능을 더 자세히 살펴볼 것이다.
nvvm DialectNVVM dialect는 NVIDIA GPU를 위한 LLVM IR을 표현하는 타깃 특화 dialect다. 스레드/블록 인덱싱, 동기화, 메모리 연산 등 GPU 특화 구성요소를 위한 연산들을 포함한다.
NVVM IR은 LLVM IR에 기반한 컴파일러 IR(중간 표현)이다. NVVM IR은 GPU 컴퓨트 커널(예: CUDA 커널)을 표현하도록 설계되었다. CUDA C 컴파일러 프런트엔드 같은 고수준 언어 프런트엔드는 NVVM IR을 생성할 수 있다. LLVM 기반 NVVM 컴파일러는 NVVM IR에서 PTX 코드를 생성한다.
| CUDA Builtin | NVVM Intrinsic | MLIR Operation |
|---|---|---|
threadId.{x,y,z} | @llvm.nvvm.read.ptx.sreg.tid.{x,y,z} | gpu.thread_id {x,y,z} |
blockIdx.{x,y,z} | @llvm.nvvm.read.ptx.sreg.ctaid.{x,y,z} | gpu.block_id {x,y,z} |
blockDim.{x,y,z} | @llvm.nvvm.read.ptx.sreg.ntid.{x,y,z} | gpu.block_dim {x,y,z} |
gridDim.{x,y,z} | @llvm.nvvm.read.ptx.sreg.nctaid.{x,y,z} | gpu.grid_dim {x,y,z} |
__syncthreads() | @llvm.nvvm.barrier0() | gpu.barrier |
MLIR Python bindings는 MLIR C++ 내부에 접근할 수 있는 인터페이스를 제공하여, 커맨드라인 도구에 의존하지 않고도 MLIR의 중간 표현을 프로그래밍적으로 조작할 수 있게 해준다. 이 바인딩은 Python을 통해 MLIR의 핵심 기능을 노출하며, MLIR 텍스트 파싱, MLIR AST 구성/수정, 변환 패스 적용, 심지어 Python에서 MLIR 코드를 JIT 컴파일 및 실행하는 것까지 가능하다.
Python 바인딩으로 작업하면 기존 커맨드라인 방식에서 발생하는 많은 마찰 지점을 없앨 수 있다. 여러 커맨드라인 호출을 체인으로 엮는 대신, Python 안에서 종단 간 컴파일 파이프라인을 구성할 수 있어 개발 과정이 더 인터랙티브하고 디버깅하기 쉬워진다. 다만 LLVM과 MLIR에 대한 복잡한 의존성 때문에 소스에서 이 바인딩을 빌드하는 일은 어려울 수 있다. 도입을 쉽게 하기 위해, 필요한 모든 의존성을 포함한 사전 빌드 wheel이 제공되어 pip 명령처럼 간단히 설치할 수 있다.
바인딩 설치는 pip로 할 수 있다.
pip install mlir_python_bindings -f https://github.com/makslevental/mlir-wheels/releases/expanded_assets/latest
Poetry를 사용한다면 pyproject.toml에 다음을 추가한다.
[tool.poetry.dependencies]
python = "^3.10"
mlir-python-bindings = { version = "*", source = "mlir-wheels"}
[[tool.poetry.source]]
name = "mlir-wheels"
url = "https://github.com/makslevental/mlir-wheels/releases/expanded_assets/latest"
priority = "supplemental"
이제 Python 바인딩으로 간단한 벡터 덧셈 커널을 PTX로 로워링해 보자.
module {
func.func @square(%input: tensor<10x10xf32>, %output: tensor<10x10xf32>) -> tensor<10x10xf32> {
%x0 = linalg.square ins(%input : tensor<10x10xf32>) outs(%output : tensor<10x10xf32>) -> tensor<10x10xf32>
return %x0 : tensor<10x10xf32>
}
}
여기서는 다음 패스들을 적용한다.
from mlir.ir import Context, Module
from mlir.passmanager import PassManager
mlir_module_str = open("vecadd.mlir").read()
with Context():
# Parse the input module
module = Module.parse(mlir_module_str)
pm = PassManager()
pm.enable_ir_printing(print_after_change=True)
pm.add("canonicalize")
pm.add(
"one-shot-bufferize{ bufferize-function-boundaries function-boundary-type-conversion=identity-layout-map }"
)
pm.add("canonicalize")
pm.add("convert-linalg-to-affine-loops")
pm.add("func.func(affine-loop-invariant-code-motion)")
pm.add("func.func(convert-affine-for-to-gpu)")
pm.add("gpu-kernel-outlining")
pm.add("lower-affine")
pm.add("gpu-decompose-memrefs")
pm.add("expand-strided-metadata")
pm.add("normalize-memrefs")
pm.add(
"gpu.module(convert-gpu-to-nvvm{index-bitwidth=0 use-bare-ptr-memref-call-conv })"
)
pm.add(f"nvvm-attach-target{ {chip={chip_type} features=+ptx80 O=3} }")
pm.add("convert-nvvm-to-llvm")
pm.add("reconcile-unrealized-casts")
pm.add("gpu-to-llvm { use-bare-pointers-for-host use-bare-pointers-for-kernels }")
pm.run(module.operation)
print(module)
여기에는 많은 일이 벌어지므로 각 플래그를 쪼개서 보자.
one-shot-bufferize{ bufferize-function-boundaries function-boundary-type-conversion=identity-layout-map }: 텐서 연산을 단일 패스로 버퍼 연산으로 변환하며, identity layout mapping으로 함수 경계를 처리한다.convert-linalg-to-affine-loops: 고수준 linalg 연산을 텐서 원소를 명시적으로 순회하는 affine 루프 네스트로 변환한다.func.func(affine-loop-invariant-code-motion): affine 루프에서 루프 불변 코드 이동을 수행하여 가능하면 계산을 루프 밖으로 옮겨 중복 계산을 줄인다.func.func(convert-affine-for-to-gpu): affine 루프를 GPU 실행 모델에 매핑하여 반복을 GPU 스레드와 블록에 분배한다.gpu-kernel-outlining: GPU 커널 영역을 별도의 GPU 함수로 추출해 호스트 코드에서 런치할 수 있게 한다.lower-affine: affine dialect 연산을 표준 제어 흐름 및 산술 연산으로 변환한다.gpu-decompose-memrefs: GPU 백엔드가 처리할 수 있도록 복잡한 memref 타입을 더 단순한 형태로 분해한다.expand-strided-metadata: stride 기반 메모리 접근의 메타데이터를 명시적 계산으로 확장한다.normalize-memrefs: GPU 백엔드가 기대하는 형태로 memref를 정규화한다.gpu.module(convert-gpu-to-nvvm{index-bitwidth=0 use-bare-ptr-memref-call-conv }): GPU dialect 연산을 NVVM dialect(NVIDIA의 LLVM 기반 IR)로 변환하며, memref에 대해 bare pointer 호출 규약을 사용한다.nvvm-attach-target{chip={chip_type} features=+ptx80 O=3}: NVVM 모듈에 타깃 특화 정보를 붙여 GPU 아키텍처, PTX 8.0 기능, 최적화 레벨 3을 지정한다.convert-nvvm-to-llvm: NVVM dialect를 표준 LLVM dialect로 번역해 후속 처리를 가능하게 한다.reconcile-unrealized-casts: 남아 있는 타입 변환 문제를 해결하기 위해 unrealized cast 연산을 정리한다.gpu-to-llvm { use-bare-pointers-for-host use-bare-pointers-for-kernels }: 남아 있는 GPU dialect 연산을 LLVM dialect로 변환하며, 호스트/디바이스 코드 모두에 bare pointer를 사용한다.convert-linalg-to-affine-loops로 linalg 연산을 affine 루프로 로워링하면 다음 MLIR 모듈을 얻는다.
module {
func.func @square(%arg0: memref<10x10xf32>, %arg1: memref<10x10xf32>) -> memref<10x10xf32> {
affine.for %arg2 = 0 to 10 {
affine.for %arg3 = 0 to 10 {
%0 = affine.load %arg0[%arg2, %arg3] : memref<10x10xf32>
%1 = arith.mulf %0, %0 : f32
affine.store %1, %arg1[%arg2, %arg3] : memref<10x10xf32>
}
}
return %arg1 : memref<10x10xf32>
}
}
convert-affine-for-to-gpu 패스 이후. 이 패스는 affine 루프를 GPU 커널 런치 영역으로 변환한다.
func.func @square(%arg0: memref<10x10xf32>, %arg1: memref<10x10xf32>) -> memref<10x10xf32> {
%c0 = arith.constant 0 : index
%c10 = arith.constant 10 : index
%0 = arith.subi %c10, %c0 : index
%c1 = arith.constant 1 : index
%c0_0 = arith.constant 0 : index
%c10_1 = arith.constant 10 : index
%1 = arith.subi %c10_1, %c0_0 : index
%c1_2 = arith.constant 1 : index
%c1_3 = arith.constant 1 : index
gpu.launch
blocks(%arg2, %arg3, %arg4) in (%arg8 = %0, %arg9 = %c1_3, %arg10 = %c1_3)
threads(%arg5, %arg6, %arg7) in (%arg11 = %1, %arg12 = %c1_3, %arg13 = %c1_3) {
%2 = arith.addi %c0, %arg2 : index
%3 = arith.addi %c0_0, %arg5 : index
%4 = affine.load %arg0[%2, %3] : memref<10x10xf32>
%5 = arith.mulf %4, %4 : f32
affine.store %5, %arg1[%2, %3] : memref<10x10xf32>
gpu.terminator
}
return %arg1 : memref<10x10xf32>
}
gpu-kernel-outlining 패스를 적용하면, GPU 커널 런치 영역을 GPU 인트린식을 포함한 별도 GPU 함수로 추출한다(블록/스레드 인덱스를 참조하기 위한 것).
module attributes {gpu.container_module} {
func.func @square(%arg0: memref<10x10xf32>, %arg1: memref<10x10xf32>) -> memref<10x10xf32> {
%c0 = arith.constant 0 : index
%c10 = arith.constant 10 : index
%0 = arith.subi %c10, %c0 : index
%c1 = arith.constant 1 : index
%c0_0 = arith.constant 0 : index
%c10_1 = arith.constant 10 : index
%1 = arith.subi %c10_1, %c0_0 : index
%c1_2 = arith.constant 1 : index
%c1_3 = arith.constant 1 : index
gpu.launch_func @square_kernel::@square_kernel
blocks in (%0, %c1_3, %c1_3)
threads in (%1, %c1_3, %c1_3)
args(%c0 : index, %c0_0 : index, %arg0 : memref<10x10xf32>, %arg1 : memref<10x10xf32>)
return %arg1 : memref<10x10xf32>
}
gpu.module @square_kernel {
gpu.func @square_kernel(%arg0: index, %arg1: index, %arg2: memref<10x10xf32>, %arg3: memref<10x10xf32>) kernel {
%block_id_x = gpu.block_id x
%block_id_y = gpu.block_id y
%block_id_z = gpu.block_id z
%thread_id_x = gpu.thread_id x
%thread_id_y = gpu.thread_id y
%thread_id_z = gpu.thread_id z
%grid_dim_x = gpu.grid_dim x
%grid_dim_y = gpu.grid_dim y
%grid_dim_z = gpu.grid_dim z
%block_dim_x = gpu.block_dim x
%block_dim_y = gpu.block_dim y
%block_dim_z = gpu.block_dim z
%0 = arith.addi %arg0, %block_id_x : index
%1 = arith.addi %arg1, %thread_id_x : index
%2 = affine.load %arg2[%0, %1] : memref<10x10xf32>
%3 = arith.mulf %2, %2 : f32
affine.store %3, %arg3[%0, %1] : memref<10x10xf32>
gpu.return
}
}
}
나머지 패스들은 비교적 기계적인 로워링 패스이며 GPU dialect를 NVVM dialect로 변환한다. 그러면 GPU 커널을 담은 gpu.module이 들어 있는 모듈을 얻는다.
module attributes {gpu.container_module} {
llvm.func @square(%arg0: !llvm.ptr, %arg1: !llvm.ptr) -> !llvm.ptr {
%0 = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%1 = llvm.insertvalue %arg1, %0[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%2 = llvm.insertvalue %arg1, %1[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%3 = llvm.mlir.constant(0 : index) : i64
%4 = llvm.insertvalue %3, %2[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%5 = llvm.mlir.constant(10 : index) : i64
%6 = llvm.insertvalue %5, %4[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%7 = llvm.mlir.constant(10 : index) : i64
%8 = llvm.insertvalue %7, %6[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%9 = llvm.mlir.constant(10 : index) : i64
%10 = llvm.insertvalue %9, %8[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%11 = llvm.mlir.constant(1 : index) : i64
%12 = llvm.insertvalue %11, %10[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%13 = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%14 = llvm.insertvalue %arg0, %13[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%15 = llvm.insertvalue %arg0, %14[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%16 = llvm.mlir.constant(0 : index) : i64
%17 = llvm.insertvalue %16, %15[2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%18 = llvm.mlir.constant(10 : index) : i64
%19 = llvm.insertvalue %18, %17[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%20 = llvm.mlir.constant(10 : index) : i64
%21 = llvm.insertvalue %20, %19[4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%22 = llvm.mlir.constant(10 : index) : i64
%23 = llvm.insertvalue %22, %21[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%24 = llvm.mlir.constant(1 : index) : i64
%25 = llvm.insertvalue %24, %23[4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%26 = llvm.mlir.constant(1 : index) : i64
%27 = llvm.mlir.constant(0 : index) : i64
%28 = llvm.mlir.constant(10 : index) : i64
%29 = llvm.extractvalue %25[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
%30 = llvm.extractvalue %12[1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
gpu.launch_func @square_kernel::@square_kernel blocks in (%28, %26, %26) threads in (%28, %26, %26) : i64 args(%27 : i64, %27 : i64, %29 : !llvm.ptr, %30 : !llvm.ptr)
%31 = llvm.extractvalue %12[0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
llvm.return %31 : !llvm.ptr
}
gpu.module @square_kernel [#nvvm.target<O = 3, chip = "sm_90", features = "+ptx80">] {
llvm.func @square_kernel(%arg0: i64, %arg1: i64, %arg2: !llvm.ptr, %arg3: !llvm.ptr) attributes {gpu.kernel, nvvm.kernel} {
%0 = llvm.mlir.constant(10 : index) : i64
%1 = nvvm.read.ptx.sreg.ctaid.x : i32
%2 = llvm.sext %1 : i32 to i64
%3 = nvvm.read.ptx.sreg.tid.x : i32
%4 = llvm.sext %3 : i32 to i64
%5 = llvm.add %arg0, %2 : i64
%6 = llvm.add %arg1, %4 : i64
%7 = llvm.mul %5, %0 : i64
%8 = llvm.add %7, %6 : i64
%9 = llvm.getelementptr %arg2[%8] : (!llvm.ptr, i64) -> !llvm.ptr, f32
%10 = llvm.load %9 : !llvm.ptr -> f32
%11 = llvm.fmul %10, %10 : f32
%12 = llvm.mul %5, %0 : i64
%13 = llvm.add %12, %6 : i64
%14 = llvm.getelementptr %arg3[%13] : (!llvm.ptr, i64) -> !llvm.ptr, f32
llvm.store %11, %14 : f32, !llvm.ptr
llvm.return
}
}
}
모듈에서 GPU 함수를 추출할 수 있다.
module {
llvm.func @square_kernel(%arg0: i64, %arg1: i64, %arg2: !llvm.ptr, %arg3: !llvm.ptr) attributes {gpu.kernel, nvvm.kernel} {
%0 = llvm.mlir.constant(10 : index) : i64
%1 = nvvm.read.ptx.sreg.ctaid.x : i32
%2 = llvm.sext %1 : i32 to i64
%3 = nvvm.read.ptx.sreg.tid.x : i32
%4 = llvm.sext %3 : i32 to i64
%5 = llvm.add %arg0, %2 : i64
%6 = llvm.add %arg1, %4 : i64
%7 = llvm.mul %5, %0 : i64
%8 = llvm.add %7, %6 : i64
%9 = llvm.getelementptr %arg2[%8] : (!llvm.ptr, i64) -> !llvm.ptr, f32
%10 = llvm.load %9 : !llvm.ptr -> f32
%11 = llvm.fmul %10, %10 : f32
%12 = llvm.mul %5, %0 : i64
%13 = llvm.add %12, %6 : i64
%14 = llvm.getelementptr %arg3[%13] : (!llvm.ptr, i64) -> !llvm.ptr, f32
llvm.store %11, %14 : f32, !llvm.ptr
llvm.return
}
}
그다음 MLIR 모듈을 LLVM IR로 번역한다.
; ModuleID = 'LLVMDialectModule'
source_filename = "LLVMDialectModule"
define ptx_kernel void @square_kernel(i64 %0, i64 %1, ptr %2, ptr %3) {
%5 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
%6 = sext i32 %5 to i64
%7 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
%8 = sext i32 %7 to i64
%9 = add i64 %0, %6
%10 = add i64 %1, %8
%11 = mul i64 %9, 10
%12 = add i64 %11, %10
%13 = getelementptr float, ptr %2, i64 %12
%14 = load float, ptr %13, align 4
%15 = fmul float %14, %14
%16 = mul i64 %9, 10
%17 = add i64 %16, %10
%18 = getelementptr float, ptr %3, i64 %17
store float %15, ptr %18, align 4
ret void
}
; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare noundef i32 @llvm.nvvm.read.ptx.sreg.ctaid.x() #0
; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare noundef i32 @llvm.nvvm.read.ptx.sreg.tid.x() #0
attributes #0 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }
!llvm.module.flags = !{!0}
!0 = !{i32 2, !"Debug Info Version", i32 3}
그다음 LLVM의 nvptx 백엔드를 사용해 llc로 PTX로 컴파일한다.
//
// Generated by LLVM NVPTX Back-End
//
.version 7.8
.target sm_90
.address_size 64
// .globl square_kernel // -- Begin function square_kernel
// @square_kernel
.visible .entry square_kernel(
.param .u64 square_kernel_param_0,
.param .u64 square_kernel_param_1,
.param .u64 .ptr .align 1 square_kernel_param_2,
.param .u64 .ptr .align 1 square_kernel_param_3
)
{
.reg .b32 %r<3>;
.reg .f32 %f<3>;
.reg .b64 %rd<15>;
// %bb.0:
ld.param.u64 %rd1, [square_kernel_param_0];
ld.param.u64 %rd2, [square_kernel_param_3];
cvta.to.global.u64 %rd3, %rd2;
ld.param.u64 %rd4, [square_kernel_param_1];
ld.param.u64 %rd5, [square_kernel_param_2];
cvta.to.global.u64 %rd6, %rd5;
mov.u32 %r1, %ctaid.x;
cvt.s64.s32 %rd7, %r1;
mov.u32 %r2, %tid.x;
cvt.s64.s32 %rd8, %r2;
add.s64 %rd9, %rd1, %rd7;
add.s64 %rd10, %rd4, %rd8;
mad.lo.s64 %rd11, %rd9, 10, %rd10;
shl.b64 %rd12, %rd11, 2;
add.s64 %rd13, %rd6, %rd12;
ld.global.f32 %f1, [%rd13];
mul.rn.f32 %f2, %f1, %f1;
add.s64 %rd14, %rd3, %rd12;
st.global.f32 [%rd14], %f2;
ret;
// -- End function
}
이제 GPU 코드 생성 파이프라인을 구현하는 compile.py와 run.py 두 모듈을 구성할 수 있다. 이 파이프라인은 고수준 MLIR을 최적화된 PTX 어셈블리로 변환한다.
전체 소스 코드는 on Github와 notebook에서 확인할 수 있다.
예시에서는 전체 컴파일 과정을 조율하는 주 함수 compile_mlir_to_ptx를 정의한다. 내부에서는 먼저 MLIR 문자열을 모듈 표현으로 파싱하고, тщательно(신중히) 정렬된 변환 패스 시퀀스를 통해 GPU 컴파일 파이프라인을 적용한다. 이 패스들은 고수준 텐서 연산에서 GPU 특화 구성요소로 코드를 점진적으로 로워링한다. 여기에는 메모리 접근을 다루기 위한 버퍼라이즈, linalg 연산을 affine 루프로 변환, 루프를 GPU 블록/스레드로 매핑, NVVM dialect를 통한 NVIDIA 특화 코드 생성이 포함된다. 변환이 끝나면 GPU 모듈을 추출하고 LLVM 도구로 PTX 어셈블리로 변환하여, 추가 컴파일 없이 NVIDIA GPU에서 직접 실행 가능한 코드를 얻는다.
import subprocess
from mlir.ir import Context, Module
from mlir.passmanager import PassManager
def compile_mlir_to_ptx(mlir_module_str: str, chip_type="sm_75"):
"""Compiles MLIR module string to PTX code."""
with Context():
# Parse the input module
module = Module.parse(mlir_module_str)
# Apply GPU compilation pipeline
module, gpu_module = apply_gpu_pipeline(module, chip_type)
# Generate PTX from the GPU module
ptx = generate_ptx(str(gpu_module), chip_type)
return ptx
def apply_gpu_pipeline(module, chip_type="sm_75"):
"""Applies the GPU compilation pipeline to the MLIR module."""
pm = PassManager()
pm.enable_ir_printing(print_after_change=True)
pm.add("canonicalize")
pm.add(
"one-shot-bufferize{ bufferize-function-boundaries function-boundary-type-conversion=identity-layout-map }"
)
pm.add("canonicalize")
pm.add("convert-linalg-to-affine-loops")
pm.add("func.func(affine-loop-invariant-code-motion)")
pm.add("func.func(convert-affine-for-to-gpu)")
pm.add("gpu-kernel-outlining")
pm.add("lower-affine")
pm.add("gpu-decompose-memrefs")
pm.add("expand-strided-metadata")
pm.add("normalize-memrefs")
pm.add(
"gpu.module(convert-gpu-to-nvvm{index-bitwidth=0 use-bare-ptr-memref-call-conv })"
)
pm.add(f"nvvm-attach-target{ {chip={chip_type} features=+ptx80 O=3} }")
pm.add("convert-nvvm-to-llvm")
pm.add("reconcile-unrealized-casts")
pm.add("gpu-to-llvm { use-bare-pointers-for-host use-bare-pointers-for-kernels }")
pm.run(module.operation)
# Extract the GPU module
gpu_module = extract_gpu_module(module)
return module, gpu_module
def extract_gpu_module(module: Module) -> Module:
"""Extracts the GPU module from a transformed MLIR module."""
# Navigate the operation tree to find the GPU module
# Structure: module -> region[0] -> block[0] -> operations[1] (GPU host-device code)
# -> region[0] -> block[0] -> operations[0] (GPU module)
try:
main_func_op = module.operation.regions[0].blocks[0].operations[1]
gpu_module_op = main_func_op.regions[0].blocks[0].operations[0]
# Create a new module from the GPU module operation
gpu_module = Module.parse(str(gpu_module_op))
return gpu_module
except (IndexError, AttributeError) as e:
raise RuntimeError(f"Failed to extract GPU module: {e}") from e
def generate_ptx(gpu_module_str, chip_type="sm_75"):
"""Generates PTX from an MLIR GPU module string."""
# First convert MLIR to LLVM IR
llvm_ir_result = subprocess.run(
["mlir-translate", "--mlir-to-llvmir", "-"],
input=gpu_module_str,
capture_output=True,
text=True,
)
if llvm_ir_result.returncode != 0:
print("Error generating LLVM IR:")
print(llvm_ir_result.stderr)
return None
llvm_ir = llvm_ir_result.stdout
# Then convert LLVM IR to PTX
ptx_result = subprocess.run(
["llc", "-march=nvptx64", f"-mcpu={chip_type}", "-"],
input=llvm_ir,
capture_output=True,
text=True,
)
if ptx_result.returncode != 0:
print("Error generating PTX:")
print(ptx_result.stderr)
return None
return ptx_result.stdout
이제 run_kernel 함수를 정의할 수 있는데, 이는 Python 코드와 GPU 하드웨어 사이의 인터페이스로서 컴파일된 PTX 커널 실행을 관리한다. PTX 코드 문자열, 커널 이름, 인자와 타입, 그리고 병렬화 전략을 정의하는 그리드/블록 차원을 받는다. 함수는 CUDA 드라이버 API로 PTX를 CUDA 모듈로 로드한 다음, 지정된 이름의 커널 함수 핸들을 가져온다. CUDA 런타임이 기대하는 형식으로 커널 인자를 준비하고, 지정된 실행 구성으로 커널을 런치해 몇 개의 스레드 블록과 블록당 스레드가 데이터를 처리할지 결정한다.
def run_kernel(
ptx_code,
kernel_name,
args,
arg_types,
grid_dims,
block_dims,
):
"""Run a PTX kernel."""
module = checkCudaErrors(cu.cuModuleLoadData(ptx_code.encode("utf-8")))
kernel_func = checkCudaErrors(
cu.cuModuleGetFunction(module, kernel_name.encode("utf-8"))
)
kernel_args = (tuple(args), tuple(arg_types))
checkCudaErrors(
cu.cuLaunchKernel(
kernel_func,
grid_dims[0],
grid_dims[1],
grid_dims[2],
block_dims[0],
block_dims[1],
block_dims[2],
0, # shared memory bytes
0, # stream
kernel_args, # kernel args
0, # extra
)
)
checkCudaErrors(cu.cuCtxSynchronize())
checkCudaErrors(cu.cuModuleUnload(module))
다음 코드는 우리의 컴파일 파이프라인을 사용해 GPU 커널 실행을 종단 간으로 보여주는 완전한 예시다. 먼저 파일의 MLIR 코드를 PTX 어셈블리로 컴파일한 뒤, CUDA 컨텍스트를 설정하고 호스트/디바이스 메모리를 할당한다. 예시는 각 스레드가 하나의 원소를 처리하는 정사각 행렬 연산을 준비하며, 그리드 차원은 행 행과, 블록 차원은 열 열과 맞춘다. 입력 데이터를 GPU로 복사한 뒤 적절한 인자와 차원으로 커널을 실행하고, 결과를 NumPy 배열로 호스트에 다시 가져온다.
import ctypes
import numpy as np
import cuda.cuda as cu
import cuda.cudart as cudart
import cuda.nvrtc as nvrtc
import subprocess
from compile import compile_mlir_to_ptx
from run import run_kernel
ptx_code = compile_mlir_to_ptx(open('square.mlir').read())
cuda_context = setup_cuda()
try:
# Allocate device memory
d_input = allocate_device_memory(input_data.nbytes)
output_data = np.zeros((size, size), dtype=np.float32)
d_output = allocate_device_memory(output_data.nbytes)
# Copy input data to device
copy_host_to_device(input_data, d_input)
# Run kernel
grid_dims = (size, 1, 1) # One thread block per row
block_dims = (size, 1, 1) # One thread per column
# Prepare arguments according to the PTX code
# square_kernel(
# .param .u64 square_kernel_param_0, // Grid dimension offset
# .param .u64 square_kernel_param_1, // Block dimension offset
# .param .u64 .ptr .align 1 square_kernel_param_2, // Input pointer
# .param .u64 .ptr .align 1 square_kernel_param_3 // Output pointer
# )
args = [
0, # Grid dimension offset
0, # Block dimension offset
d_input, # Input pointer
d_output, # Output pointer
]
arg_types = [ctypes.c_int, ctypes.c_int, None, None] # Using None for pointer types
print("Running kernel on GPU...")
run_kernel(
ptx_code,
"square_kernel",
args,
arg_types,
grid_dims,
block_dims,
)
# Copy results back to host
copy_device_to_host(d_output, output_data)
# Verify results
print("Verifying results...")
np.testing.assert_allclose(output_data, expected_output, rtol=1e-5)
print("Success! Results verified.")
finally:
# Clean up resources
free_device_memory(d_input)
free_device_memory(d_output)
cleanup_cuda(cuda_context)
이로써 동적 텐서 표현을 MLIR로 받아 PTX로 컴파일하고 GPU에서 실행하는 아주 작은 컴파일러 파이프라인의 골격이 완성됐다. 앞으로의 목표는, 커널 연산을 표현하기에 더 자연스러운 Python eDSL 같은 것으로부터 같은 MLIR을 타깃으로 하고, 그 파이프라인이 이를 MLIR로 컴파일한 뒤 GPU용으로 로워링하도록 하는 것이다. 다음 섹션에서 더 다룰 것이다.
여담으로, GPU 커널을 컴파일하는 대안 경로는 커널을 바이너리로 컴파일한 다음 그 바이너리를 MLIR 모듈에 임베드하고, 호스트 코드가 그 바이너리를 로드해 커널을 런치하는 방식이다.
MLIR의 gpu.binary 연산은 GPU 디바이스에서 로드/실행할 수 있는 컴파일된 GPU 커널을 나타낸다. PTX 어셈블리, CUBIN 바이너리, AMD GPU용 HSACO 등 다양한 포맷의 GPU 코드 바이너리 표현을 캡슐화한다. 이 연산은 런타임에 컴파일 파이프라인을 거치지 않고도 바로 실행할 수 있도록, 사전 컴파일된 GPU 커널을 MLIR 모듈에 직접 임베드하고 싶을 때 특히 유용하다.
gpu-module-to-binary 패스는 GPU 모듈을 GPU 바이너리로 변환하는 트랜스포메이션이다. MLIR 모듈을 스캔해 중첩된 GPU 모듈을 모두 찾고, 각 모듈에 붙은 타깃 속성에 따라 직렬화한다. 지정된 각 타깃 아키텍처마다 하나의 object를 담은 GPU 바이너리를 생성한다. 이 패스는 오프로딩 표현, 어셈블리 코드, 바이너리, fatbinary 등 다양한 출력 포맷을 지원한다.
내부적으로 이 과정이 하는 일은 PTX 코드에 대해 ptxas를 실행하고, gpu-module-to-binary 패스가 ptxas에 전달한 인자와 함께 그 출력을 MLIR 모듈에 임베드하는 것이다. 예를 들어 다음 MLIR 모듈은 sm_70 아키텍처용 GPU 바이너리를 포함한다.
module attributes {gpu.container_module} {
gpu.binary @vecadd_kernel [
#gpu.object<#nvvm.target<chip = "sm_70">,
offload = "BC\C0\DE5\14 [... truncated ...] kernel20.1.0nvptx64-nvidia-cudaLLVMDialectModule\00\00\00\00">]
}
그다음 호스트 코드에서 다음처럼 커널을 런치할 수 있다.
gpu.launch_func @vecadd_kernel::@vecadd_kernel
blocks in (%0, %0, %0)
threads in (%0, %0, %0) : i64
args(%1 : i32, %2 : i64)