91视频免费?看_蜜芽MY188精品TV在线观看_国产免费无遮挡在线观看视频_深夜国产_亚洲精品欧洲精品_欧美黑人粗暴多交

徐土豆
認證:優質創作者
所在專題目錄 查看專題
圖文多模態語義融合前的語義對齊——一種單雙混合塔多模態模型
在多模態模型訓練時,如何合適地融合單模態損失
FILIP: 一種基于交互的細粒度圖文預訓練模型
ERNIE VIL 2.0,多模態模型的一種多視角預訓練范式
VQ-VAE的實現方法分析——一種基于梯度回調的方法
【論文極速讀】視頻檢索中的模態均衡方法
作者動態 更多
給定計算預算下的最佳LLM模型尺寸與預訓練數據量分配
05-19 09:33
大模型推理時的尺度擴展定律
05-18 10:32
世界多胞體與世界模型
05-13 09:42
獎勵模型中的尺度擴展定律和獎勵劫持
05-12 08:41
MeCo——給預訓練數據增加源信息,就能減少33%的訓練量并且提升效果
05-08 09:13

VQ-VAE的實現方法分析——一種基于梯度回調的方法

筆者在前文 [2] 中曾經介紹過VQ-VAE模型,如Fig 1.所示,該模型基于最近鄰查找的方式從字典中查找其索引,作為其稀疏化后的令牌,具體細節可見博文[2]。

Fig 1. 通過最近鄰方法在字典里面查找稀疏令牌,作為稀疏編碼的結果,然后通過反查字典可以對feature map進行恢復。整個框架中有若干參數需要學習,分別是encoder,decoder網絡參數和Embedding space字典的參數。然而稀疏編碼的過程由于出現了最近鄰方法,這個過程顯然是無法傳遞梯度的,為了實現編碼器的更新,可以考慮將解碼器的梯度直接拷貝到編碼器中。假設對于編碼后恢復的而言,其每個元素表示為,那么對于其中某個元素的梯度表示為,同理,對于編碼后的而言,同樣有? ,令? 。

那么對于編碼器的梯度就可以表示為 。在詳細分析代碼實現邏輯之前,讓我們回顧下其損失函數,如(1-1)所示,其中的為停止梯度函數,表示該函數無梯度傳導。decoder的參數通過第一項損失項進行更新(這部分損失可通過MSE損失建模),稱之為重建損失。encoder參數通過第一項和第三項損失進行更新,其中第一項是重建損失,第三項是為了encoder編碼產出和embedding space進行對齊而設計的,由于此時通過函數停止了梯度,因此此時的參數不會得到更新。Embedding space的參數通過第二項損失項進行更新,通過將encoder編碼結果進行停止梯度,我們只對E \mathcal{E}E進行參數更新。

Fig 2. 通過梯度拷貝,將decoder的梯度拷貝到encoder中。

那么在代碼中如何實現這些邏輯呢?我們首先可以參考[3]項目中的實現。我們首先分析model.py文件中的forward函數,字典定義為一個nn.Embedding層(Code 1.1),其參數就是self.dict.weight,那么求最近鄰的操作就如Code 1.2所示。Code 1.3將最近鄰的索引結果(也即是稀疏化后的視覺令牌),在字典中進行查詢,對feature map進行恢復。因此W_j的形狀和Z是一致的。此時Code 1.4中對Z和W_j進行detach,這個detach的作用之前在博文[4]中闡述過,本文不進行累述,其主要作用可視為是停止了該節點開始的梯度傳導,也即是用于實現公式(1-1)中的。

Code 1. model.py的主要邏輯

def __init__(self,...):
	...
	self.dict = nn.Embedding(k_dim, z_dim) # Code 1.1
	
def forward(self, x):
     h = self.encoder(x) # (?, z_dim*2, 1, 1)
     sz = h.size()
     
     # BCWH -> BWHC
     org_h = h
     h = h.permute(0,2,3,1)
     h = h.contiguous()
     Z = h.view(-1,self.z_dim)
     W = self.dict.weight
	 
	 # Code 1.2
     def L2_dist(a,b):
         return ((a - b) ** 2)
     # Sample nearest embedding
     j = L2_dist(Z[:,None],W[None,:]).sum(2).min(1)[1]
	 
	 # Code 1.3
     W_j = W[j]

     # Code 1.4, Stop gradients
     Z_sg = Z.detach()
     W_j_sg = W_j.detach()

     # BWHC -> BCWH
     h = W_j.view(sz[0],sz[2],sz[3],sz[1])
     h = h.permute(0,3,1,2)
	 
	 # Code 1.5, gradient hook register
     def hook(grad):
         nonlocal org_h
         self.saved_grad = grad
         self.saved_h = org_h
         return grad

     h.register_hook(hook)
     
     # Code 1.6, losses
     return self.decoder(h), L2_dist(Z,W_j_sg).sum(1).mean(), L2_dist(Z_sg,W_j).sum(1).mean()

# Code 1.7, back propagation for encoder
def bwd(self):
    self.saved_h.backward(self.saved_grad)

