-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathmain.py
66 lines (53 loc) · 1.93 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from reckit import Configurator
from importlib.util import find_spec
from importlib import import_module
from reckit import typeassert
import os
import sys
import numpy as np
import random
import torch
def _set_random_seed(seed=2020):
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
print("set pytorch seed")
@typeassert(recommender=str)
def find_recommender(recommender):
model_dirs = set(os.listdir("model"))
model_dirs.remove("base")
module = None
for tdir in model_dirs:
spec_path = ".".join(["model", tdir, recommender])
if find_spec(spec_path):
module = import_module(spec_path)
break
if module is None:
raise ImportError(f"Recommender: {recommender} not found")
if hasattr(module, recommender):
Recommender = getattr(module, recommender)
else:
raise ImportError(f"Import {recommender} failed from {module.__file__}!")
return Recommender
if __name__ == "__main__":
is_windows = sys.platform.startswith('win')
if is_windows:
root_dir = 'XXXXXXXX/PythonProjects/SGL-torch/'
data_dir = 'XXXXXXXX/PythonProjects/SGL-torch/dataset/'
else:
root_dir = 'XXXXXXXX/PythonProjects/SGL-torch/'
data_dir = 'XXXXXXXX/PythonProjects/SGL-torch/dataset/'
config = Configurator(root_dir, data_dir)
config.add_config(root_dir + "NeuRec.ini", section="NeuRec")
config.parse_cmd()
os.environ['CUDA_VISIBLE_DEVICES'] = str(config["gpu_id"])
_set_random_seed(config["seed"])
Recommender = find_recommender(config.recommender)
model_cfg = os.path.join(root_dir + "conf", config.recommender+".ini")
config.add_config(model_cfg, section="hyperparameters", used_as_summary=True)
recommender = Recommender(config)
recommender.train_model()