-
Notifications
You must be signed in to change notification settings - Fork 2
/
clargs.py
209 lines (160 loc) · 6.23 KB
/
clargs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import argparse
from helpers import str2bool, str2none
def unify_namespaces(*namespaces):
unified_dict = {}
for ns in namespaces:
ns_dict = vars(ns)
# Check for overlapping keys and warn
for key, value in ns_dict.items():
if key in unified_dict and unified_dict[key] != value:
print(f"Warning: Argument '{key}' from namespace {ns} is overriding previous value '{unified_dict[key]}' with '{value}'.")
unified_dict[key] = value
return argparse.Namespace(**unified_dict)
# model related args
def model_args_parser():
ap = argparse.ArgumentParser(description='LLM model arguments')
ap = argparse.ArgumentParser()
# model loading params
ap.add_argument("model_name_or_path", default="facebook/opt-125m", type=str, help="Model name or path")
ap.add_argument("--n_gpus", default=1, type=int, help="Number of GPUs to use for inference")
ap.add_argument("--seed", default=42, type=int, help="Random seed for initialization")
return ap
# api related args
def api_args_parser():
ap = argparse.ArgumentParser(description='api arguments')
ap.add_argument("--host", type=str, default="localhost")
ap.add_argument("--port", type=int, default=8000)
ap.add_argument("--stream", action="store_true")
ap.add_argument("--prompt", type=str, default="San Francisco is a")
ap.add_argument("--model_url", type=str, default="http://localhost:8000/generate")
return ap
# inference related args
def inference_args_parser():
ap = argparse.ArgumentParser(description='inference arguments')
# inference params
ap.add_argument("--num_return_sequences", default=1, type=int, help="Number of samples to generate for each prompt")
ap.add_argument("--best_of", default=1, type=int, help="Number of samples to generate for each prompt")
ap.add_argument("--presence_penalty", default=0.0, type=float, help="Presence penalty")
ap.add_argument("--frequency_penalty", default=0.2, type=float, help="Frequency penalty")
ap.add_argument("--temperature", default=0.8, type=float, help="Temperature")
ap.add_argument("--top_p", default=0.9, type=float, help="Top p")
ap.add_argument("--top_k", default=50, type=int, help="Top k")
ap.add_argument("--use_beam_search", default=False, type=bool, help="Use beam search")
ap.add_argument("--stop_tokens", nargs="+", default=None, type=str, help="Stop tokens")
ap.add_argument("--ignore_eos", default=False, type=bool, help="Ignore EOS")
ap.add_argument("--max_tokens", default=1024, type=int, help="Max output tokens before forcing EOS")
ap.add_argument("--logprobs", default=None, type=str, help="Logprobs")
ap.add_argument("--batch_size", default=1, type=int, help="Batch size")
ap.add_argument("--force", default=False, type=bool, help="Overwrite output file if it already exists. Otherwise, skip inference.")
return ap
def data_args_parser():
ap = argparse.ArgumentParser(description='data arguments')
ap.add_argument(
"--input_file",
type=str,
default=None,
help="Input file for inference",
)
ap.add_argument(
"--output_file",
type=str,
default=None,
help="Output file for inference",
)
ap.add_argument(
"--output_path",
type=str,
default="llm_dqa/resources/outputs/",
help="Output path for inference. Full file path will be inferred from args",
)
ap.add_argument(
"--log_path",
type=str,
default=None,
help="Log path for inference. If not specified, logs are written to models' outputs' subdir.",
)
# expected key in the input file
ap.add_argument(
"--src_key",
type=str,
default="question", # QuestionText for MSQA
help="Source column for inference",
)
ap.add_argument(
"--orig_src_key",
type=str,
default=None,
help="Original source column for inference before prompt formatting if available",
)
ap.add_argument(
"--tgt_key",
type=str,
default="answer", # ProcessedAnswerText for MSQA
help="Target column for inference",
)
ap.add_argument(
"--ctx_key",
type=str,
default=None,
help="Context column for inference",
)
ap.add_argument(
"--instruction_prefix",
type=str,
default=None,
help="Instruction prefix for inference (experimental)",
)
ap.add_argument(
"--index_path",
type=str,
default=None,
help="Path to index for RAG model",
)
ap.add_argument("--prompt_format", default="prompts/dummy", type=str, help="Prompt format template")
ap.add_argument("--verbose", action="store_true", help="Verbose")
ap.add_argument("--limit", default=-1, type=int, help="Limit number of examples to process")
ap.add_argument(
"--truncate_from_start",
type=str2bool,
nargs="?",
const=True,
default=True,
help="When handling model inputs that exceed model size, truncate from start instead of end"
)
ap.add_argument(
"--max_input_length",
default=4096,
type=int,
help="Max input length that the model accepts"
)
return ap
# retrieval related args
def retrieval_args_parser():
ap = argparse.ArgumentParser(description='retrieval arguments')
ap.add_argument(
"-k",
"--k",
type=int,
default=3,
help="Number of retrieved contexts",
)
ap.add_argument(
"--fetch_k",
type=int,
default=3,
help="Number of retrieved documents before applying filtering (must be >= k)",
)
return ap
if __name__ == "__main__":
model_args = model_args_parser().parse_known_args()[0]
print(model_args)
data_args = data_args_parser().parse_known_args()[0]
print(data_args)
inference_args = inference_args_parser().parse_known_args()[0]
print(inference_args)
retrieval_args = retrieval_args_parser().parse_known_args()[0]
print(retrieval_args)
args = unify_namespaces(model_args, data_args, inference_args, retrieval_args)
print(args)