以下代码选自链接: https://github.com/jwyang/faster-rcnn.pytorch
trainval_net.py:
from __future__ import absolute_import
from __future__ import absolute_import 的作用:https://blog.csdn.net/caiqiiqi/article/details/51050800
from __future__ import division
from __future__ import division 的作用:https://blog.csdn.net/feixingfei/article/details/7081446
from __future__ import print_function
from __future__ import print_function 的作用:https://blog.csdn.net/xiaotao_1/article/details/79460365
import os
import numpy as np
import argparse
import pprint
import pdb
调试 Python 代码的另一种手段:pdb(https://www.jianshu.com/p/fb5f791fcb18)
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data.sampler import Sampler
from roi_data_layer.roidb import combined_roidb
from roi_data_layer.roibatchLoader import roibatchLoader
from model.utils.config import cfg, cfg_from_file, cfg_from_list, get_output_dir
from model.utils.net_utils import weights_normal_init, save_net, load_net, \
adjust_learning_rate, save_checkpoint, clip_gradient
from model.faster_rcnn.vgg16 import vgg16
from model.faster_rcnn.resnet import resnet
def parse_args():
parser = argparse.ArgumentParser(description='Train a Fast R-CNN network')
parser.add_argument('--dataset', dest='dataset',
help='training dataset',
default='pascal_voc', type=str)
parser.add_argument('--net', dest='net',
help='vgg16, res101',
default='vgg16', type=str)
parser.add_argument('--start_epoch', dest='start_epoch',
help='starting epoch',
default=1, type=int)
# ...... #
args = parser.parse_args()
return args
每一个 add_argument 用三行书写,add_argument 之间没有空行。
class sampler(Sampler):
def __init__(self, train_size, batch_size):
self.num_data = train_size
self.num_batches = int(train_size / batch_size)
self.batch_size = batch_size
self.range = torch.arange(0, batch_size).view(1, batch_size).long()
self.leftover_flag = False
if train_size % batch_size:
self.leftover = torch.arange(self.num_batches*batch_size, train_size, train_size).long()
self.leftover_flag = True
def __iter__(self):
rand_num = torch.randperm(self.num_batches).view(-1, 1) * self.batch_size
self.rand_num = rand_num.expand(self.num_batches, self.batch_size) + self.range
self.rand_num_view = self.rand_num.view(-1)
if self.leftover_flag:
self.rand_num_view = torch.cat((self.rand_num_view, self.leftover), 0)
return iter(self.rand_num_view)
def __len__(self):
return self.num_data
torch.arange 与 torch.range 的区别: https://blog.csdn.net/m0_37586991/article/details/88830026
torch.randperm() : https://blog.csdn.net/u013066730/article/details/93844306
Pytorch expand() 和 repeat() : https://zhuanlan.zhihu.com/p/58109107
view(-1) 的作用是展成一行。
torch.cat() : https://blog.csdn.net/qq_39709535/article/details/80803003
if __name__ == '__main__':
args = parse_args()
print('Called with args:')
print(args)
if args.dataset == "pascal_voc":
args.imdb_name = "voc_2007_trainval"
args.imdbval_name = "voc_2007_test"
args.set_cfgs = ['ANCHOR_SCALES', '[8, 16, 32]', 'ANCHOR_RATIOS', '[0.5,1,2]', 'MAX_NUM_GT_BOXES', '20']
elif args.dataset == "pascal_voc_0712":
args.imdb_name = "voc_2007_trainval+voc_2012_trainval"
args.imdbval_name = "voc_2007_test"
args.set_cfgs = ['ANCHOR_SCALES', '[8, 16, 32]', 'ANCHOR_RATIOS', '[0.5,1,2]', 'MAX_NUM_GT_BOXES', '20']
elif args.dataset == "coco":
args.imdb_name = "coco_2014_train+coco_2014_valminusminival"
args.imdbval_name = "coco_2014_minival"
args.set_cfgs = ['ANCHOR_SCALES', '[4, 8, 16, 32]', 'ANCHOR_RATIOS', '[0.5,1,2]', 'MAX_NUM_GT_BOXES', '50']
elif args.dataset == "imagenet":
args.imdb_name = "imagenet_train"
args.imdbval_name = "imagenet_val"
args.set_cfgs = ['ANCHOR_SCALES', '[4, 8, 16, 32]', 'ANCHOR_RATIOS', '[0.5,1,2]', 'MAX_NUM_GT_BOXES', '30']
elif args.dataset == "vg":
# train sizes: train, smalltrain, minitrain
# train scale: ['150-50-20', '150-50-50', '500-150-80', '750-250-150', '1750-700-450', '1600-400-20']
args.imdb_name = "vg_150-50-50_minitrain"
args.imdbval_name = "vg_150-50-50_minival"
args.set_cfgs = ['ANCHOR_SCALES', '[4, 8, 16, 32]', 'ANCHOR_RATIOS', '[0.5,1,2]', 'MAX_NUM_GT_BOXES', '50']
args.cfg_file = "cfgs/{}_ls.yml".format(args.net) if args.large_scale else "cfgs/{}.yml".format(args.net)
if args.cfg_file is not None:
cfg_from_file(args.cfg_file)
if args.set_cfgs is not None:
cfg_from_list(args.set_cfgs)
print('Using config:')
pprint.pprint(cfg)
np.random.seed(cfg.RNG_SEED)
if torch.cuda.is_available() and not args.cuda:
print("WARNING: You have a CUDA device, so you should probably run with --cuda")
# train set
# -- Note: Use validation set and disable the flipped to enable faster loading.
cfg.TRAIN.USE_FLIPPED = True
cfg.USE_GPU_NMS = args.cuda
imdb, roidb, ratio_list, ratio_index = combined_roidb(args.imdb_name)
train_size = len(roidb)
print('{:d} roidb entries'.format(len(roidb)))
output_dir = args.save_dir + "/" + args.net + "/" + args.dataset
if not os.path.exists(output_dir):
os.makedirs(output_dir)
sampler_batch = sampler(train_size, args.batch_size)
dataset = roibatchLoader(roidb, ratio_list, ratio_index, args.batch_size, \
imdb.num_classes, training=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size,
sampler=sampler_batch, num_workers=args.num_workers)
# initialize the tensor holder here
im_data = torch.FloatTensor(1)
im_info = torch.FloatTensor(1)
num_boxes = torch.LongTensor(1)
gt_boxes = torch.FloatTensor(1)
# ship to cuda
if args.cuda:
im_data = im_data.cuda()
im_info = im_info.cuda()
num_boxes = num_boxes.cuda()
gt_boxes = gt_boxes.cuda()
if args.cuda:
cfg.CUDA = True
# initialize the network here
if args.net == 'vgg16':
fasterRCNN = vgg16(imdb.classes, pretrained=True, class_agnostic=args.class_agnostic)
elif args.net == 'res101':
fasterRCNN = resnet(imdb.classes, 101, pretrained=True, class_agnostic=args.class_agnostic)
elif args.net == 'res50':
fasterRCNN = resnet(imdb.classes, 50, pretrained=True, class_agnostic=args.class_agnostic)
elif args.net == 'res152':
fasterRCNN = resnet(imdb.classes, 152, pretrained=True, class_agnostic=args.class_agnostic)
else:
print("network is not defined")
pdb.set_trace()
fasterRCNN.create_architecture()
lr = cfg.TRAIN.LEARNING_RATE
lr = args.lr
params = []
for key, value in dict(fasterRCNN.named_parameters()).items():
if value.requires_grad:
if 'bias' in key:
params += [{'params':[value], 'lr':lr*(cfg.TRAIN.DOUBLE_BIAS + 1), \
'weight_decay':cfg.TRAIN_BIAS_DECAY and cfg.TRAIN_WEIGHT_DECAY or 0}]
else:
params += [{'params':[value], 'lr':lr, 'weight_decay':cfg.TRAIN_WEIGHT_DECAY}]
if args.optimizer == "adam":
lr = lr * 0.1
optimizer = torch.optim.Adam(params)
elif args.optimizer == "sgd":
optimizer = torch.optim.SGD(params, momentum=cfg.TRAIN.MOMENTUM)
if args.cuda:
fasterRCNN.cuda()
if args.resume:
load_name = os.path.join(output_dir,
'faster_rcnn_{}_{}_{}.pth'.format(args.checksession, args.checkepoch, args.checkpoint))
print("loading checkpoint %s" % (load_name))
checkpoint = torch.load(load_name)
args.session = checkpoint['session']
args.start_epoch = checkpoint['epoch']
fasterRCNN.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
lr = optimizer.param_groups[0]['lr']
if 'pooling_mode' in checkpoint.keys():
cfg.POOLING_MODE = checkpoint['pooling_mode']
print("loaded checkpoint %s" % (load_name))
if args.mGPUS:
fasterRCNN = nn.DataParallel(fasterRCNN)
iters_per_epoch = int(train_size / args.batch_size)
if args.use_tfboard:
from tensorboardX import SummaryWriter
logger = SummaryWriter("logs")
for epoch in range(args.start_epoch, args.max_epochs+1):
# setting to train mode
fasterRCNN.train()
loss_temp = 0
start = time.time()
if epoch % (args.lr_decay_step + 1) == 0:
adjust_learning_rate(optimizer, args.lr_decay_gamma)
lr *= args.lr_decay_gamma
data_iter = iter(dataloader)
for step in range(iters_per_epoch):
data = next(data_iter)
im_data.data.resize_(data[0].size()).copy_(data[0])
im_info.data.resize_(data[1].size()).copy_(data[1])
gt_boxes.data.resize_(data[2].size()).copy_(data[2])
num_boxes.data.resize_(data[3].size()).copy_(data[3])
fasterRCNN.zero_grad()
rois, cls_prob, bbox_pred, \
rpn_loss_cls, rpn_loss_box, \
RCNN_loss_cls, RCNN_loss_bbox, \
rois_label = fasterRCNN(im_data, im_info, gt_boxes, num_boxes)
loss = rpn_loss_cls.mean() + rpn_loss_box.mean() \
+ RCNN_loss_cls.mean() + RCNN_loss_bbox.mean()
loss_temp += loss.item()
# backward
optimizer.zero_grad()
loss.backward()
if args.net == "vgg16":
clip_gradient(fasterRCNN, 10.)
optimizer.step()
if step % args.disp_interval == 0:
end = time.time()
if step > 0:
loss_temp /= (args.disp_interval + 1)
if args.mGPUS:
loss_rpn_cls = rpn_loss_cls.mean().item()
loss_rpn_box = rpn_loss_box.mean().item()
loss_rcnn_cls = RCNN_loss_cls.mean().item()
loss_rcnn_box = RCNN_loss_bbox.mean().item()
fg_cnt = torch.sum(rois_label.data.ne(0))
bg_cnt = rois_label.data.numel() - fg_cnt
else:
loss_rpn_cls = rpn_loss_cls.item()
loss_rpn_box = rpn_loss_box.item()
loss_rcnn_cls = RCNN_loss_cls.item()
loss_rcnn_box = RCNN_loss_bbox.item()
fg_cnt = torch.sum(rois_label.data.ne(0))
bg_cnt = rois_label.data.numel() - fg_cnt
print("[session %d][epoch %2d][iter %4d/%4d] loss: %.4f, lr: %.2e" \
% (args.session, epoch, step, iters_per_epoch, loss_temp, lr))
print("\t\t\tfg/bg=(%d/%d), time cost: %f" % (fg_cnt, bg_cnt, end-start))
print("\t\t\trpn_cls: %.4f, rpn_box: %.4f, rcnn_cls: %.4f, rcnn_box: %.4f" \
% (loss_rpn_cls, loss_rpn_box, loss_rcnn_cls, loss_rcnn_box))
if args.use_tfboard:
info = {
'loss': loss_temp,
'loss_rpn_cls': loss_rpn_cls,
'loss_rpn_box': loss_rpn_box,
'loss_rcnn_cls': loss_rcnn_cls,
'loss_rcnn_box': loss_rcnn_box
}
logger.add_scalars("logs_s_{}/losses".format(args.session), info, (epoch-1)*iters_per_epoch+step)
loss_temp = 0
start = time.time()
save_name = os.path.join(output_dir, 'faster_rcnn_{}_{}_{}.pth'.format(args.session, epoch, step))
save_checkpoint({
'session': args.session,
'epoch': epoch + 1,
'model': fasterRCNN.module.state_dict() if args.mGPUS else fasterRCNN.state_dict(),
'optimizer': optimizer.state_dict(),
'pooling_mode': cfg.POOLING_MODE,
'class_agnostic': args.class_agnostic,
}, save_name)
print('save model: {}'.format(save_name))
if args.use_tfboard:
logger.close()






网友评论