美文网首页
[代码解读]SegCaps-main.py

[代码解读]SegCaps-main.py

作者: zelda2333 | 来源:发表于2020-12-27 18:38 被阅读0次

main.py

暂时有几个疑问:
1.参数中split_num到底指什么
2.line 48参数output_name是什么
3.mainip是做什么的

1. 先调用95行if __name__ == '__main__':main主函数

  • line 96 ~ line 156规定参数并调用

  • line 160 ~ line 161

    • os.environ[“CUDA_DEVICE_ORDER”] = “PCI_BUS_ID” 按照PCI_BUS_ID顺序从0开始排列GPU设备
    • os.environ[“CUDA_VISIBLE_DEVICES”] = “0”设置当前使用的GPU设备仅为0号设备 设备名称为'/gpu:0'
      os.environ[“CUDA_VISIBLE_DEVICES”] = “1”设置当前使用的GPU设备仅为1号设备 设备名称为'/gpu:0'
      os.environ[“CUDA_VISIBLE_DEVICES”] = “0,1”设置当前使用的GPU设备为0,1号两个设备,名称依次为'/gpu:0'、'/gpu:1'
      os.environ[“CUDA_VISIBLE_DEVICES”] = “1,0”设置当前使用的GPU设备为1,0号两个设备,名称依次为'/gpu:0'、'/gpu:1'。表示优先使用1号设备,然后使用0号设备
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Train on Medical Data')
    parser.add_argument('--data_root_dir', type=str, required=True,
                        help='The root directory for your data.')
    parser.add_argument('--weights_path', type=str, default='',
                        help='/path/to/trained_model.hdf5 from root. Set to "" for none.')
    parser.add_argument('--split_num', type=int, default=0,
                        help='Which training split to train/test on.')
    parser.add_argument('--net', type=str.lower, default='segcapsr3',
                        choices=['segcapsr3', 'segcapsr1', 'segcapsbasic', 'unet', 'tiramisu'],
                        help='Choose your network.')
    parser.add_argument('--train', type=int, default=1, choices=[0,1],
                        help='Set to 1 to enable training.')
    parser.add_argument('--test', type=int, default=1, choices=[0,1],
                        help='Set to 1 to enable testing.')
    parser.add_argument('--manip', type=int, default=1, choices=[0,1],
                        help='Set to 1 to enable manipulation.')
    parser.add_argument('--shuffle_data', type=int, default=1, choices=[0,1],
                        help='Whether or not to shuffle the training data (both per epoch and in slice order.')
    parser.add_argument('--aug_data', type=int, default=1, choices=[0,1],
                        help='Whether or not to use data augmentation during training.')
    parser.add_argument('--loss', type=str.lower, default='w_bce', choices=['bce', 'w_bce', 'dice', 'mar', 'w_mar'],
                        help='Which loss to use. "bce" and "w_bce": unweighted and weighted binary cross entropy'
                             '"dice": soft dice coefficient, "mar" and "w_mar": unweighted and weighted margin loss.')
    parser.add_argument('--batch_size', type=int, default=1,
                        help='Batch size for training/testing.')
    parser.add_argument('--initial_lr', type=float, default=0.0001,
                        help='Initial learning rate for Adam.')
    parser.add_argument('--recon_wei', type=float, default=.1, #131.072,
                        help="If using capsnet: The coefficient (weighting) for the loss of decoder")
    parser.add_argument('--slices', type=int, default=1,
                        help='Number of slices to include for training/testing.')
    # 形成 3D 样本进行训练时要跳过的片数。输入-1,用于随机子取样,最多占总切片的5%。
    parser.add_argument('--subsamp', type=int, default=-1,
                        help='Number of slices to skip when forming 3D samples for training. Enter -1 for random '
                             'subsampling up to 5% of total slices.')
    parser.add_argument('--stride', type=int, default=1,
                        help='Number of slices to move when generating the next sample.')
    # verbose是日志显示 0:每张切片都显示训练过程;1:每个iteration显示训练过程;2:每个epoch显示训练过程
    parser.add_argument('--verbose', type=int, default=1, choices=[0, 1, 2],
                        help='Set the verbose value for training. 0: Silent, 1: per iteration, 2: per epoch.')
    parser.add_argument('--save_raw', type=int, default=1, choices=[0,1],
                        help='Enter 0 to not save, 1 to save.')
    parser.add_argument('--save_seg', type=int, default=1, choices=[0,1],
                        help='Enter 0 to not save, 1 to save.')
    parser.add_argument('--save_prefix', type=str, default='',
                        help='Prefix to append to saved CSV.')
    parser.add_argument('--thresh_level', type=float, default=0.,
                        help='Enter 0.0 for otsu thresholding, else set value')
    parser.add_argument('--compute_dice', type=int, default=1,
                        help='0 or 1')
    parser.add_argument('--compute_jaccard', type=int, default=1,
                        help='0 or 1')
    parser.add_argument('--compute_assd', type=int, default=0,
                        help='0 or 1')
    # 输入"-2 "表示只有CPU,"-1 "表示所有可用的GPU,或者输入一个以逗号分隔的GPU ID数字列表,例如:"0,1,4"。
    parser.add_argument('--which_gpus', type=str, default="0",
                        help='Enter "-2" for CPU only, "-1" for all GPUs available, '
                             'or a comma separated list of GPU id numbers ex: "0,1,4".')
    parser.add_argument('--gpus', type=int, default=-1,
                        help='Number of GPUs you have available for training. '
                             'If entering specific GPU ids under the --which_gpus arg or if using CPU, '
                             'then this number will be inferred, else this argument must be included.')

    arguments = parser.parse_args()


    if arguments.which_gpus == -2:
        environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
        environ["CUDA_VISIBLE_DEVICES"] = ""
    elif arguments.which_gpus == '-1':
        assert (arguments.gpus != -1), 'Use all GPUs option selected under --which_gpus, with this option the user MUST ' \
                                  'specify the number of GPUs available with the --gpus option.'
    else:
        arguments.gpus = len(arguments.which_gpus.split(','))
        environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
        environ["CUDA_VISIBLE_DEVICES"] = str(arguments.which_gpus)

    if arguments.gpus > 1:
        assert arguments.batch_size >= arguments.gpus, 'Error: Must have at least as many items per batch as GPUs ' \
                                                       'for multi-GPU training. For model parallelism instead of ' \
                                                       'data parallelism, modifications must be made to the code.'

    main(arguments)

