이 글에서는, 연속적 Normalizing Flow 등을 다룰 때 “최종 분포” p1
의 밀도를 구하거나, 특히 log p1(x)
값을 계산하고자 하는 과정을 설명한다.
차원이 큰 문제에서 Jacobian determinant를 직접 계산하기가 복잡하고 비싸므로, continuity equation과 flow 개념을 활용해 “로그밀도”의 시간 변화를 추적한 뒤, 초기 로그밀도와의 관계로 log p1(x)
를 효율적으로 구하는 방법, ODE solver 의 코드까지 살펴보겠다.
먼저 Continuity Equation 에서 시작해보자.
이렇게 쭉 전개하면 다음과 같은 식이 나온다.
최종 식을 해석해보자면, "최종 로그밀도(우리가 구하려고자 하는)” = “초기 로그밀도" - divergence적분인거다.
어느정도 간단하게 만들어놓긴 했지만 아직도 div 부분 계산이 어렵다. 이식을 간단하게 풀기 위해 Hutchinson’s Trick를 사용한다.
Hutchinson’s Trick이란,,
행렬 \(A\)의 trace(\(\mathrm{trace}(A)\))를 직접 구하지 않고,
임의의 무작위 벡터 \(\mathbf{z}\)로부터 \(\mathbf{z}^\top A\,\mathbf{z}\)를 샘플링해 평균 냄으로써 \(\mathrm{trace}(A)\)를 추정하는 방식이다. 샘플이 많으면 lotp 에 의해 저렇게 근사가 가능하다고 하니 참 근사하다.
그래서 결국 우리가 구하려고 하는 logP1(x)는
이 식에 의해 다음과 같이 써줄 수 있다.
즉, 시간 1에서 \(x\)를 주고 \(\bigl(f(1),g(1)\bigr)\)=(x,0) 으로 역방향으로 \(f(0),g(0)\)로부터 \(\log p_1(x)\)를 구한다.
Flow Matching ODE Solver 코드
아래 코드는 flow_matching 라이브러리에서 제공하는 ODESolver
클래스와 학습된 velocity_model
을 사용해서, 주어진 점 x_1
에 대한 log p1(x)
를 계산할 수 있게 해준다. 즉, 초기 분포를 p0
로 두고, 미분방정식을 역방향으로 적분해서 “시간 1에서 위치가 x_1
”인 점들의 초기 위치 x_0
와 최종 로그밀도 log_p1
값을 구하는 코드이다.
from flow_matching.solver import ODESolver
from flow_matching.utils import ModelWrapper
from torch.distributions.normal import Normal
# velocity_model: model(x_t, t) -> u_t(x_t)
velocity_model: ModelWrapper = ...
# 우리가 log p1(x)를 구하고 싶은 점(여기서는 표준정규분포에서 임의 생성)
x_1 = torch.randn(batch_size, *data_dim)
# ODESolver 생성
solver = ODESolver(velocity_model=velocity_model)
# 초기 분포 p0 = Gaussian(0, I)의 log_density 정의
gaussian_log_density = Normal(
torch.zeros(size=data_dim),
torch.ones(size=data_dim)
).log_prob
num_steps = 100
# compute_likelihood를 호출해서 x_0, log_p1를 구한다
x_0, log_p1 = solver.compute_likelihood(
x_1=x_1,
method='midpoint',
step_size=1.0 / num_steps,
log_p0=gaussian_log_density
)
최종 시점에서 x_1
이라는 점이 있을 때, 이를 ODE 역방향으로 추적해 보면 초기 시점(0)에서 어느 위치 x_0
였고, 전체 div 적분을 통해 log p1(x1)
를 어떻게 계산할 수 있는가를 자동으로 해주는 구조이다.
따라서 실제 데이터셋을 x_1
자리에 넣으면, 그 점의 log p1
를 구할 수 있고, 동시에 그 점이 " 초기 분포 p0" 공간에서 어떤 위치(x_0
)였는지도 알수있다.
Training flow models
그럼 Instantaneous Change of Variables를 통해 만들어진 ODE식을 어떻게 실제 학습과정에서 활용할까. 결국 “Flow-based” 모델을 “데이터 분포에 맞게” 학습하기 위해서는, log p1θ(x)
를 계산할 수 있어야 하고, 이를 최대화(또는 KL발산 최소화)해야 한다.
학습목표
Flow-based 모델에서, 학습 파라미터 θ
를 갖는 속도장 utθ
를 정의한다고 하자. 이를 통해 “시간 1에서의 분포” p1θ
가 형성된다. 우리는 이 p1θ
가, 실제 데이터 분포 q
와 가깝도록 하고 싶다. 즉,
$$ p_1^\theta \;\approx\; q. \quad $$
예를 들어, KL발산(Kullback-Leibler divergence)을 이용해 아래와 같은 목표함수를 최소화한다고 할 수 있다:
$$ \mathcal{L}(\theta) \;=\; D_{\mathrm{KL}}\!\bigl(q,\;p_1^\theta\bigr) \;=\; -\,\mathbb{E}_{Y \sim q}\bigl[\log p_1^\theta(Y)\bigr] \;+\; \text{constant.} \quad $$
여기서 p1θ
는 “X1 = ψ1θ(X0)”의 분포, ψtθ
는 속도장 utθ
에 의해 결정되는 “흐름(flow)”이고, 우리가 구하고자 하는 것은 log p1θ(Y)
의 값이다.
ODE를 통한 log p1θ(x) 계산
이전에 언급한 Instantaneous Change of Variables을 통해, log p1θ(x)
를 unbiased하게 추정할 수 있다. 다시 말해서, ODE를 정확히 simulate하면 log p1θ(x)
를 divergence 적분과 초기 밀도 p0
를 통해 산출 가능하다.
그러나 학습 과정에서 매 스텝마다 “정확한 ODE 해”를 구하는 것은 많은 계산 비용이 들고, 근사로 처리할 경우 그라디언트가 편향될 수 있다는 단점이 있다.
위와 같은 KL발산(ℒ(θ)
)을 최소화하려면, $$ \nabla_\theta \; \mathcal{L}(\theta) $$ 를 반복적으로 구해야 한다. 그런데 log p1θ(x)
가 “ODE 적분(시간 0~1)을 통해 정의”된 함수라면, 매번 정밀한 ODE 시뮬레이션이 필요하고, 이는 매우 비싼 연산이 될 수 있다.
Flow matching?
뒤에서 소개될 Flow matching은 학습 과정에서 매번 ODE를 풀지 않아도 되는 “simulation-free” 방식이다. 즉 기존 Flow based 모델이 Instantaneous Change of Variables 을 통해 학습 중 ODE를 풀었기에 계산 부담이 컸던 반면, Flow Matching은 그 부분을 효율적으로 해결한다.