From c6c8053ccce7d23814126be00d76d70e60047a21 Mon Sep 17 00:00:00 2001 From: "Laisky.Cai" Date: Wed, 8 Jan 2025 05:00:33 +0000 Subject: [PATCH] fix: audio transcription only charge for the length of audio duration --- Dockerfile | 4 +- common/helper/audio.go | 40 ++++++++++ common/helper/audio_test.go | 37 +++++++++ go.mod | 21 +++-- go.sum | 38 ++++----- relay/adaptor/openai/model.go | 24 +++++- relay/billing/ratio/model.go | 2 + relay/controller/audio.go | 146 ++++++++++++++++++++++++---------- router/web.go | 5 +- 9 files changed, 238 insertions(+), 79 deletions(-) create mode 100644 common/helper/audio.go create mode 100644 common/helper/audio_test.go diff --git a/Dockerfile b/Dockerfile index ade561e408..72fbd08d93 100644 --- a/Dockerfile +++ b/Dockerfile @@ -35,10 +35,10 @@ FROM alpine RUN apk update \ && apk upgrade \ - && apk add --no-cache ca-certificates tzdata \ + && apk add --no-cache ca-certificates tzdata ffmpeg \ && update-ca-certificates 2>/dev/null || true COPY --from=builder2 /build/one-api / EXPOSE 3000 WORKDIR /data -ENTRYPOINT ["/one-api"] \ No newline at end of file +ENTRYPOINT ["/one-api"] diff --git a/common/helper/audio.go b/common/helper/audio.go new file mode 100644 index 0000000000..9db62f42d1 --- /dev/null +++ b/common/helper/audio.go @@ -0,0 +1,40 @@ +package helper + +import ( + "bytes" + "context" + "io" + "os" + "os/exec" + "strconv" + + "github.com/pkg/errors" +) + +// SaveTmpFile saves data to a temporary file. The filename would be apppended with a random string. +func SaveTmpFile(filename string, data io.Reader) (string, error) { + f, err := os.CreateTemp(os.TempDir(), filename) + if err != nil { + return "", errors.Wrapf(err, "failed to create temporary file %s", filename) + } + defer f.Close() + + _, err = io.Copy(f, data) + if err != nil { + return "", errors.Wrapf(err, "failed to copy data to temporary file %s", filename) + } + + return f.Name(), nil +} + +// GetAudioDuration returns the duration of an audio file in seconds. +func GetAudioDuration(ctx context.Context, filename string) (float64, error) { + // ffprobe -v error -show_entries format=duration -of default=noprint_wrappers=1:nokey=1 {{input}} + c := exec.CommandContext(ctx, "ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", filename) + output, err := c.Output() + if err != nil { + return 0, errors.Wrap(err, "failed to get audio duration") + } + + return strconv.ParseFloat(string(bytes.TrimSpace(output)), 64) +} diff --git a/common/helper/audio_test.go b/common/helper/audio_test.go new file mode 100644 index 0000000000..90f334a31d --- /dev/null +++ b/common/helper/audio_test.go @@ -0,0 +1,37 @@ +package helper + +import ( + "context" + "io" + "net/http" + "os" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGetAudioDuration(t *testing.T) { + t.Run("should return correct duration for a valid audio file", func(t *testing.T) { + tmpFile, err := os.CreateTemp("", "test_audio*.mp3") + require.NoError(t, err) + defer os.Remove(tmpFile.Name()) + + // download test audio file + resp, err := http.Get("https://s3.laisky.com/uploads/2025/01/audio-sample.m4a") + require.NoError(t, err) + defer resp.Body.Close() + + _, err = io.Copy(tmpFile, resp.Body) + require.NoError(t, err) + require.NoError(t, tmpFile.Close()) + + duration, err := GetAudioDuration(context.Background(), tmpFile.Name()) + require.NoError(t, err) + require.Equal(t, duration, 3.904) + }) + + t.Run("should return an error for a non-existent file", func(t *testing.T) { + _, err := GetAudioDuration(context.Background(), "non_existent_file.mp3") + require.Error(t, err) + }) +} diff --git a/go.mod b/go.mod index 2106cf0f79..136546c1f8 100644 --- a/go.mod +++ b/go.mod @@ -27,6 +27,7 @@ require ( github.com/stretchr/testify v1.9.0 golang.org/x/crypto v0.31.0 golang.org/x/image v0.18.0 + golang.org/x/sync v0.10.0 google.golang.org/api v0.187.0 gorm.io/driver/mysql v1.5.6 gorm.io/driver/postgres v1.5.7 @@ -38,29 +39,27 @@ require ( cloud.google.com/go/auth v0.6.1 // indirect cloud.google.com/go/auth/oauth2adapt v0.2.2 // indirect cloud.google.com/go/compute/metadata v0.3.0 // indirect - filippo.io/edwards25519 v1.1.0 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.2 // indirect github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.7 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.7 // indirect github.com/aws/smithy-go v1.20.2 // indirect github.com/bytedance/sonic v1.11.6 // indirect github.com/bytedance/sonic/loader v0.1.1 // indirect - github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/cloudwego/base64x v0.1.4 // indirect github.com/cloudwego/iasm v0.2.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect - github.com/dlclark/regexp2 v1.11.0 // indirect + github.com/dlclark/regexp2 v1.10.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect - github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/gabriel-vasile/mimetype v1.4.3 // indirect github.com/gin-contrib/sse v0.1.0 // indirect github.com/go-logr/logr v1.4.1 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect - github.com/go-sql-driver/mysql v1.8.1 // indirect - github.com/goccy/go-json v0.10.3 // indirect + github.com/go-sql-driver/mysql v1.7.0 // indirect + github.com/goccy/go-json v0.10.2 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/protobuf v1.5.4 // indirect github.com/google/s2a-go v0.1.7 // indirect @@ -71,9 +70,8 @@ require ( github.com/gorilla/securecookie v1.1.2 // indirect github.com/gorilla/sessions v1.2.2 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect - github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 // indirect - github.com/jackc/pgx/v5 v5.5.5 // indirect - github.com/jackc/puddle/v2 v2.2.1 // indirect + github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect + github.com/jackc/pgx/v5 v5.4.3 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/json-iterator/go v1.1.12 // indirect @@ -82,7 +80,7 @@ require ( github.com/kr/text v0.2.0 // indirect github.com/leodido/go-urn v1.4.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect - github.com/mattn/go-sqlite3 v1.14.22 // indirect + github.com/mattn/go-sqlite3 v1.14.17 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect @@ -99,13 +97,12 @@ require ( golang.org/x/arch v0.8.0 // indirect golang.org/x/net v0.26.0 // indirect golang.org/x/oauth2 v0.21.0 // indirect - golang.org/x/sync v0.10.0 // indirect golang.org/x/sys v0.28.0 // indirect golang.org/x/text v0.21.0 // indirect golang.org/x/time v0.5.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240617180043-68d350f18fd4 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20240624140628-dc46fd24d27d // indirect - google.golang.org/grpc v1.64.1 // indirect + google.golang.org/grpc v1.64.0 // indirect google.golang.org/protobuf v1.34.2 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index c98f19656c..e04bad1f90 100644 --- a/go.sum +++ b/go.sum @@ -7,8 +7,6 @@ cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2Qx cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= cloud.google.com/go/iam v1.1.10 h1:ZSAr64oEhQSClwBL670MsJAW5/RLiC6kfw3Bqmd5ZDI= cloud.google.com/go/iam v1.1.10/go.mod h1:iEgMq62sg8zx446GCaijmA2Miwg5o3UbO+nI47WHJps= -filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= -filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/aws/aws-sdk-go-v2 v1.27.0 h1:7bZWKoXhzI+mMR/HjdMx8ZCC5+6fY0lS5tr0bbgiLlo= github.com/aws/aws-sdk-go-v2 v1.27.0/go.mod h1:ffIFB97e2yNsv4aTSGkqtHnppsIJzw7G7BReUZ3jCXM= @@ -29,8 +27,8 @@ github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1 github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM= github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= -github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= -github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= +github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y= github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= @@ -43,16 +41,15 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= -github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI= -github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/dlclark/regexp2 v1.10.0 h1:+/GIL799phkJqYW+3YbOd8LCcbHzT0Pbo8zl70MHsq0= +github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= -github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= -github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= +github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= github.com/gin-contrib/cors v1.7.2 h1:oLDHxdg8W/XDoN/8zamqk/Drgt4oVZDvaV0YmvVICQw= @@ -81,11 +78,10 @@ github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBEx github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= +github.com/go-sql-driver/mysql v1.7.0 h1:ueSltNNllEqE3qcWBTD0iQd3IpL/6U+mJxLkazJ7YPc= github.com/go-sql-driver/mysql v1.7.0/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= -github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= -github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= -github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA= -github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= +github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= +github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= @@ -134,12 +130,10 @@ github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/ github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= -github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA= -github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= -github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw= -github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= -github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= -github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.4.3 h1:cxFyXhxlvAifxnkKKdlxv8XqUf59tDlYjnV5YYfsJJY= +github.com/jackc/pgx/v5 v5.4.3/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA= github.com/jinzhu/copier v0.4.0 h1:w3ciUoD19shMCRargcpm0cm91ytaBhDvuRpz1ODO/U8= github.com/jinzhu/copier v0.4.0/go.mod h1:DfbEm0FYsaqBcKcFuvmOZb218JkPGtvSHsKg8S8hyyg= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= @@ -163,8 +157,8 @@ github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= -github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= +github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -282,8 +276,8 @@ google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyac google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= google.golang.org/grpc v1.33.2/go.mod h1:JMHMWHQWaTccqQQlmk3MJZS+GWXOdAesneDmEnv2fbc= -google.golang.org/grpc v1.64.1 h1:LKtvyfbX3UGVPFcGqJ9ItpVWW6oN/2XqTxfAnwRRXiA= -google.golang.org/grpc v1.64.1/go.mod h1:hiQF4LFZelK2WKaP6W0L92zGHtiQdZxk8CrSdvyjeP0= +google.golang.org/grpc v1.64.0 h1:KH3VH9y/MgNQg1dE7b3XfVK0GsPSIzJwdF617gUSbvY= +google.golang.org/grpc v1.64.0/go.mod h1:oxjF8E3FBnjp+/gVFYdWacaLDx9na1aqy9oovLpxQYg= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= diff --git a/relay/adaptor/openai/model.go b/relay/adaptor/openai/model.go index 4c974de4ad..39e872626b 100644 --- a/relay/adaptor/openai/model.go +++ b/relay/adaptor/openai/model.go @@ -1,6 +1,10 @@ package openai -import "github.com/songquanpeng/one-api/relay/model" +import ( + "mime/multipart" + + "github.com/songquanpeng/one-api/relay/model" +) type TextContent struct { Type string `json:"type,omitempty"` @@ -71,6 +75,24 @@ type TextToSpeechRequest struct { ResponseFormat string `json:"response_format"` } +type AudioTranscriptionRequest struct { + File *multipart.FileHeader `form:"file" binding:"required"` + Model string `form:"model" binding:"required"` + Language string `form:"language"` + Prompt string `form:"prompt"` + ReponseFormat string `form:"response_format" binding:"oneof=json text srt verbose_json vtt"` + Temperature float64 `form:"temperature"` + TimestampGranularity []string `form:"timestamp_granularity"` +} + +type AudioTranslationRequest struct { + File *multipart.FileHeader `form:"file" binding:"required"` + Model string `form:"model" binding:"required"` + Prompt string `form:"prompt"` + ResponseFormat string `form:"response_format" binding:"oneof=json text srt verbose_json vtt"` + Temperature float64 `form:"temperature"` +} + type UsageOrResponseText struct { *model.Usage ResponseText string diff --git a/relay/billing/ratio/model.go b/relay/billing/ratio/model.go index f83aa70c11..d1720a990f 100644 --- a/relay/billing/ratio/model.go +++ b/relay/billing/ratio/model.go @@ -337,6 +337,8 @@ var CompletionRatio = map[string]float64{ // aws llama3 "llama3-8b-8192(33)": 0.0006 / 0.0003, "llama3-70b-8192(33)": 0.0035 / 0.00265, + // whisper + "whisper-1": 0, // only count input tokens } var ( diff --git a/relay/controller/audio.go b/relay/controller/audio.go index e3d57b1eb4..bc756f65c4 100644 --- a/relay/controller/audio.go +++ b/relay/controller/audio.go @@ -5,17 +5,20 @@ import ( "bytes" "context" "encoding/json" - "errors" "fmt" "io" + "math" + "mime/multipart" "net/http" + "os" "strings" "github.com/gin-gonic/gin" + "github.com/pkg/errors" "github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common/client" - "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/common/helper" "github.com/songquanpeng/one-api/common/logger" "github.com/songquanpeng/one-api/model" "github.com/songquanpeng/one-api/relay/adaptor/openai" @@ -27,6 +30,53 @@ import ( "github.com/songquanpeng/one-api/relay/relaymode" ) +const ( + TokensPerSecond = 1000 / 20 // $0.006 / minute -> $0.002 / 20 seconds -> $0.002 / 1K tokens +) + +type commonAudioRequest struct { + File *multipart.FileHeader `form:"file" binding:"required"` +} + +func countAudioTokens(c *gin.Context) (int, error) { + body, err := common.GetRequestBody(c) + if err != nil { + return 0, errors.WithStack(err) + } + + reqBody := new(commonAudioRequest) + c.Request.Body = io.NopCloser(bytes.NewReader(body)) + if err = c.ShouldBind(reqBody); err != nil { + return 0, errors.WithStack(err) + } + + reqFp, err := reqBody.File.Open() + if err != nil { + return 0, errors.WithStack(err) + } + + tmpFp, err := os.CreateTemp("", "audio-*") + if err != nil { + return 0, errors.WithStack(err) + } + defer os.Remove(tmpFp.Name()) + + _, err = io.Copy(tmpFp, reqFp) + if err != nil { + return 0, errors.WithStack(err) + } + if err = tmpFp.Close(); err != nil { + return 0, errors.WithStack(err) + } + + duration, err := helper.GetAudioDuration(c.Request.Context(), tmpFp.Name()) + if err != nil { + return 0, errors.WithStack(err) + } + + return int(math.Ceil(duration)) * TokensPerSecond, nil +} + func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatusCode { ctx := c.Request.Context() meta := meta.GetByContext(c) @@ -63,9 +113,19 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus case relaymode.AudioSpeech: preConsumedQuota = int64(float64(len(ttsRequest.Input)) * ratio) quota = preConsumedQuota + case relaymode.AudioTranscription, + relaymode.AudioTranslation: + audioTokens, err := countAudioTokens(c) + if err != nil { + return openai.ErrorWrapper(err, "count_audio_tokens_failed", http.StatusInternalServerError) + } + + preConsumedQuota = int64(float64(audioTokens) * ratio) + quota = preConsumedQuota default: - preConsumedQuota = int64(float64(config.PreConsumedQuota) * ratio) + return openai.ErrorWrapper(errors.New("unexpected_relay_mode"), "unexpected_relay_mode", http.StatusInternalServerError) } + userQuota, err := model.CacheGetUserQuota(ctx, userId) if err != nil { return openai.ErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError) @@ -139,7 +199,7 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus return openai.ErrorWrapper(err, "new_request_body_failed", http.StatusInternalServerError) } c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody.Bytes())) - responseFormat := c.DefaultPostForm("response_format", "json") + // responseFormat := c.DefaultPostForm("response_format", "json") req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) if err != nil { @@ -172,47 +232,53 @@ func RelayAudioHelper(c *gin.Context, relayMode int) *relaymodel.ErrorWithStatus return openai.ErrorWrapper(err, "close_request_body_failed", http.StatusInternalServerError) } - if relayMode != relaymode.AudioSpeech { - responseBody, err := io.ReadAll(resp.Body) - if err != nil { - return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) - } - err = resp.Body.Close() - if err != nil { - return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) - } + // https://github.com/Laisky/one-api/pull/21 + // Commenting out the following code because Whisper's transcription + // only charges for the length of the input audio, not for the output. + // ------------------------------------- + // if relayMode != relaymode.AudioSpeech { + // responseBody, err := io.ReadAll(resp.Body) + // if err != nil { + // return openai.ErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) + // } + // err = resp.Body.Close() + // if err != nil { + // return openai.ErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError) + // } - var openAIErr openai.SlimTextResponse - if err = json.Unmarshal(responseBody, &openAIErr); err == nil { - if openAIErr.Error.Message != "" { - return openai.ErrorWrapper(fmt.Errorf("type %s, code %v, message %s", openAIErr.Error.Type, openAIErr.Error.Code, openAIErr.Error.Message), "request_error", http.StatusInternalServerError) - } - } + // var openAIErr openai.SlimTextResponse + // if err = json.Unmarshal(responseBody, &openAIErr); err == nil { + // if openAIErr.Error.Message != "" { + // return openai.ErrorWrapper(errors.Errorf("type %s, code %v, message %s", openAIErr.Error.Type, openAIErr.Error.Code, openAIErr.Error.Message), "request_error", http.StatusInternalServerError) + // } + // } + + // var text string + // switch responseFormat { + // case "json": + // text, err = getTextFromJSON(responseBody) + // case "text": + // text, err = getTextFromText(responseBody) + // case "srt": + // text, err = getTextFromSRT(responseBody) + // case "verbose_json": + // text, err = getTextFromVerboseJSON(responseBody) + // case "vtt": + // text, err = getTextFromVTT(responseBody) + // default: + // return openai.ErrorWrapper(errors.New("unexpected_response_format"), "unexpected_response_format", http.StatusInternalServerError) + // } + // if err != nil { + // return openai.ErrorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError) + // } + // quota = int64(openai.CountTokenText(text, audioModel)) + // resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) + // } - var text string - switch responseFormat { - case "json": - text, err = getTextFromJSON(responseBody) - case "text": - text, err = getTextFromText(responseBody) - case "srt": - text, err = getTextFromSRT(responseBody) - case "verbose_json": - text, err = getTextFromVerboseJSON(responseBody) - case "vtt": - text, err = getTextFromVTT(responseBody) - default: - return openai.ErrorWrapper(errors.New("unexpected_response_format"), "unexpected_response_format", http.StatusInternalServerError) - } - if err != nil { - return openai.ErrorWrapper(err, "get_text_from_body_err", http.StatusInternalServerError) - } - quota = int64(openai.CountTokenText(text, audioModel)) - resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) - } if resp.StatusCode != http.StatusOK { return RelayErrorHandler(resp) } + succeed = true quotaDelta := quota - preConsumedQuota defer func(ctx context.Context) { diff --git a/router/web.go b/router/web.go index 3c9b4643a5..ebfc2ae1ac 100644 --- a/router/web.go +++ b/router/web.go @@ -3,6 +3,9 @@ package router import ( "embed" "fmt" + "net/http" + "strings" + "github.com/gin-contrib/gzip" "github.com/gin-contrib/static" "github.com/gin-gonic/gin" @@ -10,8 +13,6 @@ import ( "github.com/songquanpeng/one-api/common/config" "github.com/songquanpeng/one-api/controller" "github.com/songquanpeng/one-api/middleware" - "net/http" - "strings" ) func SetWebRouter(router *gin.Engine, buildFS embed.FS) {