此時有一個比較有意思的函數調用,如Code 1.5所示,此處的h.register_hook(hook_fn)表示對張量h注冊了個回調鉤子函數 hook_fn,我們先看下這個函數具體作用是什么,從官網的API信息[5]中可以知道,當每次對這個張量進行梯度計算的時候,都會調用這個回調函數hook_fn。hook_fn的輸入是該張量的原始梯度grad_orig,hook_fn會對梯度進行變換得到grad_new = hook_fn(grad_orig),并且將grad_orig更新為grad_new。這個功能可以讓我們實現將decoder的梯度賦值到encoder中,我們且看是如何實現的。我們留意到其對h,也即是W_j的結果進行了注冊回調,我們也知道W_j和Z的形狀是一致的,此時我們希望 ,因此我們需要以某種方式緩存下Z和W_j的梯度,在梯度反向傳播的時候,將W_j的梯度賦值到Z的梯度上,這也就是回調hook的目的——緩存下此時W_j的梯度和原始的Z節點。 在Code 1.6就開始構建decoder的輸出以及? 和這兩個loss了,那么何時我們對其encoder的梯度進行賦值呢?我們繼續看到solver.py文件~

def hook(grad):
	nonlocal org_h
	 self.saved_grad = grad
	 self.saved_h = org_h
	 return grad

在solver.py中,最主要的邏輯如下所示,其中的self.G(x)即是Code 1所示的forward()邏輯,對于其輸出的解碼器輸出out,構建重建損失,對重建損失loss_rec和其他倆對齊損失loss_e1和loss_e2進行加和后得到loss,對loss進行梯度計算(注意此時需要將retain_graph設置為True,以保留葉子節點的梯度,具體作用見博文[6])。注意到此時由于最近鄰查表的引入,loss.backward(retain_graph=True)只對decoder進行了梯度計算,此時為了對encoder也進行梯度計算,還需要進行self.G.bwd(),這個也正是我們剛才提到的,將W_j的梯度賦值到Z的梯度上,我們且看看如何實現的。如Code 1.7所示,self.G.bwd()的邏輯很簡單,對緩存的Z進行梯度『賦值』為緩存下來的W_j梯度,但是準確的說,此處并不是對Z的梯度賦值,而是制定了計算Z梯度的前繼梯度為self.saved_grad(梯度計算是鏈式法則,這意味著梯度計算勢必有前繼和后續),我們在附錄里面會舉個例子說明tensor.backward()和tensor.register_hook()的作用??偠灾?,通過調用self.G.bwd()我們可以對encoder的梯度也進行計算了,最后調用optimizer.step()進行參數更新即可了。

def bwd(self):
    self.saved_h.backward(self.saved_grad)

Code 2. solver.py的主要邏輯

# ================== Train G ================== #
# Train with real images (VQ-VAE)
out, loss_e1, loss_e2 = self.G(x)
loss_rec = reconst_loss(out, x)

loss = loss_rec + loss_e1 + self.vq_beta * loss_e2
self.g_optimizer.zero_grad()

# For decoder
loss.backward(retain_graph=True)

# For encoder
self.G.bwd()

self.g_optimizer.step()

附錄A. tensor.backward()和tensor.register_hook()的作用

>>> v = torch.tensor([0., 0., 0.], requires_grad=True)
>>> h = v.register_hook(lambda grad: grad * 2)  # 梯度翻倍
>>> v.backward(torch.tensor([1., 2., 3.])) # v的梯度前繼為[1, 2, 3]
>>> v.grad # 因此輸出的梯度為[2, 4, 6]

 2
 4
 6
[torch.FloatTensor of size (3,)]

>>> h.remove()  # removes the hook

Reference

[1]. Van Den Oord, Aaron, and Oriol Vinyals. “Neural discrete representation learning.” Advances in neural information processing systems 30 (2017).

[2]. https://blog.csdn.net/LoseInVain/article/details/129224424, 【論文極速讀】VQ-VAE:一種稀疏表征學習方法

[3]. https://github.com/nakosung/VQ-VAE

[4]. https://blog.csdn.net/LoseInVain/article/details/105461904, 在pytorch中停止梯度流的若干辦法,避免不必要模塊的參數更新

[5]. https://pytorch.org/do

聲明:本內容為作者獨立觀點,不代表電子星球立場。未經允許不得轉載。授權事宜與稿件投訴,請聯系:editor@netbroad.com
覺得內容不錯的朋友,別忘了一鍵三連哦!
贊 1
收藏 2
關注 52
成為作者 賺取收益
全部留言
0/200
成為第一個和作者交流的人吧
主站蜘蛛池模板: 康保县| 晋中市| 如东县| 睢宁县| 大荔县| 枞阳县| 台中县| 牟定县| 湖口县| 邯郸县| 格尔木市| 阜平县| 福泉市| 安新县| 云霄县| 潜江市| 三明市| 甘南县| 灵璧县| 亳州市| 堆龙德庆县| 循化| 张北县| 高密市| 新巴尔虎右旗| 喀喇沁旗| 巴林左旗| 灵寿县| 荥经县| 滦平县| 张掖市| 尉犁县| 平罗县| 南陵县| 宜州市| 本溪| 滦南县| 平凉市| 福清市| 南昌市| 华蓥市|