main.py
暂时有几个疑问:
1.参数中split_num到底指什么
2.line 48参数output_name是什么
3.mainip是做什么的
1. 先调用95行if __name__ == '__main__':main主函数
-
line 96 ~ line 156规定参数并调用 -
line 160 ~ line 161-
os.environ[“CUDA_DEVICE_ORDER”] = “PCI_BUS_ID”按照PCI_BUS_ID顺序从0开始排列GPU设备 -
os.environ[“CUDA_VISIBLE_DEVICES”] = “0”设置当前使用的GPU设备仅为0号设备 设备名称为'/gpu:0'
os.environ[“CUDA_VISIBLE_DEVICES”] = “1”设置当前使用的GPU设备仅为1号设备 设备名称为'/gpu:0'
os.environ[“CUDA_VISIBLE_DEVICES”] = “0,1”设置当前使用的GPU设备为0,1号两个设备,名称依次为'/gpu:0'、'/gpu:1'
os.environ[“CUDA_VISIBLE_DEVICES”] = “1,0”设置当前使用的GPU设备为1,0号两个设备,名称依次为'/gpu:0'、'/gpu:1'。表示优先使用1号设备,然后使用0号设备
-
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Train on Medical Data')
parser.add_argument('--data_root_dir', type=str, required=True,
help='The root directory for your data.')
parser.add_argument('--weights_path', type=str, default='',
help='/path/to/trained_model.hdf5 from root. Set to "" for none.')
parser.add_argument('--split_num', type=int, default=0,
help='Which training split to train/test on.')
parser.add_argument('--net', type=str.lower, default='segcapsr3',
choices=['segcapsr3', 'segcapsr1', 'segcapsbasic', 'unet', 'tiramisu'],
help='Choose your network.')
parser.add_argument('--train', type=int, default=1, choices=[0,1],
help='Set to 1 to enable training.')
parser.add_argument('--test', type=int, default=1, choices=[0,1],
help='Set to 1 to enable testing.')
parser.add_argument('--manip', type=int, default=1, choices=[0,1],
help='Set to 1 to enable manipulation.')
parser.add_argument('--shuffle_data', type=int, default=1, choices=[0,1],
help='Whether or not to shuffle the training data (both per epoch and in slice order.')
parser.add_argument('--aug_data', type=int, default=1, choices=[0,1],
help='Whether or not to use data augmentation during training.')
parser.add_argument('--loss', type=str.lower, default='w_bce', choices=['bce', 'w_bce', 'dice', 'mar', 'w_mar'],
help='Which loss to use. "bce" and "w_bce": unweighted and weighted binary cross entropy'
'"dice": soft dice coefficient, "mar" and "w_mar": unweighted and weighted margin loss.')
parser.add_argument('--batch_size', type=int, default=1,
help='Batch size for training/testing.')
parser.add_argument('--initial_lr', type=float, default=0.0001,
help='Initial learning rate for Adam.')
parser.add_argument('--recon_wei', type=float, default=.1, #131.072,
help="If using capsnet: The coefficient (weighting) for the loss of decoder")
parser.add_argument('--slices', type=int, default=1,
help='Number of slices to include for training/testing.')
# 形成 3D 样本进行训练时要跳过的片数。输入-1,用于随机子取样,最多占总切片的5%。
parser.add_argument('--subsamp', type=int, default=-1,
help='Number of slices to skip when forming 3D samples for training. Enter -1 for random '
'subsampling up to 5% of total slices.')
parser.add_argument('--stride', type=int, default=1,
help='Number of slices to move when generating the next sample.')
# verbose是日志显示 0:每张切片都显示训练过程;1:每个iteration显示训练过程;2:每个epoch显示训练过程
parser.add_argument('--verbose', type=int, default=1, choices=[0, 1, 2],
help='Set the verbose value for training. 0: Silent, 1: per iteration, 2: per epoch.')
parser.add_argument('--save_raw', type=int, default=1, choices=[0,1],
help='Enter 0 to not save, 1 to save.')
parser.add_argument('--save_seg', type=int, default=1, choices=[0,1],
help='Enter 0 to not save, 1 to save.')
parser.add_argument('--save_prefix', type=str, default='',
help='Prefix to append to saved CSV.')
parser.add_argument('--thresh_level', type=float, default=0.,
help='Enter 0.0 for otsu thresholding, else set value')
parser.add_argument('--compute_dice', type=int, default=1,
help='0 or 1')
parser.add_argument('--compute_jaccard', type=int, default=1,
help='0 or 1')
parser.add_argument('--compute_assd', type=int, default=0,
help='0 or 1')
# 输入"-2 "表示只有CPU,"-1 "表示所有可用的GPU,或者输入一个以逗号分隔的GPU ID数字列表,例如:"0,1,4"。
parser.add_argument('--which_gpus', type=str, default="0",
help='Enter "-2" for CPU only, "-1" for all GPUs available, '
'or a comma separated list of GPU id numbers ex: "0,1,4".')
parser.add_argument('--gpus', type=int, default=-1,
help='Number of GPUs you have available for training. '
'If entering specific GPU ids under the --which_gpus arg or if using CPU, '
'then this number will be inferred, else this argument must be included.')
arguments = parser.parse_args()
if arguments.which_gpus == -2:
environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
environ["CUDA_VISIBLE_DEVICES"] = ""
elif arguments.which_gpus == '-1':
assert (arguments.gpus != -1), 'Use all GPUs option selected under --which_gpus, with this option the user MUST ' \
'specify the number of GPUs available with the --gpus option.'
else:
arguments.gpus = len(arguments.which_gpus.split(','))
environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
environ["CUDA_VISIBLE_DEVICES"] = str(arguments.which_gpus)
if arguments.gpus > 1:
assert arguments.batch_size >= arguments.gpus, 'Error: Must have at least as many items per batch as GPUs ' \
'for multi-GPU training. For model parallelism instead of ' \
'data parallelism, modifications must be made to the code.'
main(arguments)
2.调用def main(args)函数将相应参数传给此函数
-
line 33 ~ line 39try/except
python中try/except/else/finally语句的完整格式如下所示:
try:
Normal execution block
except A:
Exception A handle
except B:
Exception B handle
except:
Other exception handle
else:
if no exception,get here
finally:
print(“finally”)
说明:
正常执行的程序在try下面的Normal execution block执行块中执行,在执行过程中如果发生了异常,则中断当前在Normal execution block中的执行,跳转到对应的异常处理块中开始执行;
python从第一个except X处开始查找,如果找到了对应的exception类型则进入其提供的exception handle中进行处理,如果没有找到则直接进入except块处进行处理。except块是可选项,如果没有提供,该exception将会被提交给python进行默认处理,处理方式则是终止应用程序并打印提示信息;
如果在Normal execution block执行块中执行过程中没有发生任何异常,则在执行完Normal execution block后会进入else执行块中(如果存在的话)执行。
无论是否发生了异常,只要提供了finally语句,以上try/except/else/finally代码块执行的最后一步总是执行finally所对应的代码块。
需要注意的是:
1.在上面所示的完整语句中try/except/else/finally所出现的顺序必须是try –> except X –> except –> else –> finally,即所有的except必须在else和finally之前,else(如果有的话)必须在finally之前,而except X必须在except之前。否则会出现语法错误。
2.对于上面所展示的try/except完整格式而言,else和finally都是可选的,而不是必须的,但是如果存在的话else必须在finally之前,finally(如果存在的话)必须在整个语句的最后位置。
3.在上面的完整语句中,else语句的存在必须以except X或者except语句为前提,如果在没有except语句的try block中使用else语句会引发语法错误。也就是说else不能与try/finally配合使用。
4.except的使用要非常小心,慎用。
举个例子:
try:
a = int(input("输入被除数:"))
b = int(input("输入除数:"))
c = a / b
print("您输入的两个数相除的结果是:", c )
except (ValueError, ArithmeticError):
print("程序发生了数字格式异常、算术异常之一")
except :
print("未知异常")
print("程序继续运行")
程序运行结果为:
输入被除数:a
程序发生了数字格式异常、算术异常之一
程序继续运行
上面程序中,第 6 行代码使用了(ValueError, ArithmeticError)来指定所捕获的异常类型,这就表明该 except 块可以同时捕获这 2 种类型的异常;第 8 行代码只有 except 关键字,并未指定具体要捕获的异常类型,这种省略异常类的 except 语句也是合法的,它表示可捕获所有类型的异常,一般会作为异常捕获的最后一个 except 块。
除此之外,由于 try 块中引发了异常,并被 except 块成功捕获,因此程序才可以继续执行,才有了“程序继续运行”的输出结果。
line 55 ~ line 77一直在创建相应的文件夹
line 79 ~ line 92看是运行train 还是test 还是mainip
def main(args):
# 确保train、test、manip 三者没有同时全部关闭
assert (args.train or args.test or args.manip), 'Cannot have train, test, and manip all set to 0, Nothing to do.'
# 加载train、val、test数据 load_data,split_data都在load_3D_data.py中
try:
train_list, val_list, test_list = load_data(args.data_root_dir, args.split_num)
except:
# 如果没有找到,创建训练和测试分割
split_data(args.data_root_dir, num_splits=4)
train_list, val_list, test_list = load_data(args.data_root_dir, args.split_num)
# 从第一个图像中获取图像属性。假设它们都是一样的。
img_shape = sitk.GetArrayFromImage(sitk.ReadImage(join(args.data_root_dir, 'imgs', train_list[0][0]))).shape
net_input_shape = (img_shape[1], img_shape[2], args.slices)
# 创建用于 train / val / test 的模型 create_model来自model_helper.py
model_list = create_model(args=args, input_shape=net_input_shape)
# print_summary 用于打印模型结构,现在好像是不用这个函数了,会报错,用pytorch改写的时候记得填上这一功能,用以纠错
print_summary(model=model_list[0], positions=[.38, .65, .75, 1.])
args.output_name = 'split-' + str(args.split_num) + '_batch-' + str(args.batch_size) + \
'_shuff-' + str(args.shuffle_data) + '_aug-' + str(args.aug_data) + \
'_loss-' + str(args.loss) + '_slic-' + str(args.slices) + \
'_sub-' + str(args.subsamp) + '_strid-' + str(args.stride) + \
'_lr-' + str(args.initial_lr) + '_recon-' + str(args.recon_wei)
args.time = time
args.check_dir = join(args.data_root_dir,'saved_models', args.net)
try:
makedirs(args.check_dir)
except:
pass
args.log_dir = join(args.data_root_dir,'logs', args.net)
try:
makedirs(args.log_dir)
except:
pass
args.tf_log_dir = join(args.log_dir, 'tf_logs')
try:
makedirs(args.tf_log_dir)
except:
pass
args.output_dir = join(args.data_root_dir, 'plots', args.net)
try:
makedirs(args.output_dir)
except:
pass
if args.train:
from train import train
# Run training
train(args, train_list, val_list, model_list[0], net_input_shape)
if args.test:
from test import test
# Run testing
test(args, test_list, model_list, net_input_shape)
if args.manip:
from manip import manip
# Run manipulation of segcaps
manip(args, test_list, model_list, net_input_shape)











网友评论