2.调用def main(args)函数将相应参数传给此函数

  • line 33 ~ line 39 try/except
    python中try/except/else/finally语句的完整格式如下所示:
try:
      Normal execution block
except A:
      Exception A handle
except B:
      Exception B handle
except:
      Other exception handle
else:
      if no exception,get here
finally:
      print(“finally”)

说明:
正常执行的程序在try下面的Normal execution block执行块中执行,在执行过程中如果发生了异常,则中断当前在Normal execution block中的执行,跳转到对应的异常处理块中开始执行;
python从第一个except X处开始查找,如果找到了对应的exception类型则进入其提供的exception handle中进行处理,如果没有找到则直接进入except块处进行处理。except块是可选项,如果没有提供,该exception将会被提交给python进行默认处理,处理方式则是终止应用程序并打印提示信息;
如果在Normal execution block执行块中执行过程中没有发生任何异常,则在执行完Normal execution block后会进入else执行块中(如果存在的话)执行
无论是否发生了异常,只要提供了finally语句,以上try/except/else/finally代码块执行的最后一步总是执行finally所对应的代码块

需要注意的是:
1.在上面所示的完整语句中try/except/else/finally所出现的顺序必须是try –> except X –> except –> else –> finally,即所有的except必须在else和finally之前,else(如果有的话)必须在finally之前,而except X必须在except之前。否则会出现语法错误。
2.对于上面所展示的try/except完整格式而言,else和finally都是可选的,而不是必须的,但是如果存在的话else必须在finally之前,finally(如果存在的话)必须在整个语句的最后位置。
3.在上面的完整语句中,else语句的存在必须以except X或者except语句为前提,如果在没有except语句的try block中使用else语句会引发语法错误。也就是说else不能与try/finally配合使用。
4.except的使用要非常小心,慎用。

举个例子:

try:
    a = int(input("输入被除数:"))
    b = int(input("输入除数:"))
    c = a / b
    print("您输入的两个数相除的结果是:", c )
except (ValueError, ArithmeticError):
    print("程序发生了数字格式异常、算术异常之一")
except :
    print("未知异常")
print("程序继续运行")

程序运行结果为:

输入被除数:a
程序发生了数字格式异常、算术异常之一
程序继续运行

上面程序中,第 6 行代码使用了(ValueError, ArithmeticError)来指定所捕获的异常类型,这就表明该 except 块可以同时捕获这 2 种类型的异常;第 8 行代码只有 except 关键字,并未指定具体要捕获的异常类型,这种省略异常类的 except 语句也是合法的,它表示可捕获所有类型的异常,一般会作为异常捕获的最后一个 except 块。

