From 1a5acf39d84fb4aedab3cb63c0bce16a1a02aa46 Mon Sep 17 00:00:00 2001 From: enchantee00 Date: Fri, 12 Apr 2024 15:51:33 +0900 Subject: [PATCH 1/3] fix bugs --- crslab/system/tgredial.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/crslab/system/tgredial.py b/crslab/system/tgredial.py index 3aaaa7b9..d6d84731 100644 --- a/crslab/system/tgredial.py +++ b/crslab/system/tgredial.py @@ -166,11 +166,12 @@ def step(self, batch, stage, mode): raise def train_recommender(self): + breakpoint() if hasattr(self.rec_model, 'bert'): if os.environ["CUDA_VISIBLE_DEVICES"] == '-1': bert_param = list(self.rec_model.bert.named_parameters()) else: - bert_param = list(self.rec_model.module.bert.named_parameters()) + bert_param = list(self.rec_model.bert.named_parameters()) bert_param_name = ['bert.' + n for n, p in bert_param] else: bert_param = [] From 84630f848ec9f99dea79e9d7b1482f02f95ae372 Mon Sep 17 00:00:00 2001 From: enchantee00 Date: Fri, 12 Apr 2024 15:57:35 +0900 Subject: [PATCH 2/3] fix bugs --- crslab/system/tgredial.py | 1 - 1 file changed, 1 deletion(-) diff --git a/crslab/system/tgredial.py b/crslab/system/tgredial.py index d6d84731..620bb609 100644 --- a/crslab/system/tgredial.py +++ b/crslab/system/tgredial.py @@ -166,7 +166,6 @@ def step(self, batch, stage, mode): raise def train_recommender(self): - breakpoint() if hasattr(self.rec_model, 'bert'): if os.environ["CUDA_VISIBLE_DEVICES"] == '-1': bert_param = list(self.rec_model.bert.named_parameters()) From 982ce9108e4ea24632f41c89ebddb6064e3350a8 Mon Sep 17 00:00:00 2001 From: enchantee00 Date: Fri, 12 Apr 2024 16:13:38 +0900 Subject: [PATCH 3/3] remove duplicate code --- crslab/system/tgredial.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/crslab/system/tgredial.py b/crslab/system/tgredial.py index 620bb609..a424f3d5 100644 --- a/crslab/system/tgredial.py +++ b/crslab/system/tgredial.py @@ -167,10 +167,7 @@ def step(self, batch, stage, mode): def train_recommender(self): if hasattr(self.rec_model, 'bert'): - if os.environ["CUDA_VISIBLE_DEVICES"] == '-1': - bert_param = list(self.rec_model.bert.named_parameters()) - else: - bert_param = list(self.rec_model.bert.named_parameters()) + bert_param = list(self.rec_model.bert.named_parameters()) bert_param_name = ['bert.' + n for n, p in bert_param] else: bert_param = []