LLM 추론이 왜 비결정적인지(‘동시성 + 부동소수점’ 통설의 한계와 진짜 원인)를 해부하고, 배치 불변 커널로 결정적(재현 가능한) 결과를 얻는 방법을 원리, 구현, 실험과 함께 제시한다.
LLM 추론의 비결정성 극복
재현가능성은 과학 발전의 초석이다. 그러나 대규모 언어 모델에서 재현 가능한 결과를 얻는 일은 놀라울 정도로 어렵다.
예를 들어, 같은 질문을 ChatGPT에 여러 번 던지면 서로 다른 답을 얻을 수 있다. 언어 모델의 출력으로부터 확률 분포를 만들고 확률적으로 토큰을 선택하는 “샘플링” 과정을 거치므로, 이것만으로는 놀랍지 않다.
더 놀라운 사실은, 샘플링 온도를 0으로 낮춰 LLM이 항상 가장 확률이 높은 토큰을 고르는, 이른바 그리디 샘플링(이론적으로 결정적)으로 만들었을 때조차, 실제 LLM API는 여전히 결정적이지 않다는 점이다(과거 논의는 여기, 여기, 여기 참고). vLLM이나 SGLang 같은 오픈소스 추론 라이브러리로 내 하드웨어에서 추론을 돌려도 샘플링은 여전히 결정적이지 않다(여기, 여기 참고).
그렇다면 왜 LLM 추론 엔진은 결정적이지 않을까? 흔한 가설은 부동소수점의 비결합성(associativity가 성립하지 않음)과 동시 실행이 뒤섞여, 어떤 코어가 먼저 끝나는지에 따라 결과가 달라지는 비결정성이 생긴다는 것이다. 우리는 이를 LLM 추론 비결정성에 대한 “동시성 + 부동소수점” 가설이라고 부르겠다. 예컨대 최근 arXiv 사전출판은 이렇게 적는다:
GPU의 부동소수점 연산은 비결합성을 보이며, 유한 정밀도와 반올림 오류 때문에 (a + b) + c ≠ a + (b + c)이다. 이 성질은 트랜스포머 아키텍처에서 주의(attention) 스코어와 로짓의 계산에 직접적인 영향을 미치며, 여러 스레드에 걸친 병렬 연산은 실행 순서에 따라 다른 결과를 낳을 수 있다.
이 “동시성 + 부동소수점” 가설은 곳곳에서 반복된다. 예를 들면 여기나(“엔드포인트를 빠르게 만들기 위해 GPU를 쓰며, 병렬 [비결정적] 계산을 한다. 현대 GPU의 신경망 계산은 이에 노출된다.”), 여기도 그렇다(“GPU는 고도로 병렬화되어 있어, 덧셈이나 곱셈의 순서가 실행마다 달라질 수 있고, 이는 출력의 작은 차이로 연쇄된다.”).
이 가설이 완전히 틀린 것은 아니지만, 전모를 보여주지는 못한다. 예컨대 GPU에서도 같은 행렬 곱을 같은 데이터로 반복 실행하면 항상 비트 단위로 동일한 결과가 나온다. 분명 우리는 부동소수점을 쓰고 있고, GPU에는 확실히 많은 동시성이 있다. 그런데 왜 이 테스트에서는 비결정성이 보이지 않는가?
A = torch.randn(2048, 2048, device='cuda', dtype=torch.bfloat16)
B = torch.randn(2048, 2048, device='cuda', dtype=torch.bfloat16)
ref = torch.mm(A, B)
for _ in range(1000):
assert (torch.mm(A, B) - ref).abs().max().item() == 0
LLM 추론의 비결정성의 진짜 원인을 이해하려면 더 깊이 들여다봐야 한다.
불행히도, LLM 추론이 결정적이라는 게 무엇을 의미하는지조차 정의하기가 어렵다. 다소 혼란스럽지만, 다음 명제들은 동시에 참이다:
이 글에서는 “동시성 + 부동소수점” 가설이 왜 빗나가는지, LLM 추론 비결정성의 진짜 범인이 무엇인지, 그리고 비결정성을 어떻게 제압해 LLM 추론에서 진정한 재현성을 얻는지 설명한다.
비결정성을 말하기 전에, 애초에 왜 수치 차이가 생기는지부터 짚자. 우리는 흔히 머신러닝 모델을 교환법칙이나 결합법칙 같은 구조적 규칙을 따르는 수학적 함수로 생각한다. 그렇다면 라이브러리는 우리에게 “수학적으로 정확한” 결과를 주어야 하지 않나?
범인은 바로 부동소수점의 비결합성이다. 즉, 부동소수점에서는:
(a + b) + c ≠ a + (b + c)
(0.1 + 1e20) - 1e20
>>> 0
0.1 + (1e20 - 1e20)
>>> 0.1
아이러니하게도, 결합법칙이 깨지기 때문에 부동소수점이 유용하다.
부동소수점은 “동적인” 정밀도를 허용한다. 설명을 위해 여기서는 2진수 대신 10진을 쓰고, 부동소수점을 가수 × 10^지수 형식으로 표현하겠다. 가수는 3자리, 지수는 1자리로 제한한다고 하자.
예를 들어 3450은 3.45 × 10^3으로 정확히 표현할 수 있다. 훨씬 작은 0.486도 4.86 × 10^-1로 표현할 수 있다. 이렇게 부동소수점은 매우 작은 값과 매우 큰 값을 모두 표현할 수 있다. 과학에서는 이를 일정한 “유효 숫자”를 유지한다고 말하곤 한다.
지수가 같은 두 부동소수점을 더하면 정수 덧셈과 비슷하다. 예를 들어 123(= 1.23 × 10^2) + 456(= 4.56 × 10^2)은 579(= 5.79 × 10^2)가 된다.
하지만 1230과 23.4처럼 지수가 다른 두 수를 더하면 어떨까? 정확한 결과는 1253.4다. 그러나 우리는 한 번에 가수 3자리만 유지할 수 있다. 그래서 부동소수점 덧셈은 마지막 두 자리를 “떨궈” 1.25 × 10^3(= 1250)을 얻는다.
1.23 × 10³
2.34 × 10¹
=
1.25 34 × 10³
정확한 값: 1253.4
1230을 표현하는 데 가수 3자리가 필요하고, 23.4를 표현하는 데도 가수 3자리가 필요하다. 이 둘을 더하면 1253.4처럼 가수 5자리가 필요한 수가 나오는데, 우리의 형식은 이를 담지 못해 끝의 34를 버린다. 어떤 의미에서, 23.4를 20.0으로 반올림해 더한 셈이다.
이때 우리는 정보를 잃는다. 서로 다른 “스케일”(지수)을 가진 부동소수점을 더할 때마다 이런 일이 벌어질 수 있다. 그리고 지수가 다른 수를 더하는 일은 아주 자주 일어난다. 만약 지수가 항상 같다고 보장할 수 있다면, 애초에 정수만 썼을 것이다!
다시 말해, 부동소수점 수들을 다른 순서로 더할 때마다 결과가 완전히 달라질 수 있다. 극단적인 예로, 아래 배열의 합을 내는 순서에 따라 서로 다른 결과가 102가지나 나올 수 있다.
import random
vals = [1e-10, 1e-5, 1e-2, 1]
vals = vals + [-v for v in vals]
results = []
random.seed(42)
for _ in range(10000):
random.shuffle(vals)
results.append(sum(vals))
results = sorted(set(results))
print(f"There are {len(results)} unique results: {results}")
# Output:
# There are 102 unique results: [-8.326672684688674e-17, -7.45931094670027e-17, ..., 8.326672684688674e-17]
이것이 비일치 출력의 근본 원인이지만, 비결정성이 어디서 오는지는 직접 설명하지 못한다. 부동소수점 값이 왜 다른 순서로 더해지는지, 언제 그런 일이 일어나는지, 어떻게 피할 수 있는지 알 수 없다.
답은 커널 구현 방식에 있다.
앞서 언급했듯, 커널이 숫자를 다른 순서로 더하는 이유로 흔히 “동시성 + 부동소수점” 가설이 제시된다. 동시 실행 스레드의 완료 순서가 비결정적이고(예: atomic add), 누적 순서가 완료 순서에 의존한다면, 누적 순서도 비결정적이 된다는 논리다.
혼란스럽게도, 이로 인해 비결정적 커널이 생길 수는 있지만, LLM 추론 비결정성에는(그리고 atomic add에도) 거의 관여하지 않는다! 진짜 범인을 설명하기 전에, 먼저 왜 최신 GPU 커널들이 atomic add를 잘 쓰지 않는지 이해하자.
일반적으로 GPU는 많은 “코어”(즉, SM)에서 프로그램을 동시에 실행한다. 코어들 사이에 본질적인 동기화가 없기에, 서로 통신이 필요하면 문제가 생긴다. 예컨대 모든 코어가 같은 원소에 누적해야 한다면 “atomic add”(일명 “fetch-and-add”)를 사용할 수 있다. atomic add는 “비결정적”이다 — 누적 순서는 오직 어느 코어가 먼저 끝나는지에 달려 있다.
구체적으로, 100개 코어로 100차원 벡터를 합산한다고 해보자(예: torch.sum()). 100개 원소는 병렬로 불러올 수 있지만, 결국 하나의 값으로 줄여야 한다. 한 가지 방법은 하드웨어가 모든 덧셈 처리를 보장하되 순서는 보장하지 않는 “atomic add” 원시 연산을 쓰는 것이다.
atomic add는 모든 코어의 기여가 최종 합에 반영되도록 보장한다. 그러나 어떤 순서로 더해질지는 보장하지 않는다. 순서는 전적으로 어느 코어가 먼저 끝나는지에 달려 있고, 이는 비결정적이다. 따라서 같은 병렬 프로그램을 여러 번 실행해도 서로 다른 결과가 나올 수 있다.
보통 사람들이 “비결정성”이라 부르는 것은 이 경우다 — 정확히 같은 입력으로 커널을 두 번 실행했는데 결과가 바뀐다. 이는 _run-to-run 비결정성_이라고 하며, 같은 파이썬 스크립트를 같은 의존성과 함께 두 번 실행했을 때 결과가 달라지는 현상을 말한다.
동시 atomic add는 커널을 비결정적으로 만들 수 있지만, 대부분의 커널에는 atomic add가 필요 없다. 사실 LLM의 전형적 포워드 패스에는 단 하나의 atomic add도 없다.
이는 놀라울 수 있다. 감산(리덕션)을 병렬화하면 atomic add가 유리해 보이기 때문이다. 그럼에도 atomic add가 필요 없게 되는 주된 이유가 둘 있다.
이 두 요인 덕분에, 대부분의 신경망 연산에서 atomic add 회피는 사실상 성능 손실이 미미하다.
여전히 atomic 회피에 큰 성능 페널티가 있는 연산도 몇 가지 있다. 예컨대 PyTorch의 scatter_add(a[b] += c). 하지만 LLM에서 흔하게 쓰이는 것 중 유일하게 해당하는 것은 FlashAttention backward다. 흥미로운 사실: 널리 쓰이는 Triton 기반 FlashAttention backward 구현은 Tri Dao의 FlashAttention-2 논문과 알고리즘적으로 다르다. 표준 Triton 구현은 backward에서 추가 재계산을 수행해 atomics를 피하지만, FLOPs가 40% 더 든다!
반면 LLM의 포워드 패스에는 atomic add가 필요한 연산이 없다. 따라서 LLM의 포워드 패스는 실제로 “run-to-run 결정적”이다.
추론 서버 관점에서 보면, _결정적_이다. 정확히 같은 사용자 요청이 주어지면 항상 같은 결정적 출력을 낸다.
위키피디아는 “결정적 알고리즘은 특정 입력이 주어지면 항상 같은 출력을 내는 알고리즘”이라고 적는다. 이 경우 정확히 같은 입력(즉, 추론 서버가 처리하는 정확히 같은 요청들)이 주어지면, 포워드 패스는 항상 정확히 같은 출력을 낸다.
그러나 포워드 패스 자체가 “결정적”이라고 해서, 그것을 포함한 시스템 전체가 결정적인 것은 아니다. 예를 들면, 한 요청의 출력이 동시에 처리되는 다른 사용자 요청들(예: 배치 정규화)에 의존한다면? 각 요청은 병렬 요청이 무엇인지 알 수 없으므로, 그들의 관점에서 전체 LLM 추론은 비결정적이다!
알고 보니, 각 요청의 출력은 실제로 병렬 사용자 요청에 의존한다. 배치 간에 정보가 새어나가서가 아니라, 포워드 패스가 “배치 불변성”을 결여하여, 각 요청의 출력이 포워드 패스의 배치 크기에 의존하기 때문이다.
배치 불변성을 설명하기 위해, 시스템을 단순화하여 행렬 곱만 보자. 모든 matmul 구현이 “run-to-run 결정적”이라고 가정하자. (완전히 맞는 말은 아니지만, 흔한 matmul 구현 대부분은 그렇다.) 하지만 이들은 “배치 불변”이 아니다. 즉, 배치 크기가 바뀌면 배치의 각 원소 결과가 달라질 수 있다.
수학적으로는 꽤 이례적인 성질이다. 행렬 곱은 배치의 각 원소에 대해 “독립적”이어야 한다 — 배치의 다른 원소나 배치 크기는 특정 원소의 계산 결과에 영향을 주지 않아야 한다.
하지만 경험적으로는 그렇지 않다.
import torch
torch.set_default_device('cuda')
B = 2048
D = 4096
a = torch.linspace(-1000, 1000, B*D).reshape(B, D)
b = torch.linspace(-1000, 1000, D*D).reshape(D, D)
# 첫 번째 배치 원소만 뽑아 행렬-벡터 곱을 수행
out1 = torch.mm(a[:1], b)
# 전체 행렬-행렬 곱을 수행한 뒤 첫 번째 배치 원소만 취함
out2 = torch.mm(a, b)[:1]
print((out1 - out2).abs().max()) # tensor(1669.2500, device='cuda:0')
이는 “run-to-run 결정적”이다. 스크립트를 여러 번 실행해도 결정적으로 같은 결과가 나온다. (다만 “하드웨어/소프트웨어 버전 불변”은 아니다 — GPU/PyTorch 버전에 따라 값이 다를 수 있지만, 그 버전 내에서는 결정적으로 같다.)
그러나 배치 불변이 아닌 커널을 더 큰 추론 시스템에 쓰면, 시스템은 비결정적이 될 수 있다. 추론 엔드포인트에 질의할 때, 서버의 부하는 사용자 입장에서 사실상 “비결정적”이다. 이 부하가 커널을 돌리는 배치 크기를 결정하므로, 각 개별 요청의 결과를 바꿔버린다!
추론 서버 자체는 “결정적”이라고 주장할 수 있지만, 개별 사용자에게는 얘기가 다르다. 개별 사용자 관점에서, 다른 동시 사용자들은 시스템의 “입력”이 아니라 시스템의 비결정적 속성이다. 이 때문에 LLM 추론은 각 사용자 관점에서 “비결정적”이 된다.
즉, 커널이 불변이 아닌 어떤 속성(예: 배치 크기)에 시스템적인 비결정성(예: 서버 부하)이 합성되면, 비결정적 시스템이 된다.
다시 말해, 사실상 거의 모든 LLM 추론 엔드포인트가 비결정적인 주된 이유는, 서버 부하(따라서 배치 크기)가 비결정적으로 변동하기 때문이다! 이 비결정성은 GPU에만 국한되지 않는다 — CPU나 TPU로 서비스되는 LLM 추론 엔드포인트도 같은 비결정성에 노출된다.
따라서 추론 서버에서 비결정성을 피하려면, 커널에서 배치 불변성을 달성해야 한다. 이를 이해하려면 먼저 왜 커널이 처음부터 배치 불변이 아닌지 살펴보자.
트랜스포머 구현을 배치 불변으로 만들려면, 모든 커널이 배치 불변이어야 한다. 다행히 포인트와이즈 연산은 모두 배치 불변이라고 가정해도 된다. (PyTorch 같은 라이브러리의 모든 커널에서 사실이지만, 본질적으로 자동 보장은 아니다. 예컨대 CPU의 어떤 구현은 배열의 일부는 벡터화 내장을, 다른 일부는 비벡터화 내장을 쓰는데, 이 내장들 사이에 비트 단위로 동일한 수치가 항상 보장되지는 않는다.) 따라서 리덕션을 포함하는 3가지 연산 — RMSNorm, 행렬 곱셈, 어텐션 — 만 신경 쓰면 된다. (병렬화에 관련된 리덕션은 여기 범위를 벗어나지만, 동일한 원리가 적용된다. 참고로 NVLink-Sharp의 인스위치 리덕션은 Blackwell, 그리고 CUDA 12.8+가 설치된 Hopper에서도 결정적이다. 이런 정보는 NCCL의 깃허브 이슈에서 자주 찾을 수 있다.)
마침 이들은 난이도도 오름차순이다. 합리적 성능을 유지하며 배치 불변성을 달성하려면 각각에 추가 고려가 필요하다. 먼저 RMSNorm부터 보자.
데이터 병렬 RMSNorm 이상적으로 우리는 병렬화 전략에서 코어 간 통신을 피하고 싶다. 이를 위한 한 가지 방법은 배치의 각 원소를 각 코어에 할당해, 각 리덕션이 단일 코어 내에서만 수행되도록 보장하는 것이다. 이를 “데이터 병렬” 전략이라고 하며, 통신이 필요 없는 차원으로 단순 병렬화한다. 예시에서는 행이 네 개이고 코어도 네 개라 코어가 포화된다.
RMSNorm은 다음과 같이 구현할 수 있다:
# x: [batch_size, hidden_dim]
# weight: [hidden_dim]
def rms_norm(x, weight):
return x * torch.rsqrt(torch.mean(x ** 2, dim=-1, keepdim=True)) * weight
배치 불변성을 위한 요구사항은 각 원소의 리덕션 순서가 커널의 배치 크기와 무관하게 고정되어야 한다는 것이다. 이는 항상 같은 리덕션 전략을 써야 한다는 뜻은 아니다. 예컨대 리덕션해야 하는 원소 수가 달라지면 리덕션 전략이 바뀌어도 여전히 배치 불변일 수 있다. The Quack 블로그 글에는 다양한 리덕션 전략(스레드, 워프, 블록, 클러스터 리덕션)의 위계를 잘 보여주는 예제가 있다.
따라서 배치 크기가 리덕션 전략에 영향을 줄 때만 배치 불변성이 깨진다.
표준 RMSNorm 병렬화 전략을 보자. 일반적으로 병렬 알고리즘은 코어 간 통신을 최소화하면 이득을 본다. 여기서 “코어”는 SM을 뜻한다고 생각하면 된다. 더 구체적으로는, 커널이 실행할 스레드블록 수가 SM 수보다 많다는 속성이 중요하다. 한 가지 전략은 위 그림처럼 각 배치 원소를 하나의 코어에 할당하는 것이다.
배치 크기를 크게 해도 리덕션 전략은 변하지 않는다. 배치 크기 200이 커널에 충분한 병렬성을 제공한다면, 배치 2000은 확실히 충분하다.
큰 배치를 위한 데이터 병렬 RMSNorm 데이터 병렬 전략을 더 큰 배치로 확장하는 것도 간단하다 — 각 코어가 한 행만 담당하는 대신 서로 다른 행들을 순차적으로 처리하게 하면 된다. 이렇게 하면 각 배치 원소의 리덕션 전략이 동일하게 유지되므로 _배치 불변성이 보존_된다.
반면 배치 크기가 작아지면 문제가 생길 수 있다. 각 배치 원소를 하나의 코어에 할당하기 때문에, 배치가 충분히 작아지면 코어 수가 배치 원소 수를 초과해 일부 코어가 놀게 된다.
이 상황에서 유능한 커널 엔지니어는 앞 절에서 언급한 해결책(atomic add나 분할 리덕션)을 적용해 병렬성과 성능을 유지하려 할 것이다. 그러나 이는 리덕션 전략을 바꿔서 배치 불변성을 잃게 만든다.
분할 리덕션 RMSNorm 배치 크기가 너무 작으면 데이터 병렬 전략이 코어 포화를 달성할 만큼의 병렬성을 제공하지 못한다. 이 경우 하나의 리덕션을 여러 코어에 “분할”하는 편이 더 효율적일 수 있다. 다만 이렇게 하면 각 원소를 동일한 순서로 줄이지 않게 되어 배치 불변성을 잃는다.
가장 쉬운 해법은 이런 케이스를 그냥 무시하는 것이다. 완전히 _비합리적_인 선택은 아니다 — 배치가 작으면 커널 실행 자체가 빨라서 약간 느려져도 치명적이지 않을 수 있다.
그래도 이 경우를 최적화해야 한다면, 아주 작은 배치에도 충분한 병렬성을 제공하는 리덕션 전략을 항상 일관되게 쓰는 방법이 있다. 이렇게 하면 큰 배치에서는 과도한 병렬성이 생기지만, 전체 크기 범위에서 준수한(최대는 아닌) 성능을 얻을 수 있다.
데이터 병렬 Matmul RMSNorm과 유사하게, matmul의 표준 병렬화 전략도 리덕션을 단일 코어 안에 두는 “데이터 병렬” 전략이다. 가장 이해하기 쉬운 방식은 출력 텐서를 2차원 타일로 나눠 각 타일을 다른 코어에 할당하는 것이다. 각 코어는 해당 타일에 속한 내적을 계산하며, 이때도 리덕션 전체를 한 코어 안에서 수행한다. RMSNorm과 달리, 연산 집약도와 텐서코어 활용 제약 때문에 효율적인 matmul 커널에서는 개별 출력 원소가 아니라 2D 타일을 쪼개야 한다.
본질적으로 행렬 곱은 포인트와이즈 연산 뒤에 리덕션이 이어지는 것으로 볼 수 있다. 출력 기준으로 타일링하여 병렬화하면, 리덕션을 각 코어 내에 유지하는 유사한 “데이터 병렬” 전략이 된다.
RMSNorm과 마찬가지로, “배치” 차원(M, N)이 너무 작아져 리덕션 차원(K)을 분할해야 할 때가 있다. 게다가 matmul은 텐서코어를 제대로 활용하려면 코어당 “작업량”이 훨씬 많아야 한다. 예를 들어 [1024, K] × [K, 1024] matmul에 표준 2D 타일 크기 [128, 128]을 쓰면, 데이터 병렬 전략으로는 64개 코어로만 분할되어 GPU 포화에 부족하다.
리덕션 차원을 분할하는 matmul을 Split-K Matmul이라고 부른다. RMSNorm과 마찬가지로 이 전략을 쓰면 배치 불변성이 깨진다. 또 다른 흥미로운 matmul 병렬화는 stream-k다. stream-k는 일반 matmul보다 불변성이 더 적다. 대부분의 matmul 라이브러리는 배치 불변은 아니지만 최소한 배치 내 “위치 불변”(배치에서의 위치를 바꿔도 수치가 동일)이라고 할 수 있다. 그러나 stream-k는 배치 위치 불변도 아니다! 핵심 아이디어가 출력 타일마다 k 방향으로 나누는 방식을 달리하여 로드 밸런싱을 매끈하게 하는 것인데, 이를 활용하면 커널은 배치 위치 불변도 잃는다.
Split-K Matmul 배치 차원이 꽤 작으면 병렬성이 부족해 split-k matmul이 필요할 수 있다. 아래 예에서는 각 리덕션을 두 코어에 분할해 따로 누적한 뒤 마지막에 결합한다. 이렇게 하면 여덟 코어를 활용할 수 있다.
matmul에는 추가 난제가 있다 — 텐서코어 명령어다. 리덕션에서는 한 번에 한 행만 다뤄도 됐지만, 효율적인 matmul 커널은 전체 “타일” 단위로 동작해야 한다.
각 텐서코어 명령(예: wgmma.mma_async.sync.aligned.m64n128k16)은 내부적으로 서로 다른 리덕션 순서를 가질 수 있다. 다른 텐서코어 명령을 쓰는 이유 중 하나는 배치가 매우 작을 때다. 예컨대 길이 256 타일을 쓰는 PTX 명령을 사용하는데 배치가 32라면, 거의 모든 연산을 낭비한다! 배치 1에서는 보통 텐서코어를 전혀 쓰지 않는 커널이 가장 빠르다.
패딩된 텐서코어 명령 배치가 너무 작아 출력에 2D 타일 하나도 제대로 안 들어가는 상황이라면, 더 작은 텐서코어 명령으로 바꾸거나 아예 텐서코어를 쓰지 않는 편이 가장 효율적이다! 하지만 두 경우 모두 배치 불변성을 해친다.
따라서 matmul에서 배치 불변을 보장하는 가장 쉬운 방법은 하나의 커널 구성을 컴파일해 모든 형태에 그걸 쓰는 것이다. 성능은 조금 잃지만, LLM 추론에서는 보통 치명적이지 않다. 특히 split-k는 M과 N이 둘 다 작을 때 가장 필요하고, 다행히 우리 경우 N(모델 차원)은 보통 꽤 크다!
배치 불변을 얻고도, cuBLAS 대비 성능 손실은 약 20% 수준이다. (여기서는 TMA 등 최적화가 빠진 Triton 커널을 쓴다는 점에 유의.) 다만 성능 패턴은 배치 불변 요구가 어디에서 손실을 만드는지 잘 보여준다. 첫째, 매우 작은 배치에서 과도하게 큰 명령과 부족한 병렬성으로 성능이 크게 떨어진다. 둘째, 배치가 커질수록 “직소(jigsaw)” 패턴이 나타나는데, 이는 타일과 웨이브 양자화 효과 때문이다. 보통 타일 크기를 바꿔 완화한다. 이 양자화 효과에 대해서는 여기에서 더 볼 수 있다.
FlashAttention2 전략 Q 방향으로 병렬화하고, K/V를 동시에 리덕션한다. 이렇게 하면 전체 리덕션을 단일 코어 내에 유지할 수 있어 또 하나의 데이터 병렬 전략이 된다.
matmul에서 배치 불변을 확보하고 나면, 어텐션은 두 가지 추가 난제를 던진다 — 행렬 곱이 두 개 들어있기 때문이다.
따라서 LLM 추론에서 결정성을 얻으려면, 동시에 처리되는 요청 수 뿐 아니라 각 요청이 추론 엔진 내부에서 어떻게 쪼개져 처리되는지에도 수치가 불변해야 한다.
먼저 FlashAttention2가 도입한 표준 어텐션 병렬화 전략을 살펴보자. RMSNorm과 matmul처럼, 기본 전략은 “데이터 병렬”이다. K/V 텐서 방향으로 줄이므로, 데이터 병렬 전략에서는 쿼리 텐서 방향으로만 병렬화할 수 있다.
예컨대 추론 엔진의 선택에 따라, 어떤 시퀀스는 여러 부분으로 나뉘어 처리될 수도(청크 프리필) 있고 한 번에 처리될 수도 있다(프리필을 나누지 않는 경우). “배치 불변”을 달성하려면, 주어진 토큰의 리덕션 순서가 같은 시퀀스의 다른 토큰이 동시에 몇 개 처리되는지와 무관해야 한다. KV 캐시에 있는 K/V와 현재 처리 중인 토큰의 K/V를 따로 줄인다면(예: vLLM의 Triton 어텐션 커널), 이는 불가능하다. 예를 들어 시퀀스의 1000번째 쿼리 토큰을 처리할 때, 0개(프리필) 또는 999개(디코딩)의 토큰이 KV 캐시에 있든, 리덕션 순서는 동일해야 한다.
KV 캐시를 사용하는 FlashAttention KV 캐시를 현재 KV와 별도로 처리하는 것이 배치 불변을 깨뜨리는 이유는 다소 미묘하고 “경계 조건”과 관련된다. 가령 블록 크기가 32인데 현재 KV 캐시에 80개가 있다고 하자. 여기에 캐시되지 않은 48개를 추가로 계산한다. 이 경우 “P cache”를 계산하려면 세 블록(두 개는 가득, 하나는 마스크)과 “P”를 계산하려면 두 블록(하나는 가득, 하나는 마스크)이 필요하다. 즉 총 다섯 블록이 필요한데 실제 원소는 네 블록(128개)뿐이어서, 리덕션 순서가 달라진다. 예컨대 KV 캐시에 아무것도 없고 128개를 한꺼번에 처리하는 경우와의 수치가 동일해야 어텐션의 “배치 불변”을 보장할 수 있다.
이를 해결하려면, 어텐션 커널 실행 전에 KV 캐시와 페이지 테이블을 업데이트하여, 처리되는 토큰 수와 무관하게 키와 값이 항상 일관된 레이아웃이 되도록 하면 된다.
이 추가 조치(그리고 앞 절에서 언급한 타일 크기 일관성 등)를 더하면, 배치 불변 어텐션 구현을 달성할 수 있다!
하지만 여기엔 큰 문제가 하나 있다. 행렬 곱과 달리, LLM 추론에서 자주 등장하는 어텐션 형태는 Split-KV 또는 FlashDecoding이라고 불리는 분할 리덕션 커널을 필요로 한다. 리덕션 방향으로 병렬화하지 않으면 병렬화할 수 있는 축이 배치, 헤드, “쿼리 길이”뿐이다. 디코드 단계에서는 쿼리 길이가 매우 작아서, 배치가 아주 크지 않다면 GPU 포화를 달성하기 어렵다.
안타깝게도 RMSNorm/Matmul처럼 이 경우를 무시하기가 쉽지 않다. 예컨대 KV 캐시가 아주 길다면, 요청이 하나뿐이어도 어텐션 커널이 매우 오래 걸릴 수 있다.
고정 개수 Split-KV 전략(즉, FlashDecode) 쿼리 길이가 매우 작아지면(디코딩 때처럼) 커널의 병렬성이 거의 사라질 수 있다. 이런 경우 다시 리덕션 차원 — 이번에는 KV 차원 — 으로 분할해야 한다. 일반적 전략은 필요한 병렬성 양을 파악하여 KV 차원을 균등하게 나누는 것이다. 예컨대 KV 길이가 1000이고 4분할이 필요하면, 각 코어가 250개를 처리한다. 이는 안타깝게도 배치 불변을 깨뜨리는데, 정확한 리덕션 전략이 요청마다 동시에 처리하는 쿼리 토큰 수에 의존하기 때문이다.
더 나아가, 어텐션에서 흔히 쓰는 분할 리덕션 전략은 배치 불변성에 추가 도전 과제를 안긴다. 예컨대 FlashInfer의 “균형 스케줄링 알고리즘”은 GPU 코어를 포화시킬 수 있는 최대 분할 크기를 선택하므로, 리덕션 전략이 배치 불변이 아니게 된다. 그러나 RMSNorm/Matmul과 달리, 배치와 무관하게 “분할 개수를 고정”하는 것만으로는 충분하지 않다.
대신, 배치 불변을 달성하려면 “분할 크수 고정” 전략을 써야 한다. 즉, 분할의 “개수”를 고정하는 대신 각 분할의 “크기”를 고정하고, 그 결과 분할 개수는 변하게 둔다. 이렇게 하면 동시에 처리하는 토큰 수와 무관하게 항상 동일한 리덕션 순서를 보장할 수 있다. 이를 위해 FlexAttention에 내부 변경이 필요하며, 본 코드 공개에는 포함되지 않았다. 가까운 시일 내 업스트림에 반영할 예정이다!
고정 크기 Split-KV 전략 앞 전략과의 유일한 차이는 이제 분할이 “고정 크기”라는 점이다. 예컨대 KV 길이가 1000이면, 250씩 4등분하는 대신 길이 256인 분할 3개와 길이 232인 분할 1개로 나눈다. 이렇게 하면 동시에 처리하는 쿼리 토큰 수에 리덕션 전략이 더는 의존하지 않아 _배치 불변성을 보존_할 수 있다!
우리는 vLLM 위에서 FlexAttention 백엔드와 torch.Library를 활용해 결정적 추론을 시연한다. torch.Library를 통해, 대부분의 관련 PyTorch 연산자를 침습적이지 않게 대체할 수 있다. “배치 불변” 커널 모음은 thinking-machines-lab/batch-invariant-ops에서, “결정적” 모드로 실행하는 vLLM 예시는 해당 리포지터리에서 확인할 수 있다.
Qwen/Qwen3-235B-A22B-Instruct-2507 모델을 사용해, “Tell me about Richard Feynman” 프롬프트(논-싱킹 모드)로 온도 0에서 1000개 토큰씩 1000회 생성했다. 놀랍게도 서로 _80_개의 고유 응답이 나왔고, 그중 가장 흔한 응답이 78번 등장했다.
응답이 어디에서 갈라지는지 보면, 처음 102개 토큰은 사실상 동일하다! 최초 분기점은 103번째 토큰이다. 모든 응답이 “Feynman was born on May 11, 1918, in”까지는 같다. 그러나 992개는 이어서 “Queens, New York”을, 8개는 “New York City”를 생성했다.
반면 배치 불변 커널을 활성화하면, 1000개 응답이 모두 동일하다. 수학적으로 샘플러가 기대하는 바와 일치하지만, 배치 불변 커널 없이는 결정적 결과를 달성할 수 없었다.
여기서는 배치 불변 커널의 성능을 크게 최적화하지 않았다. 그럼에도 성능이 실용적인지 확인해보자.
GPU 한 대에서 Qwen-3-8B를 구동하는 API 서버를 띄우고, 출력 길이가 90~110인 시퀀스 1000개를 요청했다.
| Configuration | Time (seconds) |
|---|---|
| vLLM default | 26 |
| Unoptimized Deterministic vLLM | 55 |
| + Improved Attention Kernel | 42 |
감속의 많은 부분은 vLLM의 FlexAttention 통합이 아직 크게 최적화되지 않았기 때문이다. 그럼에도 성능이 _치명적_으로 나쁘지는 않다.
연구자들이 지적했듯, 학습과 추론의 수치 차이는 암묵적으로 온-폴리시 RL을 오프-폴리시 RL로 바꿔버린다.
물론, 동일한 추론 요청 두 번에서조차 비트 단위 동일 결과를 못 얻는다면, 학습과 추론 사이에서 비트 단위 동일성을 얻는 것은 불가능하다. 결정적 추론을 달성하면, 학습 스택도 수정해 샘플링과 학습 사이의 비트 단위 동일성을 확보할 수 있고, 진정한 온-폴리시 RL이 가능해진다.
우리는 Bigmath 환경에서 RLVR 설정으로 실험을 수행했고, 정책은 Qwen 2.5-VL instruct 8B에서 초기화했으며, 최대 롤아웃 길이는 4096이었다.
오프-폴리시 보정(즉, 중요도 가중치)을 쓰지 않고 학습하면 보상이 학습 도중 붕괴한다. 반면 오프-폴리시 보정 항을 추가하면 학습이 원활히 진행된다. 그리고 샘플러와 트레이너 사이에서 비트 단위 동일성을 이루면, 완전한 온-폴리시(즉, KL 발산 0)로도 매끄럽게 학습할 수 있다.
샘플러와 트레이너 간 로그확률의 KL 발산을 그려보면, 세 가지 실행 모두 상이한 양상을 보인다. 중요도 가중치를 쓰면 0.001 부근에서 가끔 스파이크가 난다. 반면 중요도 가중치 없이 실행하면 보상이 붕괴하는 시점과 비슷한 때 KL 발산이 급등한다. 물론 “진정한 온-폴리시 RL”에서는 KL 발산이 0에 평평하게 유지되어, 학습 정책과 샘플링 정책 사이에 어떠한 괴리도 없음을 보여준다.
중요도 가중치 없이 실행한 경우 Step 318 부근에서 손실이 크게 솟구치며, 이에 상응해 로그확률 KL 발산도 급등한다. 반면 오프-폴리시 보정을 쓰거나 “진정한 온-폴리시”로 실행하면 RL이 안정적으로 이어진다. “진정한 온-폴리시”를 나타내는 파란 선이 버그처럼 보일 수 있는데 — 단지 0에서 평평한 선일 뿐이다.
현대 소프트웨어 시스템은 여러 추상화 층으로 이뤄져 있다. 머신러닝에서 비결정성과 미묘한 수치 차이를 만나면, 이를 덮어두고 넘어가고 싶은 유혹을 받기 쉽다. 어차피 시스템이 “확률적”인데, 비결정성이 조금 더 있어도 괜찮지 않을까? 실패하는 유닛 테스트의 atol/rtol을 올리면 되지 않을까? 트레이너와 샘플러 사이의 로그확률 차이는 진짜 버그가 아닐지도?
우리는 이런 체념을 거부한다. 약간의 노력으로 비결정성의 근본 원인을 _이해_하고, 심지어 _해결_할 수 있다! 이 글이 커뮤니티에 추론 시스템의 비결정성을 해소하는 탄탄한 이해를 제공하고, 각자 시스템을 온전히 이해하려는 시도를 고무하길 바란다.
인용 형식은 다음과 같다:
He, Horace and Thinking Machines Lab, "Defeating Nondeterminism in LLM Inference",
Thinking Machines Lab: Connectionism, Sep 2025.
또는 BibTeX:
@article{he2025nondeterminism,
author = {Horace He and Thinking Machines Lab},
title = {Defeating Nondeterminism in LLM Inference},
journal = {Thinking Machines Lab: Connectionism},
year = {2025},
note = {https://thinkingmachines.ai/blog/defeating-nondeterminism-in-llm-inference/},
doi = {10.64434/tml.20250910}
}