除此之外,由于 try 块中引发了异常,并被 except 块成功捕获,因此程序才可以继续执行,才有了“程序继续运行”的输出结果。


line 55 ~ line 77一直在创建相应的文件夹
line 79 ~ line 92看是运行train 还是test 还是mainip

def main(args):
    # 确保train、test、manip 三者没有同时全部关闭
    assert (args.train or args.test or args.manip), 'Cannot have train, test, and manip all set to 0, Nothing to do.'

    # 加载train、val、test数据  load_data,split_data都在load_3D_data.py中
    try:
        train_list, val_list, test_list = load_data(args.data_root_dir, args.split_num)
    except:
        # 如果没有找到,创建训练和测试分割
        split_data(args.data_root_dir, num_splits=4)
        train_list, val_list, test_list = load_data(args.data_root_dir, args.split_num)

    # 从第一个图像中获取图像属性。假设它们都是一样的。
    img_shape = sitk.GetArrayFromImage(sitk.ReadImage(join(args.data_root_dir, 'imgs', train_list[0][0]))).shape
    net_input_shape = (img_shape[1], img_shape[2], args.slices)

    # 创建用于 train / val / test 的模型 create_model来自model_helper.py
    model_list = create_model(args=args, input_shape=net_input_shape)
    # print_summary 用于打印模型结构,现在好像是不用这个函数了,会报错,用pytorch改写的时候记得填上这一功能,用以纠错
    print_summary(model=model_list[0], positions=[.38, .65, .75, 1.])

    args.output_name = 'split-' + str(args.split_num) + '_batch-' + str(args.batch_size) + \
                       '_shuff-' + str(args.shuffle_data) + '_aug-' + str(args.aug_data) + \
                       '_loss-' + str(args.loss) + '_slic-' + str(args.slices) + \
                       '_sub-' + str(args.subsamp) + '_strid-' + str(args.stride) + \
                       '_lr-' + str(args.initial_lr) + '_recon-' + str(args.recon_wei)
    args.time = time

    args.check_dir = join(args.data_root_dir,'saved_models', args.net)
    try:
        makedirs(args.check_dir)
    except:
        pass

    args.log_dir = join(args.data_root_dir,'logs', args.net)
    try:
        makedirs(args.log_dir)
    except:
        pass

    args.tf_log_dir = join(args.log_dir, 'tf_logs')
    try:
        makedirs(args.tf_log_dir)
    except:
        pass

    args.output_dir = join(args.data_root_dir, 'plots', args.net)
    try:
        makedirs(args.output_dir)
    except:
        pass

    if args.train:
        from train import train
        # Run training
        train(args, train_list, val_list, model_list[0], net_input_shape)

    if args.test:
        from test import test
        # Run testing
        test(args, test_list, model_list, net_input_shape)

    if args.manip:
        from manip import manip
        # Run manipulation of segcaps
        manip(args, test_list, model_list, net_input_shape)

参考链接:
try/except介绍
os.environ
Python try except异常处理详解(入门必读)

相关文章

  • [代码解读]SegCaps-main.py

    main.py 暂时有几个疑问:1.参数中split_num到底指什么2.line 48参数output_name...

  • 代码解读

    DeltaresHydro类 public函数 DeltaresHydro(int argc, char* arg...

  • 异或 算法

    解读:网上的一段代码

  • Unet 论文解读 代码解读

    论文地址:http://www.arxiv.org/pdf/1505.04597.pdf 论文解读 Archite...

  • LayoutInflater源码浅读

    LayoutInflater代码解读 本文主要解读LayoutInflater加载XML的过程。 我们经常在Lis...

  • vdsm代码解读

    代码目录:

  • AdvSemiSeg代码解读

    model = Res_Deeplab(num_classes=args.num_classes) Res_Dee...

  • ppp代码解读

    1.相关rfc rfc1661 2.pppd 2.1.数据结构 protocols,所在文件main.c中初始化,...

  • Masonry代码解读

    一.前言 Masonry是非常有名的布局框架,今天我们就分析它的具体实现。通读了一边源码,写的非常的好,有很多值得...

  • VINS代码解读

    VINS_estimator 摘抄我们初始化的原因是单目惯性紧耦合系统是一个非线性程度很高的系统,首先单目是无法获...

网友评论

      本文标题:[代码解读]SegCaps-main.py

      本文链接:https://www.haomeiwen.com/subject/tpronktx.html