GrowNet:梯度提升神經網路
介紹
GrowNet是一個新穎的梯度提升框架,它使用梯度提升技術從淺層深度神經網路構建複雜的深度神經網路。淺層深度神經網路用作弱學習器。如今,GrowNet正在各個領域和行業中得到應用。
梯度提升演算法的簡要回顧。
梯度提升是一種按順序構建模型的技術,這些模型試圖減少先前模型產生的誤差。這是透過對先前模型產生的殘差或誤差構建模型來完成的。它可以使用數值方法的最佳化來估計函式。最常見的梯度提升函式型別是決策樹,其中每個決策都是透過擬合先前樹的負梯度來建模的。
梯度提升可以是用於迴歸任務的梯度提升迴歸器,也可以是用於分類任務的梯度提升分類器。
GrowNet——應用於神經網路的新型提升思想
梯度提升演算法背後的主要概念或思想是,它使用較低級別的簡單模型作為構建塊來構建更強大且功能更強大的模型(通常是更高階的模型),方法是使用一階和二階梯度導數進行順序梯度提升。在這些模型中,弱學習器提高了高階模型的效能。
在每個提升步驟中,初始輸入特徵是
原始輸入特徵擴充套件到當前迭代的先前層輸出。此合併的特徵集用作輸入,用於使用基於提升的機制(使用當前殘差)訓練下一組弱學習器。所有來自順序訓練的模型的輸出都經過加權並組合以給出最終輸出。
假設一個數據集有m個特徵,每個特徵有d維,則
$$ \mathrm{T = {{(xi, yi)|xi ∈ R^d,yi ∈ R,|T| = m}}} $$
假設Grownet進行N次迭代,
$$ \mathrm{ŷ_i\:=\:∅(x_i)\:=\:\displaystyle\sum\limits_{n=0}^N α_n\:Fk(x_i), \: n ∈ F} $$
其中F = 空間中的乘數,αn = 步長大小。Fn表示每個具有輸出層的淺層神經網路。
如果l是可微的損失函式,則要最小化的目標函式為以下等式
$$ \mathrm{L(\epsilon)\:=\:\sum_{i=0}^nl(y_i,\hat{y}_{i})} $$
我們可以進一步新增正則化。
設$\mathrm{\hat{y}_{i}^{(t-1)}\:=\:\sum_{k=0}^{t-1}\alpha_kf_k(x_i)}$是xi樣本在t−1階段的Grownet輸出,則
$$ \mathrm{L^{(t)}\:=\:\sum_{i=0}^nl(y_i,\hat{y}_{i}^{(t-1)}+f_i(x_i))} $$
弱學習器的目標函式將給出為。
$$ \mathrm{L^{(t)}\:=\:\sum_{i=0}^n(\tilde{y_{i}}\:-\:f_i(x_i))^2} $$
其中,
$$ \mathrm{\tilde{y_{i}}\:=\:-gi/hi} $$
引入校正步驟。
在每個步驟(提升階段)t,都會更新第t個弱學習器的引數,並且所有先前的(t−1)弱學習器都不會更改。在這個過程中,模型可能會在學習過程中陷入區域性最小值,這可以透過αn來緩解。因此,我們引入一個校正步驟,在每個校正步驟中,允許每個t−1學習器透過反向傳播更新引數。
GrowNet的應用
GrowNet可用於迴歸和分類。
對於迴歸。
對於迴歸任務,採用MSE損失函式。如果l是均方損失,則用一階和二階以及t個階段獲得yi為
$$ \mathrm{g_i\:=\:2(\hat{y}^{(t-1)}\:-\:y_i), \:\:\:h_i\:=\:2} $$
$$ \mathrm{\tilde{y}_{i}\:=\:y_i\:-\:\hat{y}^{(t-1)}} $$
然後,透過對每個xi,yi(i = 1,2…)進行最小二乘迴歸來訓練即將到來的弱學習器,並且在校正狀態下,使用MSE損失再次更新GrowNet中的所有模型引數。
對於分類
在二元交叉熵示例的情況下,交叉熵損失函式是可微的。取標籤yi ∈ {−1, +1},在任何點t,一階梯度給出為
$$ \mathrm{g_i\:=\:-\frac{-2y_i}{1\:+\:e^{2y_i\hat{y_i}^{(t-1)}}},\: \: h_i\:=\:\frac{4y_i^{}2e^{2y_i\hat{y_i}^{(t-1)}}}{(1\:+\:e^{-2y_i\hat{y_i}^{(t-1)}})^2}} $$
$$ \mathrm{\tilde{y}_i\:=\:-g_i/h_i\:=\:y_i(1\:+\:e^{-2y_i\hat{y_i}^{(t-1)}})/2} $$
使用最小二乘法使用二階導數擬合即將到來的弱學習器。最終使用二元交叉熵損失更新所有函式的引數。
結論
GrowNet是一種應用於深度神經網路的梯度提升技術的新方法,我們可以使用一個框架中的機器學習來完成許多工。它是簡單深度神經網路的更好替代方案,因為它提供了更好的效能並且訓練時間更短。
資料結構
網路
關係資料庫管理系統 (RDBMS)
作業系統
Java
iOS
HTML
CSS
Android
Python
C語言程式設計
C++
C#
MongoDB
MySQL
Javascript
PHP