Diffusion 부터 복습해보자면!
Forward Process (original image에서 노이즈 추가하는 과정)
우리는 x0에서 노이즈를 추가해서 t시점의 이미지를 만들어내야하는데, 첫번째 식처럼 하나하나 이전 단계의 이미지를 이용해서 곱하고 있으면 너무 시간이 오래걸림.
결국 t-1에서 t를 얻어내기 위해서는 t-1시점에서 노이즈를 추가해주면 되는건데 (이게 두번째 식) 이렇게 써주면 좋은 이유가 있음
이렇게 표현 가능하다는건데, 하나씩 다 곱할 필요없이, reparameterize trick을 통해 nice 한 property를 얻어내는게 가능해진다는거
이 식을 조금 더 직관적으로 이해해보면, 알파t 바 부분이 t가 무한대로 커질수록 0 에 가까워짐(1보다 작은 알파값을 계속 곱하는거니까) ,
즉 q(xt|x0) 식에서 t가 무한대로 갈수록 N(0, I) 에 가까워지는것임(standard normal distribution에 수렴)
Reverse Process
이 식은 x0 과 xt 가 주어졌을 때 xt-1 을 구할 수 있는 식
근데 우리는 x0을 모르는거. 그래서 x0이 주어져 있지 않고 이전 step만 주어졌을때, 확률분포 구하는거를 네트워크를 통해 학습시켜야함.
그래서 학습시키는 함수 P에 대한 식을 새로 만들어줬는데 다음과 같음
모델 구조는 전체적으로 다음과 같이 생겼다
xt를 넣어주고, t(어떤 시점인지 함께) 넣어서 unet에 태워주면 분포의 평균과 분산을 예측하는 네트워크임.
이렇게 해서 나온 확률분포와, 정답 확률분포간을 비교하면서(kl divergence) 학습이 진행된다.
이부분이 DDPM이 나오게 된 배경이기도한데, 이런식으로 loss function을 만들게 되면, 시점별로 하나하나씩 다 비교를 해줘야 함.
그래서 DDPM에서 제안한거는
이전 스텝만을 예측하지 말고, 아예 x0(원본이미지)를 t 시점에서 바로 예측해버리자!
즉, 원래 diffusion 모델에서는 unet 네트워크가 그 다음 스텝의 mean이랑 variance를 예측했다면, 이렇게 하지 말고 어떤 시점이든간에 x0의 확률 분포를 바로 예측하자는것!
x0, xt, e 의 관계?
결국 xt는 x0이랑 엡실론 의 linear combination이라고 할 수 있음.
즉, x0이랑 엡실론을 알면 xt를 구할 수 있고, xt랑 엡실론을 알면 x0을 구할수 있고,,,
그래서 원래 디퓨전 모델에서 하려고 했던거는 xt와 엡실론을 이용해서 x0을 구하자! 였는데, DDPM에서는 x0이랑 xt를 통해서 엡실론을 학습시켜보자! 로 바뀐거.(x0을 바로 구해도 되는데 이렇게 했을때 성능이 더 잘나왔다고 함..)
즉 이미지의 확률분포를 예측하는게 아니라 엡실론(노이즈)의 분포를 예측하는거
xt 시점에서 x0으로 가는 노이즈를 예측하고, 그걸 통해 xt-1을 구하고, xt-1에서 x0으로 가는 노이즈 예측하고 그걸 통해서 xt-3 구하고~~ 계속 반복
Loss function 은 다음과 같다
KL divergence를 썼던 diffusion 초기 모델이랑 달리, t시점과 x0 시점으로 가는 노이즈를 예측한거랑, ground truth 노이즈의 차이를 최소화하는게 DDPM의 loss function
DDPM의 문제점?
너무 느리다..
x0시점을 바로 예측을 했지만, 결국 보면 t시점에서 x0을 예측해서 t-1 시점으로 가고, 예측해서 t-1시점으로 가고 … timestep 별로 하나하나 다 해줘야 하는것.
그러면 solution으로 t를 띄엄띄엄 해주면 안되나? —> timestep이 적어질수록 ddpm의 퀄리티가 심하게 낮아진다는 문제점이 있음
기존 DDPM Reverse
𝑞(𝑥𝑡−1∣𝑥𝑡,𝑥0)=𝑁(𝑥𝑡−1;𝜇~(𝑥𝑡,𝑥0),𝛽~𝑡I)q(xt−1∣xt,x0)
=N(xt−1;μ~(xt,x0),β~tI)
𝑤ℎ𝑒𝑟𝑒,𝜇~𝑡(𝑥𝑡,𝑥0):=𝛼ˉ𝑡−1𝛽𝑡1−𝛼𝑡ˉ𝑥0 +𝛼𝑡(1−𝛼ˉ𝑡−1)1−𝛼ˉ𝑡𝑥𝑡, 𝛽𝑡~:=(1−𝛼ˉ𝑡−1)1−𝛼ˉ𝑡𝛽𝑡where,μ~t(xt,x0):=1−αtˉαˉt−1βtx0 +1−αˉtαt(1−αˉt−1)xt, βt~:=1−αˉt(1−αˉt−1)βt
이거를 좀 더 generalized 형태인 다음과 같은 수식으로 바꿈
이 식이 나오게 된게, 밑에 식은 DDPM에서도 언급되었던 식인데, 밑에 식의 조건이 만족되도록 설계한것.
(25) 번 식을 잘 보면 시그마 t는 우리가 정할 수 있는 값인데, 이 시그마 t를 다음과 같이 대입하면 DDPM이랑 똑같아지고
시그마 t를 0으로 두게 된다면, covariance가 0이 되어버리니까 deterministic 한 function이 됨.
이게 DDIM!
DDPM 에서는 t 에서 t-1로 갈때 어느정도의 stochasticity가 있었다면, DDIM에서는 이걸 아예 0으로 줘서, deterministic하게 딱 정해진 함수로 만든것.
아까 DDPM에서 timestep을 적게하면 퀄리티가 급격하게 낮아진다고 했었는데, DDIM의 성능을 보면 timestep을 줄여도 성능이 크게 줄어들지는 않는걸 볼수 있음
근데 여기서도 좀 재밌는건 step수가 적으면 ddim 이 더 점수가 좋은데 step이 1000으로 가면 DDPM 더 점수가 좋음
전체적인 학습 과정
여기서 나온 xt 로 xt-1 구하고, xt-1에서 xt-2 구하고 ~~