-
-
Notifications
You must be signed in to change notification settings - Fork 1.4k
/
fine_tuning_job.go
159 lines (135 loc) · 4.25 KB
/
fine_tuning_job.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
package openai
import (
"context"
"fmt"
"net/http"
"net/url"
)
type FineTuningJob struct {
ID string `json:"id"`
Object string `json:"object"`
CreatedAt int64 `json:"created_at"`
FinishedAt int64 `json:"finished_at"`
Model string `json:"model"`
FineTunedModel string `json:"fine_tuned_model,omitempty"`
OrganizationID string `json:"organization_id"`
Status string `json:"status"`
Hyperparameters Hyperparameters `json:"hyperparameters"`
TrainingFile string `json:"training_file"`
ValidationFile string `json:"validation_file,omitempty"`
ResultFiles []string `json:"result_files"`
TrainedTokens int `json:"trained_tokens"`
httpHeader
}
type Hyperparameters struct {
Epochs any `json:"n_epochs,omitempty"`
LearningRateMultiplier any `json:"learning_rate_multiplier,omitempty"`
BatchSize any `json:"batch_size,omitempty"`
}
type FineTuningJobRequest struct {
TrainingFile string `json:"training_file"`
ValidationFile string `json:"validation_file,omitempty"`
Model string `json:"model,omitempty"`
Hyperparameters *Hyperparameters `json:"hyperparameters,omitempty"`
Suffix string `json:"suffix,omitempty"`
}
type FineTuningJobEventList struct {
Object string `json:"object"`
Data []FineTuneEvent `json:"data"`
HasMore bool `json:"has_more"`
httpHeader
}
type FineTuningJobEvent struct {
Object string `json:"object"`
ID string `json:"id"`
CreatedAt int `json:"created_at"`
Level string `json:"level"`
Message string `json:"message"`
Data any `json:"data"`
Type string `json:"type"`
}
// CreateFineTuningJob create a fine tuning job.
func (c *Client) CreateFineTuningJob(
ctx context.Context,
request FineTuningJobRequest,
) (response FineTuningJob, err error) {
urlSuffix := "/fine_tuning/jobs"
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL(urlSuffix), withBody(request))
if err != nil {
return
}
err = c.sendRequest(req, &response)
return
}
// CancelFineTuningJob cancel a fine tuning job.
func (c *Client) CancelFineTuningJob(ctx context.Context, fineTuningJobID string) (response FineTuningJob, err error) {
req, err := c.newRequest(ctx, http.MethodPost, c.fullURL("/fine_tuning/jobs/"+fineTuningJobID+"/cancel"))
if err != nil {
return
}
err = c.sendRequest(req, &response)
return
}
// RetrieveFineTuningJob retrieve a fine tuning job.
func (c *Client) RetrieveFineTuningJob(
ctx context.Context,
fineTuningJobID string,
) (response FineTuningJob, err error) {
urlSuffix := fmt.Sprintf("/fine_tuning/jobs/%s", fineTuningJobID)
req, err := c.newRequest(ctx, http.MethodGet, c.fullURL(urlSuffix))
if err != nil {
return
}
err = c.sendRequest(req, &response)
return
}
type listFineTuningJobEventsParameters struct {
after *string
limit *int
}
type ListFineTuningJobEventsParameter func(*listFineTuningJobEventsParameters)
func ListFineTuningJobEventsWithAfter(after string) ListFineTuningJobEventsParameter {
return func(args *listFineTuningJobEventsParameters) {
args.after = &after
}
}
func ListFineTuningJobEventsWithLimit(limit int) ListFineTuningJobEventsParameter {
return func(args *listFineTuningJobEventsParameters) {
args.limit = &limit
}
}
// ListFineTuningJobs list fine tuning jobs events.
func (c *Client) ListFineTuningJobEvents(
ctx context.Context,
fineTuningJobID string,
setters ...ListFineTuningJobEventsParameter,
) (response FineTuningJobEventList, err error) {
parameters := &listFineTuningJobEventsParameters{
after: nil,
limit: nil,
}
for _, setter := range setters {
setter(parameters)
}
urlValues := url.Values{}
if parameters.after != nil {
urlValues.Add("after", *parameters.after)
}
if parameters.limit != nil {
urlValues.Add("limit", fmt.Sprintf("%d", *parameters.limit))
}
encodedValues := ""
if len(urlValues) > 0 {
encodedValues = "?" + urlValues.Encode()
}
req, err := c.newRequest(
ctx,
http.MethodGet,
c.fullURL("/fine_tuning/jobs/"+fineTuningJobID+"/events"+encodedValues),
)
if err != nil {
return
}
err = c.sendRequest(req, &response)
return
}