14. Score-based Model 2
Main NCSN!
Key challenge for building complex gen-models
$x \rightarrow \prod_{A}f_{\theta}(x) \rightarrow O \rightarrow e^{f_{\theta}(x)} \rightarrow O = \frac{e^{f_{\theta}(x)}}{Z_\theta} = P_{\theta}(x).$
- Maybe negative
$\overline{Z}{\theta} = \int e^{f{\theta}(x)}dx$
Computing $Z_\theta$ in general:
- Intractable.
How to tackle intractable $Z_{\theta} = \int e^{f_{\theta}(x)}dx$?
Approximate $Z_\theta$ by energy-based
- ↳ inaccurate $P(x)$
Restricted
In models
↳ restricted to flow, auto regressive, VAE
Modeling generation only
- ↳ no evaluation for $p(x)$ (DL)
Data modeling by Estimating Score Functions
$P(X)$는 확률 분포함수.
$\nabla_{x}\log P(x)$: Score function (preserves all info.)
- ↳ 가장 빨리 증가하는 방향
c.f.) $\nabla_\theta \log P(X \theta)$: Fisher score function.
Bypass $Z_\theta$
$p_{\theta}(x) = \frac{e^{f_{\theta}(x)}}{Z_{\theta}}$
$\log p_{\theta}(x) = f_{\theta}(x) - \log Z_{\theta}$
$\nabla_{x}\log P_{\theta}(x) = \nabla_{x}f_{\theta}(x) - \nabla_{x}\log \overline{Z_{\theta}}$
- $= S_{\theta}(x)$ (Target)
Learning objective
Given: $x_{1}, \cdot\cdot\cdot, x_{N} \sim P_{data}(x)$
Goal: $\nabla_x \log P_{data}(X) \leftarrow \text{Score model}$
Score model
비교?
$S_{\theta}(x): \mathbb{R}^{d} \rightarrow \mathbb{R}^{d} \approx \nabla_x \log P_{data}(x)$
Fisher divergence
$\frac{1}{2}E_{p_{data}}[ \nabla_x \log p_{data}(x) - S_\theta(x) _{2}^{2}]: \text{learning target}$ 근데 $\nabla_x \log P_{data}(x)$를 모름.
$FD = E_{p_{data}}[\frac{1}{2} S_\theta(x) {2}^{2} + trace(\nabla{x}S_{\theta}(x))] + \text{const.}$ $\approx \frac{1}{N}\sum_{i=1}^{N}[\frac{1}{2} S_{\theta}(x_{i}) {2}^{2} + trace(\nabla_x S{\theta}(x_{i}))]$ ex) $S_{\theta}(x) = [3x_1, 4x_2^2, 5]$
$ S_{\theta}(x) {2}^{2} = 9{x{1}}^{2} + 16x_{2}^{4} + 25$ - $trace(\nabla_{x}S_{\theta}(x)) = 3 + 8x_2 + 0$
$\nabla_{x}S_{\theta}(x) = \begin{bmatrix} \frac{\partial S_{\theta 1}(x)}{\partial x_{1}} & \dots \ \dots & \frac{\partial S_{\theta d}(x)}{\partial x_{d}} \end{bmatrix}$
- 합. 근데 $d$ 다 할 수 없으니
- Backprop은 X의 차원만큼! $\mathcal{O}(\dim(x))$: 200만 차원. trace 구해야 하고 $S_{\theta}(x) \in \mathbb{R}^{d \times d}$
Denoising Score matching (noise 섞어서 구하자)
: 어려움
$X \sim P_{data}(x)$
$\leftarrow \leftarrow \leftarrow (\times \sigma)$
$\sigma$ 만큼
$q_{\sigma}(\tilde{x}) = \int p_{data}(x)q_{\sigma}(\tilde{x} x)dx$ ↳ Denoising score matching (2011)
$\frac{1}{2} E_{p_{data}} E_{q_\sigma(\tilde{x} x)} [ \nabla_{\tilde{x}} \log q_{\sigma}(\tilde{x} x) - S_\theta(\tilde{x}) _{2}^{2}]$ Noise kernel.
섞어라
Easy to compute
안 backprop 하면됨
$q_{\sigma}(\tilde{x} x) = \mathcal{N}(\tilde{x} x, \sigma^{2}I)$ 이거 최소화하도록 $\nabla_{\tilde{x}}\log q_{\sigma}(\tilde{x} x) = -\frac{\tilde{x}-x}{\sigma^{2}}$ (Closed form)
Scalable
- $S_{\theta}(x) \approx \nabla_{x}\log q_{\sigma}(x) \ne \nabla_{x}\log P_{data}(x)$
How?
${x_{1}, x_{2}, \cdot\cdot\cdot, x_{n}} \sim P_{data}(x)$
- $\downarrow$ noise
${\tilde{x}{1}, \tilde{x}{2}, \cdot\cdot\cdot, \tilde{x}{n}} \sim q{\sigma}(\tilde{x})$
$\tilde{x}{i} \sim q{\sigma}(\tilde{x}_{i} x_{i})$
Estimate denoising score matching loss
$\frac{1}{2n}\sum_{i=1}^{n}[ S_{\theta}(\tilde{x}{i}) - \nabla{\tilde{x}}\log q_{\sigma}(\tilde{x}_{i} x_{i}) _{2}^{2}]$ - ↳ substitution
$\frac{1}{2n}\sum_{i=1}^{n}[ S_{\theta}(\tilde{x}{i}) + \frac{\tilde{x}{i}-x_{i}}{\sigma^{2}} _{2}^{2}]$ SGD.
- $\Rightarrow$ 하지만, noise free를 원해요…
Pitfall
① Small $\sigma$를 해야하나…
② Loss variance가 $\sigma \rightarrow 0$ 이면 너무 커져 training 어렵다.
즉, trade-off가 있다 $\rightarrow$ Sliced score matching
- ($\sigma \rightarrow 0$으로 보내면… 문제가 생길 수 있음)
$J = \frac{1}{2}E_{x \sim p_{data}}E_{z \sim \mathcal{N}(0, I)}[ S_\theta(x+\sigma z) + \frac{z}{\sigma} _{2}^{2}]$ - $\tilde{x} = x + \sigma z$
$= \frac{1}{2}E_{x \sim p_{data}}E_{z \sim \mathcal{N}(0, I)}[ S_{\theta}(x+\sigma z) {2}^{2} + 2S{\theta}(x+\sigma z)^{\top}\frac{z}{\sigma} + \frac{ z _{2}^{2}}{\sigma^{2}}]$ $\sigma \rightarrow 0$
$Var(\frac{z}{\sigma}) \rightarrow \infty$
Sampling
How to sample new data?
[Drawing of data points converging to a manifold]
$S_{\theta}(x) \approx \nabla_{x}\log P_{data}(x)$
Assume: 그냥 score function 따라가면… 하나로 collapse
$\tilde{x}{t+1} \leftarrow \tilde{x}{t} + \frac{\epsilon}{2}S_{\theta}(\tilde{x}_{t})$ in data space.
Langevin dynamics
↳ Follow noisy version of score function!
$\tilde{x}{t+1} \leftarrow \tilde{x}{t} + \frac{\epsilon}{2}S_{\theta}(\tilde{x}_{t}) + \sqrt{\epsilon}z_t$ where $z_t \sim \mathcal{N}(0,I)$
Procedure:
$x^{0} \sim \Pi(x)$
$t \leftarrow 1, 2, \cdot\cdot\cdot, T$
- $z^{t} \sim \mathcal{N}(0,I)$
$x^{t} \leftarrow x^{t-1} + \frac{\epsilon}{2}\nabla_{x}\log P(x^{t-1}) + \sqrt{\epsilon}z^{t}$
- New pos $\leftarrow$ Current pos + Score function + Gaussian noise
$\epsilon \rightarrow 0$, $T \rightarrow \infty$이면 $X \sim P(x)$이다.
$S_{\theta}(x) \approx \nabla_{x}\log P(x)$
Score matching + Langevin dynamics not working….
Accurate vs Inaccurate regions
$\Rightarrow$ ① Low-density: not enough samples.
② Weight 불가능
③ Manifold hypothesis
$\rightarrow$ Gaussian perturbation.
Noise 크게해서 Langevin dynamics guide 하기.
$\sigma$
- ↳ 작게 해야 $p_{\theta} \approx q$ 이므로 이후 $\sigma$ 줄여야함.
| Noise | High | Low |
|---|---|---|
| Sample quality | $\downarrow$ | $\uparrow$ |
| Score est. | $\uparrow$ | $\downarrow$ |
즉, multi-scale (Cascade) 해야한다.
$\sigma_1 > \sigma_2 > \dots > \sigma_L$
Guide 위해 noise 첨가 $\rightarrow$ Sampling quality 위해 제거
Annealed Langevin dynamics
$\mathcal{N}{\sigma_1} \rightarrow \mathcal{N}{\sigma_2} \rightarrow \cdot\cdot\cdot \rightarrow \mathcal{N}_{\sigma_L}$
(Cool down)
Sample 뽑고, $\hat{x}{t} \leftarrow \hat{x}{t-1} + \frac{\alpha_{i}}{2}S_{\theta}(\tilde{x}{t-1}, \sigma{i}) + \sqrt{\alpha}{i}z{t}$
- $\epsilon$ scheduling.
↳ 그러면 매 noise level 마다 Score function 정의?
$\Rightarrow$ Conditional score model (NCSN)
- ↳ $\sigma$를 $S_\theta(\cdot, \cdot)$도 input 처리하여 학습. (Jointly trained)
Loss Function
$\mathcal{L} = \frac{1}{L}\sum_{i=1}^{L} \lambda(\sigma_i) E_{p_{\sigma_i}} [ \nabla_{x}\log p_{\sigma_i}(x) - S_{\theta}(x, \sigma_i) _{2}^{2}]$ $S_\theta$: 모든 noise level 학습.
$\lambda(\sigma_i)$: Noise level 별 weight.
$\lambda(\sigma_i) = \sigma_i^{2}$
Practical recommendations
$\sigma_{1} < \sigma_{2} < … < \sigma_{L}$
$\sigma_1$: 충분히 작아야함
$\sigma_L$: Max pairwise distance 급이어야함.
$L \approx 100 \sim 1000$
$S_\theta(x, \sigma_i)$: U-Net, Skip connection.
- 이정도 noise