之前利用PIL把dicom的slice保存为了16位灰度图, 用torchvision.transform做图像增强时发现会报错.
Dataset的__getitem__函数如下
def __getitem__(self, idx):
pth = self.dataset[idx]
img = Image.open(pth) # 范围为 [0-2048] 的16位tiff图片
img = torchvision.transform.ToTensor()(img)
return img
output:
RuntimeError: shape '[64, 64, 5]' is invalid for input of size 8192
查询了一下torchvision.transform.ToTensor()函数, 发现对输入值域要求为[0-255]. 估计是我[0-2048]的范围出发了某种判断, 使得该函数以为输入图片为某种其他格式.
将函数改为如下, 解决了问题
def __getitem__(self, idx):
pth = self.dataset[idx]
img = np.array(Image.open(pth), dtype='float32') / 2048 # 范围为 [0-1] 的单精度`numpy`数组
img = Image.fromarray(img)
img = torchvision.transform.ToTensor()(img)
return img
注意, 若不显式注明dtype='float32', 会自动转换为float64的tensor, 不确定对训练结果和速度有何影响 (pytorch的默认数据类型为float32).
总之, 下次直接把图片保存为numpy格式会更方便些.










网友评论