ESRGAN 小註解

我參考 ESRGAN [11] 的論文導讀[1], 把其中相關的名詞做了個整理. 

基本上兩大重點是 RRDB 和 GAN. 這個網路模型前者的修改是拔掉 BN, 後續又出了一篇 [9] 加上 noise 處理. 後者是修改 GAN 的對抗方式, [9] 又做了一些調整就略過. 以下把相關的名詞稍微整理如下.

首先看到 BN 這個詞, 文章中指的是第二種, 但是也順便把第一種 BN 列進來.

Bayesian Network (BN)

在 AI 中的主要應用是把人類的知識加入類神經網路 [2]. 如果單純使用 data training 叫做 machine learning. 既然它的名字裡面有 Bayesian, 顯然它考慮了條件機率, 所以在 training 的過程中, 我們要引入條件機率表 (conditional probability table – CPT).

條件式的因果關係就是 P(B|A) = C 的形式, 因此 BN 網路需要有一群輸入層的 feature 值, 然後以因果關係建立網路結構, 接著用條件機率把每個 node 到下一層 node 的機率以 CPT 表示, 最後會有一個輸出層, 告訴我們有幾種可能的 state, 以及各自有多少機率.

Batch Normalization (BN)

這是另外一個 BN, 在 RealSRGAN 裡面指的是這個, 事實上它也比較常用. 原理就是把輸入正規化為 N(0,1). 即 mean = 0, variance = 1, 這方法在 pattern recognition 和 classification 中也常用.  不過對 NN 來說, 它是指在 internal layer 去做 normalization, 不只是在輸入層.

RRDBNet Network [3]

在 RRDB 的實作中, 裡面有許多的 basic block 串接. Basic block  包括 Dense Block (DB) [4] 和 Residial Block (RB) [5]. Residual in Resisual Dense Block 就是把 dense block 以 residual network 的方式連起來 [3], 把這些 DB 都當作 RB 裡面的一層 layer 看待.

RRDB 在 [1] 講得比較多. BTW, 關於噪音的改進. 可以參考 [9] 這篇的開頭.

Dense Block [4]

主要是每層都和其他幾層相連. 我把它當作 full connection 的分批簡化形式. Dense Block 的串接連成 Dense Network.

Residual Block (RB) [5]

Residual Block 的概念, 像是 Dense block 的變形. 也就是可以跳過一層網路不連.

GAN

GAN 用國文來解釋很簡單, 就是正邪兩派各自去認親, 能夠訓練到邪不勝正, 那麼正的網路就大功告成了.

本圖取材自 [10].

從 coding 的角度來說, 當然還是要看到程式比較有感覺. 我覺得 [6] 的寫法滿好懂的. 若對於只熟悉C語言, 不熟 Pythone 的人需要克服這個底線障礙. 這邊的底線是 [7] 五種底線的第二種, ‘_’ 表示傳回的值有些是 don’t care. 例如:

validity, _, _ = discriminator(gen_imgs)

GAN 的基本流程當然是把生成對抗網路都各 train 一輪.  此時每輸入每一張圖, 都要生成一堆有噪音的圖 (靠隨機變數) – train generator. 然後把這一大把圖, 拿去 train discrimiator. 在生成 generator 的時候, 可以看到有製造噪音的 code 和最後的 back propogation.

# Train Generator
 
optimizer_G.zero_grad()
 
# Sample noise and labels as generator input
z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))
label_input = to_categorical(np.random.randint(0, opt.n_classes, batch_size), num_columns=opt.n_classes)
code_input = Variable(FloatTensor(np.random.uniform(1, 1, (batch_size, opt.code_dim))))
 
# Generate a batch of images
gen_imgs = generator(z, label_input, code_input)
 
# Loss measures generator’s ability to fool the discriminator
validity, _, _ = discriminator(gen_imgs)
g_loss = adversarial_loss(validity, valid)
 
g_loss.backward()
optimizer_G.step()

訓練 discrimator 網路時. 要得到 real image 在這個網路的 loss 評分 (終極目標是 loss = 0), 以及 fake image 的 loss 評分 (終極目標是 loss 很大). 兩組 loss 輸出, 取其平均值做 backward propogation 的分數. adversarial_loss 的實作細節看起來是在 Torch 的 loss.py 裡面 [8].

 
# Train Discriminator
 
optimizer_D.zero_grad()
 
# Loss for real images
real_pred, _, _ = discriminator(real_imgs)
d_real_loss = adversarial_loss(real_pred, valid)
 
# Loss for fake images
fake_pred, _, _ = discriminator(gen_imgs.detach())
d_fake_loss = adversarial_loss(fake_pred, fake)
 
# Total discriminator loss
d_loss = (d_real_loss + d_fake_loss) / 2
 
d_loss.backward() 
optimizer_D.step()

以上是普通的 GAN 的處理方式. ESRGAN 的 GAN 不是比較 real image 和 fake image 誰的分數高, 因為我們早已經知道 real image 和 fake image 各自要扮演的角色. 花太多時間讓參數長對有點浪費時間. 

ESRGAN 用的方式是讓真的和假的直接去 PK, 而不是對標準答案. 真的贏過假的參數要加分, 假的贏過真的的參數要扣分. 以下直接貼 [1] 的內容, 但用顏色標註重點.

我们可以看SRGAN的loss:

         l_d_real = self.cri_gan(real_d_pred, True, is_disc=True)
         l_d_fake = self.cri_gan(fake_d_pred, False, is_disc=True)

而ESRGAN的loss:

l_d_real = self.cri_gan( real_d_pred - torch.mean(fake_d_pred), True, is_disc=True)         
l_d_fake = self.cri_gan(fake_d_pred - torch.mean(real_d_pred), False, is_disc=True)

以上就是這篇論文會用到的重要模組. 沒有什麼創見, 就是幫入門者的障礙減少一點.

[Note]

  1. https://zhuanlan.zhihu.com/p/258532044
  2. https://towardsdatascience.com/how-to-train-a-bayesian-network-bn-using-expert-knowledge-583135d872d7
  3. https://blog.csdn.net/gwplovekimi/article/details/90032735
  4. DenseNet 學習心得
  5. https://ithelp.ithome.com.tw/articles/10204727
  6. https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/infogan/infogan.py
  7. https://towardsdatascience.com/5-different-meanings-of-underscore-in-python-3fafa6cd0379
  8. https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/loss.py
  9. https://zhuanlan.zhihu.com/p/393350811
  10. https://github.com/jonbruner/generative-adversarial-networks/blob/master/gan-notebook.ipynb
  11. https://aiqianji.com/blog/article/1
分類: 技術,標籤: , , 。這篇內容的永久連結

發佈留言

發佈留言必須填寫的電子郵件地址不會公開。