Scaling up Test-Time Compute with Latent Reasoning: A Recurrent Depth Approach

2025. 4. 20. 14:08·Computer Vision Paper Review
반응형

Scaling up Test time compute 하기 위한 아주 novel한 방법을 제안하는 논문이다. 보통은 Chain of thought을 하면서 모델이 더 많은 토큰을 뱉도록 하는 방식으로 scaling 을 했는데, 이렇게 하면 inference 할때 Sequenth length 가 길어질수록, memory 사용량이 linear 하게 증가하는 문제가 있다. 이 논문에서는 이런 비효율적인 memory 사용량을  recurrent 모델을 통해서 해결을 한다. 단순히 많은 토큰을 내뱉게 해서 test time compute을 scaling up 하는 대신에, recurrent 모델을 쓰는데 어떻게 했는지 하나씩 보자.


Model Architecture

아이디어만 보자면, 우선 hello 에 대한 latent output e가 normal distribution에서 sampling된 S0과 합쳐져서 s1 이 나온다. 그리고 이 s1 이 다시 e 와 합쳐져서 s2가 나오고~~ 이런식으로 쭉 반복해서, 마지막에는 Sr 이 나온다. 이거를 transformer 에 태워서 다음에 최종적으로 next prediction을 하게 된다.

 

위 그림을 식으로 표현하면 다음과 같다.

이 식이 뭘 의미하는지 위에서부터 하나씩 자세히 보자.

 

1) 첫번쨰 식 (e=p(x))

x가 들어오면 일단 vocab 차원으로 맞춰주고, 이걸 encoding을 통해 (n x h)  hidden state 차원으로 맞춰서 transformer 에 태운다. 

 

2) 두번째, 세번째 식

 (e=p(x))를 통해 나온 e와, S0 (노이즈) 를 섞어서, concat을 시킨다. 그럼 (nxh) + (nxh) 니까 (nx2h) 가 되니까, 인코더를 태워서 다시 nxh 로 맞춰준다. 그리고 transformer를 태워서 RMS norm 까지 거치면 Si가 나온다.

이거를 다시 원래 첫번째 식에서 나온 e랑 다시 섞어서 concat을 하고 똑같은 과정을 계속 recurrent 하게 반복한다.

 

3) 네번째 식

마지막 recurrent를 쭉 돌려서 나온 Sr을 transformer를 태우고 디코더를 태워서 (nxv) vocab size 차원에 맞춘 p를 output 하면 끝!


Training Objective

전체 loss function은 다음과 같다.

여기서 \( m_\theta(x, r) \)은 입력 \( x \)에 대해 \( r \)번 반복하여 생성한 모델 출력이고, \( x' \)는 시퀀스 \( x \)를 왼쪽으로 한 칸 shift한 값으로, 다음 토큰들이다. 즉, 모델이 생성한 다음 토큰과 실제 정답 토큰의 차이를 비교해 loss를 계산한다. 여기까지는 일반적인 next token prediction loss이다.

 

 \( \mathbb{E}_{r \sim \Lambda} \) 이 부분이 조금 특이해서 잘 봐야하는데

반복 횟수 \( r \)은 log-normal Poisson 분포를 따르며, 평균 반복 횟수가 \( \bar{r} + 1 \)이 되도록 다음의 절차를 통해 샘플링한다.

  1. 먼저 로그 정규 분포를 따르는 \( \tau \)를 샘플링한다:\[ \tau \sim \mathcal{N}\left( \log(\bar{r}) - \frac{1}{2} \sigma^2, \sigma \right) \]
  2. 이후 \( \tau \)의 지수값을 평균으로 하는 포아송 분포에서 \( r \)을 샘플링하고, 최종적으로 1을 더해준다:\[ r \sim \mathcal{P}(e^\tau) + 1 \]

이렇게 하면 평균은 \( \bar{r} + 1 \)이 되며, 대부분의 값은 \( \bar{r} \)보다 작지만 가끔 매우 큰 반복 횟수가 샘플링되는 heavy tail 분포의 특성을 갖는다. 이로 인해 모델은 일반적인 경우와 드물지만 긴 시퀀스까지 모두 잘 처리할 수 있도록 학습된다.

반응형
'Computer Vision Paper Review' 카테고리의 다른 글
  • DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models
  • DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via Reinforcement Learning
  • Deepseek v3 해부식
  • Vision-R1: Evolving Human-Free Alignment in Large Vision-Language Models via Vision-Guided Reinforcement Learning
happy88
happy88
  • happy88
    happy8825
    happy88
  • 전체
    오늘
    어제
    • 분류 전체보기 (97) N
      • NLP (7)
      • Computer Vision Paper Revie.. (53) N
      • 이것저것 (5)
      • About me (3)
      • Linear Algebra (7)
      • 개발 (2)
      • Statistics (12)
      • Flow Matching (7)
  • 블로그 메뉴

    • 홈
    • 태그
    • 방명록
    • 글쓰기
  • 링크

  • 공지사항

  • 인기 글

  • 태그

  • 최근 댓글

  • 최근 글

  • hELLO· Designed By정상우.v4.10.0
happy88
Scaling up Test-Time Compute with Latent Reasoning: A Recurrent Depth Approach
상단으로

티스토리툴바