最近在进行kaggle练习的时候,发现一篇非常nice的notebook。于是想将其转化为自己的博客,既是对自己的一次总结,也可以帮助后来想要进行打kaggle的小伙伴。该篇博客的篇幅比较长,原版的notebook本人阅读加上敲代码差不多花了5天时间,但是相信如果完整的跟着本篇博客自己敲一遍,将对自己的代码能力和分析能力比较巨大的提高。
原作者Notebook:https://www.kaggle.com/martinpiotte/whale-recognition-model-with-score-0-78563#Training-data-construction
原作者2018年竞赛数据:https://www.kaggle.com/c/whale-categorization-playground
本篇博客2019年竞赛数据:https://www.kaggle.com/c/humpback-whale-identification
数据简介
比赛的形式很简单,给你一堆已经识别出类别的鲸鱼图片,即训练集。再给你一些未识别的鲸鱼图片,即测试集。需要你根据训练集将测试集中的图片进行类别预测,提交到kaggle的平台(提交过程需要翻墙),他会给你分数,越高说明你预测得越准确。
摘要
本次采用的模型主要是孪生网络,不清楚孪生网络的同学可以参考https://www.jianshu.com/p/92d7f6eaacf5,对于模型准确性起决定性因素的是用于在训练期间生成的图像对。 每次训练epoch的训练集由一系列(A,B)的图像对组成,从而:
一、50%的图像对是匹配的鲸鱼图像,50%的图像对是不同的鲸鱼。即例如有10个图像对,其中5个图像对的A,B来自于同一种鲸鱼类别,而另外5个图像对的A,B属于不同类别的鲸鱼。
二、每一次epoch,训练集中的每幅图像正好使用4次:匹配鲸鱼的A和B图像,不同鲸鱼对的A和B图像。(这一点并没有看懂,4次没有计算出来)
三、选择网络模型在训练阶段难以区分的不同鲸鱼的图像对。这是从对抗训练中得到的启发:找到不同类型鲸鱼的图像对,但模型却认为它们非常相似。这一点有点像高中时期的错题本,将自己容易做错的题目拿出来,刻意对这些题目进行训练。
在训练孪生网络的同时实施上述策略是对模型准确性提升较大。 其他细节在一定程度上有助于精度,但影响要小得多。
内容
1、找出重复的图像(重复的图片很少,该部分只提供思路)
2、图片预处理(常规操作)
3、搭建孪生网络(原文用的是keras版本,我按照其结构改成了pytorch版本)
4、生成训练数据集(绝大多数的trick在这部分,摘要中也提到训练集对于模型的准确性有较大影响)
5、训练过程
6、对测试集进行预测,提交文件
一、找出重复的图像
本节介绍用于识别重复图像的启发式方法。训练和测试集可能具有重复图像。ps:2018年重复图像较多,2019年重复图像较少。有些图像是完美的二进制副本,而另一些则做了一些改动:对比度和亮度,大小,遮盖图例等。
如果符合以下一个特征,则认为两张图像是重复的:
1、两张图像具有相同的phash值,phash值可以参考https://blog.csdn.net/zouxy09/article/details/17471401
2、两张图像同时满足以下条件,条件一:phash值最多相差6字节;条件二:拥有相同大小;条件三:归一化后的图像每个像素均方误差低于给定阈值。
变量说明:picture_2_hash表示每张图像分配唯一的phash值。hash_2_picture表示为每个phash值分配唯一的图像id。这是因为不同的picture可能有相同的phash值,即一个hash值对应多个picture,最后我们选择分辨率最高的那张picture。
一、查看数据集大小
if __name__ == '__main__':
TRAIN_DF = 'data/train.csv'
SUBMIT_DF = 'data/sample_submission.csv'
tagged = dict([(p, w) for _, p, w in pd.read_csv(TRAIN_DF).to_records()])
submit = [p for _, p, _ in pd.read_csv(SUBMIT_DF).to_records()]
join = list(tagged.keys()) + submit # 训练集和测试集所有图片名字的总和
print('给定的训练集大小:{}'.format(len(tagged)))
print('给定的测试集大小:{}'.format(len(submit)))
print('全部图片数量:{}'.format(len(join)))
print('训练集前5个样本:{}'.format(list(tagged.items())[:5]))
print('测试集前5个样本:{}'.format(submit[:5]))
给定的训练集大小:25361
给定的测试集大小:7960
全部图片数量:33321
训练集前5个样本:
[('0000e88ab.jpg', 'w_f48451c'), ('0001f9222.jpg', 'w_c3d896a'),
('00029d126.jpg', 'w_20df2c5'), ('00050a15a.jpg', 'new_whale'),
('0005c1ef8.jpg', 'new_whale')]
测试集前5个样本:
['00028a005.jpg', '000dcf7d8.jpg', '000e7c7df.jpg',
'0019c34f4.jpg', '001a4d292.jpg']
二、获取图片大小
# Determise the size of each image
from os.path import isfile
from PIL import Image as pil_image
from tqdm import tqdm_notebook
TRAIN = 'data/train/'
TEST = 'data/test/'
def expand_path(path):
if isfile(TRAIN + path):
return TRAIN + path
if isfile(TEST + path):
return TEST + path
return path
picture_2_size={}
for file in tqdm.tqdm(join, desc='all samples of train set and test set'):
file_path = expand_path(file)
if isfile(file_path): # 图片存在
img_size = Image.open(file_path).size
picture_2_size[file] = img_size
picture_2_size大小:33321
[('0000e88ab.jpg', (1050, 700)), ('0001f9222.jpg', (758, 325)),
('00029d126.jpg', (1050, 497)), ('00050a15a.jpg', (1050, 525)),
('0005c1ef8.jpg', (1050, 525))]
三、计算每张图片的phash值
import pickle
import numpy as np
from imagehash import phash
from math import sqrt
# Two phash values are considered duplicate if, for all associated image pairs:
# 1) They have the same mode and size;
# 2) After normalizing the pixel to zero mean and variance 1.0, the mean square error does not exceed 0.1
def match(h1, h2, h2ps):
for p1 in h2ps[h1]:
for p2 in h2ps[h2]:
img_1 = Image.open(expand_path(p1))
img_2 = Image.open(expand_path(p2))
if img_1.mode != img_2.mode or img_1.size != img_2.size:
return False
img_1 = np.array(img_1)
norm_1 = (img_1 - np.mean(img_1)) / np.std(img_1)
img_2 = np.array(img_2)
norm_2 = (img_2 - np.mean(img_2)) / np.std(img_2)
diff = ((norm_1 - norm_2) ** 2).mean()
if diff > 0.1:
return False
return True
# Read or generate p2h, a dictionary of image name to image id (picture to hash)感知哈希算法,判断两张图片是否相近
data_path = 'data/p2h.pickle'
if isfile(data_path):
print('the file picture_2_hash exit')
with open(data_path, 'rb') as f:
picture_2_hash = pickle.load(f)
else:
print('the file picture_2_hash not exit')
# Compute phash for each image in the training and test set.
picture_2_hash = {} # 图片到hash值
for file in tqdm.tqdm(join, desc='Compute phash for each image in the training and test set.'):
if isfile(expand_path(file)):
img = Image.open(expand_path(file))
img_hash = phash(img) # 为图片计算其hash值
picture_2_hash[file] = img_hash
# Find all images associated with a given phash value.
hash_2_picture = {} # hash值到图片的映射
for picture, hash in picture_2_hash.items(): # 相同的hash值下面可能有多张图片
if hash not in hash_2_picture:
hash_2_picture[hash] = []
if picture not in hash_2_picture[hash]:
hash_2_picture[hash].append(picture)
# Find all distinct phash values
hash_list = list(hash_2_picture.keys())
# If the images are close enough, associate the two phash values (this is the slow part: n^2 algorithm)
hash_2_hash = {}
for index, h1 in enumerate(tqdm.tqdm(hash_list, desc='match close images')):
for h2 in hash_list[:index]:
if h1 - h2 <= 6 and match(h1, h2, hash_2_picture): # 两个hash值足够近为其设置关联关系
s1 = str(h1)
s2 = str(h2)
if s1 < s2:
s1, s2 = s2, s1
hash_2_hash[s1] = s2
# Group together images with equivalent phash, and replace by string format of phash (faster and more readable)
for picture, hash in picture_2_hash.items():
hash = str(hash)
if hash in hash_2_hash:
hash = hash_2_hash[hash]
picture_2_hash[picture] = hash
pickle_path = 'data/p2h.pickle'
with open(pickle_path, mode='wb') as f:
pickle.dump(picture_2_hash, f)
print('pickle dump finish!')
picture_2_hash大小:33321
[('0000e88ab.jpg', 'd26698c3271c757c'), ('0001f9222.jpg', 'ba8cc231ad489b77'),
('00029d126.jpg', 'bbcad234a52d0f0b'), ('00050a15a.jpg', 'c09ae7dc09f33a29'),
('0005c1ef8.jpg', 'd02f65ba9f74a08a')]
四、计算phash值对应的图片
hash_2_picture = {}
for picture, hash in picture_2_hash.items(): # 相同的hash值下面可能有多张图片
if hash not in hash_2_picture:
hash_2_picture[hash] = []
if picture not in hash_2_picture[hash]:
hash_2_picture[hash].append(picture)
hash_2_picture大小:33317
[('d26698c3271c757c', ['0000e88ab.jpg']),('ba8cc231ad489b77', ['0001f9222.jpg']),
('bbcad234a52d0f0b', ['00029d126.jpg']),('c09ae7dc09f33a29', ['00050a15a.jpg']),
('b1685bb3742cc372', ['01f66ca26.jpg', 'd37179fd1.jpg'])]
对于看到存在一个hash值存在多张图片的情况,我们将两张图片进行展示。除少部分细节外,两张图片非常相近,怪不得有相同的phash值,所以我们现在要选择一张最佳的图片,选择标准为分辨率谁高便选谁。
具有相同phash值的不同图片
# For each images id, select the prefered image 若hash值相近的图片,返回精度最高的那张图
def prefer(pictures, picture_size):
if len(pictures) == 1:
return pictures[0]
best_p = pictures[0]
best_s = picture_size[best_p]
for i in range(1, len(pictures)):
p = pictures[i]
s = picture_size[p]
if s[0] * s[1] > best_s[0] * best_s[1]: # Select the image with highest resolution
best_p = p
best_s = s
return best_p
最后将phash相同的不同图片打印出,发现数量并不多,看来在2019年的数据集中有很大改进。
Images: ['01f66ca26.jpg', 'd37179fd1.jpg']
Images: ['579886448.jpg', 'f50529c53.jpg']
Images: ['60a3f2422.jpg', '7f7a63b8a.jpg']
Images: ['b95d73a55.jpg', 'fb3879dc7.jpg']
重构已有代码
到此已经完成重复图片查找的过程,不过原作者代码的结构性不强,因此我自己重新创建了DataSouce类,将hash_2_picture、picture_2_size、picutre_2_hash等数据文件全部放到该类下。将例如prefer、expand_path、match等工具方法放到util.py文件中。先前生成的picture_2_hash和picutre_2_size等数据可以这里下载,当然也可以自己跑一遍保存下来。
链接:https://pan.baidu.com/s/1GIhWZlW84idbFgKUOE1Yww 提取码:f2bl
文件结构为
项目结构
read_data.py
import pandas as pd
import tqdm
from utils import expand_path, match, prefer
from PIL import Image
from os.path import isfile
import pickle
from imagehash import phash
class DataSource():
def __init__(self, TRAIN_DF, SUB_DF, HASH_PATH, SIZE_PATH):
super(DataSource, self).__init__()
# Read the dataset description
self.tagged = dict([(p, w) for _, p, w in pd.read_csv(TRAIN_DF).to_records()])
self.submit = [p for _, p, _ in pd.read_csv(SUB_DF).to_records()]
self.join = list(self.tagged.keys()) + self.submit # 训练集和测试集所有图片名字的总和
self.picture_2_size = self.get_picture_2_size(data_path=SIZE_PATH) # 获取所有图片的大小
self.picture_2_hash = self.get_picture_2_hash(data_path=HASH_PATH) # 图片映射hash值
self.hash_2_picture = self.get_hash_2_picture() # hash值映射图片
# 每张图片对应的大小
def get_picture_2_size(self, data_path='../data/p2size.pickle'):
# Determise the size of each image
if isfile(data_path):
print('the file picture_2_size exit')
with open(data_path, 'rb') as f:
picture_size = pickle.load(f)
else:
picture_size = {}
for file in tqdm.tqdm(self.join, desc='all samples of train set and test set'):
file_path = expand_path(file)
if isfile(file_path): # 图片存在
img_size = Image.open(file_path).size
picture_size[file] = img_size
return picture_size
# 每张图片对应的hash值
def get_picture_2_hash(self, data_path='../data/p2h.pickle'):
# Read or generate p2h, a dictionary of image name to image id (picture to hash)感知哈希算法,判断两张图片是否相近
if isfile(data_path):
print('the file picture_2_hash exit')
with open(data_path, 'rb') as f:
picture_2_hash = pickle.load(f)
else:
print('the file picture_2_hash not exit')
# Compute phash for each image in the training and test set.
picture_2_hash = {} # 图片到hash值
for file in tqdm.tqdm(self.join, desc='Compute phash for each image in the training and test set.'):
if isfile(expand_path(file)):
img = Image.open(expand_path(file))
img_hash = phash(img) # 为图片计算其hash值
picture_2_hash[file] = img_hash
# Find all images associated with a given phash value.
hash_2_picture = {} # hash值到图片的映射
for picture, hash in picture_2_hash.items(): # 相同的hash值下面可能有多张图片
if hash not in hash_2_picture:
hash_2_picture[hash] = []
if picture not in hash_2_picture[hash]:
hash_2_picture[hash].append(picture)
# Find all distinct phash values
hash_list = list(hash_2_picture.keys())
# If the images are close enough, associate the two phash values (this is the slow part: n^2 algorithm)
hash_2_hash = {}
for index, h1 in enumerate(tqdm.tqdm(hash_list, desc='match close images')):
for h2 in hash_list[:index]:
if h1 - h2 <= 6 and match(h1, h2, hash_2_picture): # 两个hash值足够近为其设置关联关系
s1 = str(h1)
s2 = str(h2)
if s1 < s2:
s1, s2 = s2, s1
hash_2_hash[s1] = s2
# Group together images with equivalent phash, and replace by string format of phash (faster and more readable)
for picture, hash in picture_2_hash.items():
hash = str(hash)
if hash in hash_2_hash:
hash = hash_2_hash[hash]
picture_2_hash[picture] = hash
pickle_path = 'data/p2h.pickle'
with open(pickle_path, mode='wb') as f:
pickle.dump(picture_2_hash, f)
print('pickle dump finish!')
return picture_2_hash
# hash值对应图片,一个hash值可能对应多张图片
def get_hash_2_picture(self):
# For each image id, determine the list of pictures
hash_2_picture = {}
for picture, hash in self.picture_2_hash.items(): # 相同的hash值下面可能有多张图片
if hash not in hash_2_picture:
hash_2_picture[hash] = []
if picture not in hash_2_picture[hash]:
hash_2_picture[hash].append(picture)
# Notice how 25460 images use only 20913 distinct image ids.
# print(len(hash_2_picture))
# print(list(hash_2_picture.items())[:5])
# For each images id, select the prefered image
for h, ps in hash_2_picture.items():
if len(ps) >= 2:
print('phash值相同的图片:', ps)
print('-' * 100)
temp = {}
for hash, pictures in hash_2_picture.items():
temp[hash] = prefer(pictures, self.picture_2_size)
hash_2_picture = temp
return hash_2_picture
util.py
from os.path import isfile
from PIL import Image
import numpy as np
TRAIN = 'data/train/'
TEST = 'data/test/'
def expand_path(path):
if isfile(TRAIN + path):
return TRAIN + path
if isfile(TEST + path):
return TEST + path
return path
# Two phash values are considered duplicate if, for all associated image pairs:
# 1) They have the same mode and size;
# 2) After normalizing the pixel to zero mean and variance 1.0, the mean square error does not exceed 0.1
def match(h1, h2, h2ps):
for p1 in h2ps[h1]:
for p2 in h2ps[h2]:
img_1 = Image.open(expand_path(p1))
img_2 = Image.open(expand_path(p2))
if img_1.mode != img_2.mode or img_1.size != img_2.size:
return False
img_1 = np.array(img_1)
norm_1 = (img_1 - np.mean(img_1)) / np.std(img_1)
img_2 = np.array(img_2)
norm_2 = (img_2 - np.mean(img_2)) / np.std(img_2)
diff = ((norm_1 - norm_2) ** 2).mean()
if diff > 0.1:
return False
return True
# For each images id, select the prefered image 若hash值相近的图片,返回精度最高的那张图
def prefer(pictures, picture_size):
if len(pictures) == 1:
return pictures[0]
best_p = pictures[0]
best_s = picture_size[best_p]
for i in range(1, len(pictures)):
p = pictures[i]
s = picture_size[p]
if s[0] * s[1] > best_s[0] * best_s[1]: # Select the image with highest resolution
best_p = p
best_s = s
return best_p
main.py生成datasource
from utils import DataSource
if __name__ == '__main__':
TRAIN_DF = 'data/train.csv'
SUBMIT_DF = 'data/sample_submission.csv'
HASH_PATH = 'data/p2h.pickle'
SIZE_PATH = 'data/p2size.pickle'
data_source = DataSource(TRAIN_DF, SUBMIT_DF, HASH_PATH, SIZE_PATH)
print('finish')
main.py输出
the file picture_2_size exit
the file picture_2_hash exit
phash值相同的图片: ['01f66ca26.jpg', 'd37179fd1.jpg']
phash值相同的图片: ['579886448.jpg', 'f50529c53.jpg']
phash值相同的图片: ['60a3f2422.jpg', '7f7a63b8a.jpg']
phash值相同的图片: ['b95d73a55.jpg', 'fb3879dc7.jpg']
--------------------------------------------------------------------
finish
由于篇幅较长,下一篇博客将讲解内容二:图片预处理










网友评论