設為首頁 - 加入收藏 - 網站地圖
當前位置:首頁 > 產品動態 > 正文

ICCV 2019 推薦Pytorch實現一種無需原始訓練數據的模型壓縮算法

時間:2019-11-08 08:13:16 來源:本站 閱讀:3996493次

背景

大多數深層神經網絡(CNN)往往消耗巨大的計算資源和存儲空間為了將模型部署到性能受限的設備(如移動設備),通常需要加速網絡的壓縮現有的一些加速壓縮算法,如知識蒸餾等,可以通過訓練數據獲得有效的結果。然而,在實際應用中,由于隱私、傳輸等原因,訓練數據集通常不可用因此,作者提出了一種不需要原始訓練數據的模型壓縮方法。

原理

上圖是本文提出的總體結構通過一個給定的待壓縮網絡(教師網絡),作者訓練一個生成器生成與原始訓練集分布相似的數據然后,利用生成的數據,基于知識提取算法對學生網絡進行訓練,從而實現無數據的模型壓縮。

那么,在沒有數據的情況下,如何在給定的教師網絡上訓練一個可靠的生成器呢作者提出了以下三個損失來指導發電機的學習。

(1)在圖像分類任務中,對于真實數據,網絡的輸出往往接近一個熱向量其中,分類類別的輸出接近于1,其他類別的輸出接近于零因此,如果生成器生成的圖像接近真實數據,那么它在教師網絡上的輸出應該類似于一個熱向量因此,作者提出了一個One-hotloss:

其中YT是通過教師網絡生成的圖片的輸出,T是偽標簽,并且由于生成的圖片不具有標簽,所以作者將YT中的最大值設置為偽標簽。Hcross表示交叉熵函數。

(2)另外,在神經網絡中,輸入真實數據往往比輸入的隨機噪聲在特征圖上有更大的響應值因此,作者建議激活損失約束生成的數據:

其中fT表示通過教師網絡提取生成的數據的特征,||·||1表示|1范數。

(3)此外,為了使網絡得到更好的訓練,訓練數據往往需要類別平衡因此,為了平衡同一類別中生成的數據,引入信息熵損失來度量類別平衡度:

其中,Hinfo表示信息熵,yT表示每張圖片的輸出如果信息熵較大,則對輸入的圖片集中的每個類別的平均數進行平均,從而確保生成的圖片類別的平均數。


最后,結合以上三個損耗函數,可以得到發電機培訓使用的損耗:

通過優化上述損失,您可以訓練生成器,然后通過生成器生成的樣本執行知識蒸餾在知識提取中,要壓縮的網絡(教師網絡)通常具有較高的精度,但存在冗余參數學生網絡是一個輕量級設計和隨機初始化網絡利用教師網絡的輸出來指導學生網絡的輸出,可以提高學生網絡的精度,達到模型壓縮的目的這個過程可以用以下公式表示:

其中,ys和yt分別表示學生網絡和教師網絡的輸出,Hcross表示交叉熵函數。

算法1表示項目方法的流程首先,通過優化上述損耗,獲得與原始數據集具有相似分布的發生器其次,通過生成器生成的圖像,將教師網絡的輸出通過知識蒸餾遷移到學生網絡中學生網絡的參數較少,支持無數據壓縮方法。

結果

MNIST數據集上的分類結果。

所提出的無數據學習方法的不同組成部分的有效性。

CIFAR數據集上的分類結果。

CelebA數據集上的分類結果

在各種數據集上的分類結果。

可視化每個類別中的平均圖像(從0至9)

第一卷積層中過濾器的可視化,在MNIST數據集上學習。第一行顯示訓練有素的過濾器,使用原始訓練數據集,并且底線顯示使用通過所提出的方法生成的樣本獲得的過濾器。

總結

常規方法需要原始訓練數據集,用于微調壓縮的深度神經網絡具有可接受的精度。但是,訓練集和給定深度網絡的詳細架構信息,由于某些隱私和傳輸限制,通常無法使用。

作者在本文中,我們提出了一個新穎的框架來訓練生成器以逼近原始沒有訓練數據的數據集。然后,一個便攜式網絡通過知識提煉方案可以有效地學習。

在基準數據集上的實驗表明,所提出的方法DAFL方法能夠無需任何培訓即可學習便攜式深度神經網絡數據。

論文地址:

https://arxiv.org/pdf/1904.01186.pdf

摘要:機鋒網,機鋒論壇,機鋒,劉精松,劉統勛,劉羽禪
頂一下
踩一下
TAGS標簽:機鋒網,機鋒論壇,機鋒,劉精松,劉統勛,劉羽禪