-
Notifications
You must be signed in to change notification settings - Fork 0
/
config.py
77 lines (67 loc) · 2.72 KB
/
config.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""
This file contains the configuration for the training process.
"""
from pathlib import Path
def get_config():
"""
Returns the configuration dictionary for the training process.
:return: the configuration dictionary
"""
return {
# Batch size for training
"batch_size": 8,
# Number of epochs to train for
"num_epochs": 40,
# Learning rate
"lr": 10**-4,
# Sequence length, should be more than the longest sentence in the dataset (printed at the beginning)
"seq_len": 350,
# Dimension of the model, 512 is the default mentioned in the paper
"d_model": 512,
# Source language of the dataset
"lang_src": "en",
# Target language of the dataset
"lang_tgt": "it",
"datasource": 'opus_books',
"model_basename": "tmodel_",
"model_folder": "weights",
# Whether to use the latest weights file in the weights folder
# set to None to not use any weights and start training from scratch
# set to 'latest' to use the latest weights file
"preload": "latest", # None or 'latest'
# Tokenizer file, this is where the tokenizer will be saved
"tokenizer_file": "tokenizer_{0}.json",
"experiment_name": "runs/tmodel",
"validation_each_step": False
}
def get_weights_file_path(config, epoch: str):
"""
Returns the path to the weights file for the given epoch.
:param config: the configuration dictionary
:param epoch: the epoch number
:return: the path to the weights file
"""
model_folder = f"{config['datasource']}_{config['model_folder']}"
model_filename = f"{config['model_basename']}{epoch}.pt"
return str(Path('.') / model_folder / model_filename)
# Find the latest weights file in the weights folder
def latest_weights_file_path(config):
"""
Finds the latest weights file in the weights folder, when preload is set to latest
this function will be used to find the latest weights file.
:param config: the configuration dictionary
:return: the path to the latest weights file
"""
model_folder = f"{config['datasource']}_{config['model_folder']}"
model_filename = f"{config['model_basename']}*"
#weights_files = list(Path(model_folder).glob(model_filename))
#weights_files = list(Path(model_folder).glob("tmodel_39*.pt"))
# Get betsi.py from the weights folder
#weights_files = list(Path("weights").glob('*.pt'))
weights_files = list(Path("weights").glob('betsi*'))
print(f"weights_files: {weights_files}")
# If there are no weights files, return None
if len(weights_files) == 0:
return None
# Return the last weights file
return str(weights_files[-1])