From 48f68016d402d0c84b832d5b7cc176293b98d9e0 Mon Sep 17 00:00:00 2001 From: Benny <775410794@qq.com> Date: Sun, 21 Mar 2021 12:04:31 +0800 Subject: [PATCH] fixed the bug for https://github.com/yanx27/Pointnet_Pointnet2_pytorch/issues/78 --- train_classification.py | 2 +- train_partseg.py | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/train_classification.py b/train_classification.py index 0cdd3e715..c58195056 100644 --- a/train_classification.py +++ b/train_classification.py @@ -114,7 +114,7 @@ def log_string(str): train_dataset = ModelNetDataLoader(root=data_path, args=args, split='train', process_data=args.process_data) test_dataset = ModelNetDataLoader(root=data_path, args=args, split='test', process_data=args.process_data) - trainDataLoader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=10) + trainDataLoader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=10, drop_last=True) testDataLoader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=10) '''MODEL LOADING''' diff --git a/train_partseg.py b/train_partseg.py index 6924fa857..9621de70a 100644 --- a/train_partseg.py +++ b/train_partseg.py @@ -40,7 +40,7 @@ def to_categorical(y, num_classes): def parse_args(): parser = argparse.ArgumentParser('Model') - parser.add_argument('--model', type=str, default='pointnet2_part_seg_msg', help='model name') + parser.add_argument('--model', type=str, default='pointnet_part_seg', help='model name') parser.add_argument('--batch_size', type=int, default=16, help='batch Size during training') parser.add_argument('--epoch', default=251, type=int, help='epoch to run') parser.add_argument('--learning_rate', default=0.001, type=float, help='initial learning rate') @@ -95,10 +95,9 @@ def log_string(str): root = 'data/shapenetcore_partanno_segmentation_benchmark_v0_normal/' TRAIN_DATASET = PartNormalDataset(root=root, npoints=args.npoint, split='trainval', normal_channel=args.normal) - trainDataLoader = torch.utils.data.DataLoader(TRAIN_DATASET, batch_size=args.batch_size, shuffle=True, - num_workers=4) + trainDataLoader = torch.utils.data.DataLoader(TRAIN_DATASET, batch_size=args.batch_size, shuffle=True, num_workers=10, drop_last=True) TEST_DATASET = PartNormalDataset(root=root, npoints=args.npoint, split='test', normal_channel=args.normal) - testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=args.batch_size, shuffle=False, num_workers=4) + testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=args.batch_size, shuffle=False, num_workers=10) log_string("The number of training data is: %d" % len(TRAIN_DATASET)) log_string("The number of test data is: %d" % len(TEST_DATASET))