
- PyTorch 教程
- PyTorch - 主頁
- PyTorch - 介紹
- PyTorch - 安裝
- 神經網路的數學基礎知識
- PyTorch - 神經網路基礎
- 機器學習的通用工作流程
- 機器學習與深度學習
- 實現第一個神經網路
- 將神經網路應用到功能模組
- PyTorch - 術語
- PyTorch - 載入資料
- PyTorch - 線性迴歸
- PyTorch - 卷積神經網路
- PyTorch - 迴圈神經網路
- PyTorch - 資料集
- PyTorch - 卷積神經網路簡介
- 從頭開始訓練卷積神經網路
- PyTorch - 卷積神經網路中的特徵提取
- PyTorch - 卷積神經網路視覺化
- 使用卷積神經網路處理序列
- PyTorch - 詞嵌入
- PyTorch - 遞迴神經網路
- PyTorch 有用資源
- PyTorch - 快速指南
- PyTorch - 有用資源
- PyTorch - 討論
PyTorch - 從頭開始訓練卷積神經網路
在本章中,我們將重點介紹從頭開始建立卷積神經網路。這包括使用 torch 建立相應的卷積神經網路或樣本神經網路。
步驟 1
使用相應引數建立必要的類。這些引數包括具有隨機值的權重。
class Neural_Network(nn.Module): def __init__(self, ): super(Neural_Network, self).__init__() self.inputSize = 2 self.outputSize = 1 self.hiddenSize = 3 # weights self.W1 = torch.randn(self.inputSize, self.hiddenSize) # 3 X 2 tensor self.W2 = torch.randn(self.hiddenSize, self.outputSize) # 3 X 1 tensor
步驟 2
建立帶有 sigmoid 函式的前饋模式函式。
def forward(self, X): self.z = torch.matmul(X, self.W1) # 3 X 3 ".dot" does not broadcast in PyTorch self.z2 = self.sigmoid(self.z) # activation function self.z3 = torch.matmul(self.z2, self.W2) o = self.sigmoid(self.z3) # final activation function return o def sigmoid(self, s): return 1 / (1 + torch.exp(-s)) def sigmoidPrime(self, s): # derivative of sigmoid return s * (1 - s) def backward(self, X, y, o): self.o_error = y - o # error in output self.o_delta = self.o_error * self.sigmoidPrime(o) # derivative of sig to error self.z2_error = torch.matmul(self.o_delta, torch.t(self.W2)) self.z2_delta = self.z2_error * self.sigmoidPrime(self.z2) self.W1 + = torch.matmul(torch.t(X), self.z2_delta) self.W2 + = torch.matmul(torch.t(self.z2), self.o_delta)
步驟 3
建立如下所示的訓練和預測模型 -
def train(self, X, y): # forward + backward pass for training o = self.forward(X) self.backward(X, y, o) def saveWeights(self, model): # Implement PyTorch internal storage functions torch.save(model, "NN") # you can reload model with all the weights and so forth with: # torch.load("NN") def predict(self): print ("Predicted data based on trained weights: ") print ("Input (scaled): \n" + str(xPredicted)) print ("Output: \n" + str(self.forward(xPredicted)))
廣告