From 00f36756b5ab3a2cef88aa90da2d0a3342f01733 Mon Sep 17 00:00:00 2001 From: Benny <775410794@qq.com> Date: Sun, 21 Mar 2021 13:32:48 +0800 Subject: [PATCH] change inplace=True in ReLU --- train_classification.py | 7 +++++++ train_partseg.py | 13 ++++++++++--- train_semseg.py | 5 +++++ 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/train_classification.py b/train_classification.py index c58195056..229f49db4 100644 --- a/train_classification.py +++ b/train_classification.py @@ -43,6 +43,12 @@ def parse_args(): return parser.parse_args() +def inplace_relu(m): + classname = m.__class__.__name__ + if classname.find('ReLU') != -1: + m.inplace=True + + def test(model, loader, num_class=40): mean_correct = [] class_acc = np.zeros((num_class, 3)) @@ -126,6 +132,7 @@ def log_string(str): classifier = model.get_model(num_class, normal_channel=args.use_normals) criterion = model.get_loss() + classifier.apply(inplace_relu) if not args.use_cpu: classifier = classifier.cuda() diff --git a/train_partseg.py b/train_partseg.py index 9621de70a..bb85e434e 100644 --- a/train_partseg.py +++ b/train_partseg.py @@ -4,18 +4,19 @@ """ import argparse import os -from data_utils.ShapeNetDataLoader import PartNormalDataset import torch import datetime import logging -from pathlib import Path import sys import importlib import shutil -from tqdm import tqdm import provider import numpy as np +from pathlib import Path +from tqdm import tqdm +from data_utils.ShapeNetDataLoader import PartNormalDataset + BASE_DIR = os.path.dirname(os.path.abspath(__file__)) ROOT_DIR = BASE_DIR sys.path.append(os.path.join(ROOT_DIR, 'models')) @@ -30,6 +31,11 @@ seg_label_to_cat[label] = cat +def inplace_relu(m): + classname = m.__class__.__name__ + if classname.find('ReLU') != -1: + m.inplace=True + def to_categorical(y, num_classes): """ 1-hot encodes a tensor """ new_y = torch.eye(num_classes)[y.cpu().data.numpy(),] @@ -111,6 +117,7 @@ def log_string(str): classifier = MODEL.get_model(num_part, normal_channel=args.normal).cuda() criterion = MODEL.get_loss().cuda() + classifier.apply(inplace_relu) def weights_init(m): classname = m.__class__.__name__ diff --git a/train_semseg.py b/train_semseg.py index 0f8c10f6e..c3bc91812 100644 --- a/train_semseg.py +++ b/train_semseg.py @@ -29,6 +29,10 @@ for i, cat in enumerate(seg_classes.keys()): seg_label_to_cat[i] = cat +def inplace_relu(m): + classname = m.__class__.__name__ + if classname.find('ReLU') != -1: + m.inplace=True def parse_args(): parser = argparse.ArgumentParser('Model') @@ -111,6 +115,7 @@ def log_string(str): classifier = MODEL.get_model(NUM_CLASSES).cuda() criterion = MODEL.get_loss().cuda() + classifier.apply(inplace_relu) def weights_init(m): classname = m.__class__.__name__