Skip to content

Commit

Permalink
fixed the bug for yanx27#78
Browse files Browse the repository at this point in the history
  • Loading branch information
yanx27 committed Mar 21, 2021
1 parent 63f7183 commit 48f6801
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 5 deletions.
2 changes: 1 addition & 1 deletion train_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def log_string(str):

train_dataset = ModelNetDataLoader(root=data_path, args=args, split='train', process_data=args.process_data)
test_dataset = ModelNetDataLoader(root=data_path, args=args, split='test', process_data=args.process_data)
trainDataLoader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=10)
trainDataLoader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=10, drop_last=True)
testDataLoader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False, num_workers=10)

'''MODEL LOADING'''
Expand Down
7 changes: 3 additions & 4 deletions train_partseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def to_categorical(y, num_classes):

def parse_args():
parser = argparse.ArgumentParser('Model')
parser.add_argument('--model', type=str, default='pointnet2_part_seg_msg', help='model name')
parser.add_argument('--model', type=str, default='pointnet_part_seg', help='model name')
parser.add_argument('--batch_size', type=int, default=16, help='batch Size during training')
parser.add_argument('--epoch', default=251, type=int, help='epoch to run')
parser.add_argument('--learning_rate', default=0.001, type=float, help='initial learning rate')
Expand Down Expand Up @@ -95,10 +95,9 @@ def log_string(str):
root = 'data/shapenetcore_partanno_segmentation_benchmark_v0_normal/'

TRAIN_DATASET = PartNormalDataset(root=root, npoints=args.npoint, split='trainval', normal_channel=args.normal)
trainDataLoader = torch.utils.data.DataLoader(TRAIN_DATASET, batch_size=args.batch_size, shuffle=True,
num_workers=4)
trainDataLoader = torch.utils.data.DataLoader(TRAIN_DATASET, batch_size=args.batch_size, shuffle=True, num_workers=10, drop_last=True)
TEST_DATASET = PartNormalDataset(root=root, npoints=args.npoint, split='test', normal_channel=args.normal)
testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=args.batch_size, shuffle=False, num_workers=4)
testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=args.batch_size, shuffle=False, num_workers=10)
log_string("The number of training data is: %d" % len(TRAIN_DATASET))
log_string("The number of test data is: %d" % len(TEST_DATASET))

Expand Down

0 comments on commit 48f6801

Please sign in to comment.