Scaling up Test time compute 하기 위한 아주 novel한 방법을 제안하는 논문이다. 보통은 Chain of thought을 하면서 모델이 더 많은 토큰을 뱉도록 하는 방식으로 scaling 을 했는데, 이렇게 하면 inference 할때 Sequenth length 가 길어질수록, memory 사용량이 linear 하게 증가하는 문제가 있다. 이 논문에서는 이런 비효율적인 memory 사용량을 recurrent 모델을 통해서 해결을 한다. 단순히 많은 토큰을 내뱉게 해서 test time compute을 scaling up 하는 대신에, recurrent 모델을 쓰는데 어떻게 했는지 하나씩 보자.
Model Architecture
아이디어만 보자면, 우선 hello 에 대한 latent output e가 normal distribution에서 sampling된 S0과 합쳐져서 s1 이 나온다. 그리고 이 s1 이 다시 e 와 합쳐져서 s2가 나오고~~ 이런식으로 쭉 반복해서, 마지막에는 Sr 이 나온다. 이거를 transformer 에 태워서 다음에 최종적으로 next prediction을 하게 된다.
위 그림을 식으로 표현하면 다음과 같다.
이 식이 뭘 의미하는지 위에서부터 하나씩 자세히 보자.
1) 첫번쨰 식 (e=p(x))
x가 들어오면 일단 vocab 차원으로 맞춰주고, 이걸 encoding을 통해 (n x h) hidden state 차원으로 맞춰서 transformer 에 태운다.
2) 두번째, 세번째 식
(e=p(x))를 통해 나온 e와, S0 (노이즈) 를 섞어서, concat을 시킨다. 그럼 (nxh) + (nxh) 니까 (nx2h) 가 되니까, 인코더를 태워서 다시 nxh 로 맞춰준다. 그리고 transformer를 태워서 RMS norm 까지 거치면 Si가 나온다.
이거를 다시 원래 첫번째 식에서 나온 e랑 다시 섞어서 concat을 하고 똑같은 과정을 계속 recurrent 하게 반복한다.
3) 네번째 식
마지막 recurrent를 쭉 돌려서 나온 Sr을 transformer를 태우고 디코더를 태워서 (nxv) vocab size 차원에 맞춘 p를 output 하면 끝!
Training Objective
전체 loss function은 다음과 같다.
여기서 \( m_\theta(x, r) \)은 입력 \( x \)에 대해 \( r \)번 반복하여 생성한 모델 출력이고, \( x' \)는 시퀀스 \( x \)를 왼쪽으로 한 칸 shift한 값으로, 다음 토큰들이다. 즉, 모델이 생성한 다음 토큰과 실제 정답 토큰의 차이를 비교해 loss를 계산한다. 여기까지는 일반적인 next token prediction loss이다.
\( \mathbb{E}_{r \sim \Lambda} \) 이 부분이 조금 특이해서 잘 봐야하는데
반복 횟수 \( r \)은 log-normal Poisson 분포를 따르며, 평균 반복 횟수가 \( \bar{r} + 1 \)이 되도록 다음의 절차를 통해 샘플링한다.
- 먼저 로그 정규 분포를 따르는 \( \tau \)를 샘플링한다:\[ \tau \sim \mathcal{N}\left( \log(\bar{r}) - \frac{1}{2} \sigma^2, \sigma \right) \]
- 이후 \( \tau \)의 지수값을 평균으로 하는 포아송 분포에서 \( r \)을 샘플링하고, 최종적으로 1을 더해준다:\[ r \sim \mathcal{P}(e^\tau) + 1 \]
이렇게 하면 평균은 \( \bar{r} + 1 \)이 되며, 대부분의 값은 \( \bar{r} \)보다 작지만 가끔 매우 큰 반복 횟수가 샘플링되는 heavy tail 분포의 특성을 갖는다. 이로 인해 모델은 일반적인 경우와 드물지만 긴 시퀀스까지 모두 잘 처리할 수 있도록 학습된다.