作 者: 月牙眼的楼下小黑
联 系:zlf111@mail.ustc.edu.cn
声 明: 欢迎转载本文中的图片或文字,请说明出处
参考资料:
[1].PyTorch常用工具模块
1 数据处理
import torch
from torch.utils import data
import os
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from torchvision.transforms import ToPILImage
show = ToPILImage()
1.1 数据加载
在 Pytorch 中, 数据加载可通过自定义一个继承Dataset类的数据集对象, 并实现两个方法:
- __ getitem __ : 返回一条数据
- __ len __ : 返回样本数量
class DogCat(data.Dataset):
def __init__(self, root):
imgs = os.listdir(root) # root:图片所在文件夹路径
self.imgs = [os.path.join(root, img) for img in imgs] # imgs:图片文件路径列表
def __getitem__(self, index):
img_path = self.imgs[index]
if 'dog' in img_path.split('/')[-1]:
label =1
else:
label = 0
pil_img= Image.open(img_path) # 利用 python 图像处理标准库的 open 方法打开图片
array = np.asarray(pil_img) # 将 PIL.image 转化为 np. ndarray 形式, 默认为 channel last 形式: [height, width, channel]
#array = np.transpose(array, (2, 0, 1)) # 将 channel last 形式转化成 channel first 形式:[channel, height, width]
data = torch.from_numpy(array) # 将 np.ndarray 转化为 Tensor 形式
return data, label
def __len__(self):
return len(self.imgs)
补充: 对于三维矩阵的转置, 如 a.transpose(2,0,1), 意思是原矩阵a 中(aix 0, aix 1, aix 2) 处的值,现在成为了转置后矩阵 (aix 2, aix 0 , aix 1)处的值。
in:
dataset = DogCat('/data1/zhanglf/myDLStudying/myDataSet/dog_cat_data/train/dogs')
in:
# 显示第一张图片
img,label= dataset[0]
plt.imshow(img) # 若为 channel last 形式的 tensor, 可用 matplotlib 中 imshow() 方法
print(label, img.size(), img.float().mean())
out:
1 torch.Size([500, 282, 3]) 169.23073522458628
in:
# 显示第一张图片
img,label= dataset[0]
plt.imshow(img) # 若为 channel first 形式的 tensor, 可用 transforms 中 的 ToPILImage() 方法
print(label, img.size(), img.float().mean())
在前面文章中提到过:ToPILImage 可以将
-
shape为(C,H,W)的Tensor -
shape为(H,W,C)的numpy.ndarray
转化成PIL.Image,值不变,方便可视化。注意到 它只能转变channel first形式的Tensor 。而在上面的__getitem__中,array = np.asarray(pil_img) 将 PIL.image 转化为 np. ndarray 形式, 默认为 channel last 形式: [height, width, channel]。 所以如果我们要使用 ToPILImage 方法显示图片,在将 PIL.image 转化为 np. ndarray 形式后,还需要利用转置方法将 channel last 形式改成 channel first 形式:[channel, height, width]
1.2 数据预处理
torchvision.transforms模块提供了对 PILImage对象和Tensor对象的常用操作。
对PILImage的操作包括:
-
Scale: 调整图片尺寸,长宽比保持不变 -
CenterCrop、RandomCrop、RandomSizeCrop:裁剪图片 Pad-
ToTensor: 将PILImage对象转化成channel first的Tensor并归一至[0,1]
对 Tensor的操作包括:
-
Normalize: 标准化,减均值,除以标准差 -
ToPlLImage:将Tensor转化为PILImage对象
in:
trans = transforms.Resize((100,100))
image = Image.open('./dog.1.jpg')
print(image.size)
image = trans(image)
print(image.size)
out:
(327, 499)
(100, 100)
如果要对图片进行多个操作, 可通过Compose方法将这些操作拼接起来。
in:
transform = transforms.Compose([
transforms.Resize(224), # 缩放图片,保持长宽比不变,最短边为224像素
transforms.CenterCrop(224), # 从图片中间切出 224x224 的图片
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [1, 1, 1])
])
class DogCat(data.Dataset):
def __init__(self, root, transforms = None):
imgs = os.listdir(root) # root:图片所在文件夹路径
self.imgs = [os.path.join(root, img) for img in imgs] # imgs:图片文件路径列表
self.transforms = transforms
def __getitem__(self, index):
img_path = self.imgs[index]
if 'dog' in img_path.split('/')[-1]:
label =1
else:
label = 0
data = Image.open(img_path) # 利用 python 图像处理标准库的 open 方法打开图片
if self.transforms:
data = self.transforms(data)
return data, label
def __len__(self):
return len(self.imgs)
in:
dataset = DogCat('/data1/zhanglf/myDLStudying/myDataSet/dog_cat_data/train/dogs', transforms = transform)
img,label= dataset[0]
print(img.size())
show(img)
out:
torch.Size([3, 224, 224])
1.3 ImageFolder
torchvision预先实现了常用的DataSet,如CIFAR-10, 可通过 torchvision.datasets.CIFAR10来调用。这里介绍一个经常使用的 DataSet——ImageFolder. ImageFolder 假设所有文件按文件夹保存,每个文件夹下存储同一个类别的图片,文件夹名为类名, ImageFolder会根据文件夹名顺序自动生成 label, 可以通过 class_to_idx查看 label 和 文件夹名的映射关系。其构造函数如下:
ImageFolder(root, transform = None, target_transform = None, Loader = default_loader)
- root :文件夹路径
- transform: 对
PILImage的转化操作,transform的输入是loader的返回对象 - target_transform: 对
label的转化 - loader : 读取图片函数,默认为读取
RGB格式的PILImage对象
in:
from torchvision.datasets import ImageFolder
dataset = ImageFolder('./myDataSet/dog_cat_data/train')
in:
dataset.class_to_idx
out:
{'cats': 0, 'dogs': 1}
in:
# 此时还没有任何 transform, 返回的是 PILImage 对象
# 第一维指示第几张图片,第二维为 1 返回 label, 为 0 返回 图片数据
print(dataset[0][1])
dataset[0][0]
out:
0
1.4 DataLoader
调用DataSet中的__getitem__只返回一个样本,而我们需要batch wise training ,Pytorch提供了 DataLoader帮助我们实现这些功能。其构造函数如下:
DataLoader(dataset,
batch_size=1,
shuffle = False,
sample =None,
sampler = None,
num_workers =0,
collate_fn = default_collate,
pin_memory =False,
drop_last = False)
- dataset: 加载的数据集对象(
DataSet对象) - shuffle: 是否将数据打乱
- sampler: 样本抽样
- num_workers: 使用多进程加载的进程数
- collate_fn: 如何将多个样本数据拼接成一个
patch - pin_memory: 是否将数据保存在
pin memory,pin memory中的数据转到gpu会快一些 - drop_last: 当
datast中的数据个数不是batchsize的整数倍,drop_last为True会将多出来不足一个batch的数据丢弃。
2. torchvision
torchvision是Pytorch 团队开发的独立于 Pytorch的视觉工具包,通过pip install torchvision安装,主要包含三部分:
- models:提供一系列经典已经预训练好的模型,包括
AlexNet,VGG,ResNet,Inception等 - datasets: 提供常用的数据集对象(
DataSet对象), 包括MNIST,CIFAR10/100,ImageNet,COCO等 - transforms: 提供常用数据预处理工作,主要包括对
Tensor和PILImage对象的操作。
3. 可视化工具 Visdom
Visdom可以创造、组织和共享多种数据的可视化,包括数值、图像、文本、视频, 支持 Pytorch和Numpy。.
-
Visdom的安装:pip install visdom -
Visdom的启动:python -m visdom.server, 打开浏览器输入:http://localhost:8097, 8097 是默认端口号。
Visdom中的两个重要概念:
- env: 环境。不同用户、不同程序一般使用不同
env. 不同env相互独立,互不影响。使用时不指定env,则默认使用main - pane: 窗格。 一个
env中可以有多个不同的pane, 每个pane可视化或记录某一信息,可以拖动、缩放、保存或关闭
In:
import visdom
vis = visdom.Visdom(env=u'test1') # 构建一个客户端对象,创建一个名为' test1' 的 env
x = torch.arange(1, 30, 0.01)
y = torch.sin(x)
vis.line(X=x, Y=y, win='sinx', opts={'title': 'y=sin(x)'}) # win 是 pane 名字,opts 设置 pane 格式,如 title, xlabel,ylabel
在
vis.line(X=x, Y=y, win='sinx', opts={'title': 'y=sin(x)'}) 中, win 参数指定 pane 名字, 如果不指定,visdom将自动分配一个新的pane. 如果两次操做指定的win名字一样,新操作将覆盖当前 pane 的内容。如在 上面的 pane中画 y = x 函数,原来的 y = sin(x) 将被覆盖。
In:
y = x
vis.line(X=x, Y=y, win='sinx', opts={'title': 'y=x'})
如果不想覆盖原图,可以使用
updateTrace方法,如:
y = x + 1
vis.updateTrace(X=x, Y=y, win='sinx', name='this is a new Trace')










网友评论