diff --git a/crslab/system/tgredial.py b/crslab/system/tgredial.py index 3aaaa7b9..a424f3d5 100644 --- a/crslab/system/tgredial.py +++ b/crslab/system/tgredial.py @@ -167,10 +167,7 @@ def step(self, batch, stage, mode): def train_recommender(self): if hasattr(self.rec_model, 'bert'): - if os.environ["CUDA_VISIBLE_DEVICES"] == '-1': - bert_param = list(self.rec_model.bert.named_parameters()) - else: - bert_param = list(self.rec_model.module.bert.named_parameters()) + bert_param = list(self.rec_model.bert.named_parameters()) bert_param_name = ['bert.' + n for n, p in bert_param] else: bert_param = []