美文网首页
pytorch系列|3步入门篇

pytorch系列|3步入门篇

作者: reallocing | 来源:发表于2018-12-27 12:18 被阅读0次
1. 定义网络

继承nn.Module类,实现init和forward方法.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
  super(Net,self).__init__()
  self.fc1 = nn.Linear(240,120) #输入长度240,输出长度120
  self.fc2 = nn.Linear(120,60)
  self.fc3 = nn.Linear(60,10)

def forward(self,x):
  ##定义网络中的层该怎么连接
  x = F.relu(self.fc1(x))
  x = F.relu(self.fc2(x))
  x = self.fc3(x)
  return x

net = Net()

As a rule of thumb, you can put inside the forward method all the layers that do not have any weights to be updated. On the other hand, you should put all the layers that have weights to be updated inside the __init__.

2. 加载数据

pytorch使用了DatasetDataLoader来接入数据.我们可自定义Dataset来接入自己的数据集.

import torch
import pandas as pd
from torch.utils.data import Dataset,DataLoader

class ExampleDataset(Dataset):
  def __init__(self,csv_file):
    self.data_frame = pd.read_csv(csv_file)
 def __len__(self):
  return len(self.data_frame)
def __getitem__(self,idx):
  return self.data_frame[idx]

example_dataset = ExampleDataset('my_datasets.csv')

example_data_loader = DataLoader(example_dataset,batch_size = 10,shuffle=True,num_workers=2) # num_workers: used to load the data in parallel


## Loop over data
## enumerate() return index and value.
for batch_index, batch in enumerate(example_data_loader):
  print(batch_index,batch)

3. 训练网络
import torch.optim as optim
import torch.nn as nn

## instantiate network
net = Net()
## optimizer
optimizer = optim.SGD(net.parameters(),lr=1e-3)

## define loss function
criterion = nn.MSELoss() ## nn.CrossEntropyLoss()

for epoch in range(10): ## epoches = 10
  for i,batch in enumerate(example_data_loader):
    # get the inputs
    data,targets = batch

    # zero the gradient buffes
    optimizer.zero_grad()
    ## passes the data through the network
    output = net.forward(data)
    ## calculates the loss
    loss = criterion(output,target)
    ## propagates the loss back.
    loss.backward()
    ## update all weights of the network
    optimizer.step()


更多学习资料
参考

相关文章

网友评论

      本文标题:pytorch系列|3步入门篇

      本文链接:https://www.haomeiwen.com/subject/mzollqtx.html