From a1f61384c557da7855897bf9ae405f1e88e96945 Mon Sep 17 00:00:00 2001 From: JustSong Date: Mon, 15 May 2023 17:34:09 +0800 Subject: [PATCH] feat: automatically disable channel when error occurred (#59) --- common/constants.go | 5 +++++ controller/channel.go | 31 +++++++++++++++++++---------- controller/relay.go | 10 ++++++++++ middleware/distributor.go | 2 ++ model/channel.go | 14 ++++++------- model/option.go | 6 ++++++ web/src/components/SystemSetting.js | 29 ++++++++++++++++++++++++++- 7 files changed, 78 insertions(+), 19 deletions(-) diff --git a/common/constants.go b/common/constants.go index 3ebe233b0b..f9595ee021 100644 --- a/common/constants.go +++ b/common/constants.go @@ -52,6 +52,11 @@ var TurnstileSecretKey = "" var QuotaForNewUser = 100 +var ChannelDisableThreshold = 5.0 +var AutomaticDisableChannelEnabled = false + +var RootUserEmail = "" + const ( RoleGuestUser = 0 RoleCommonUser = 1 diff --git a/controller/channel.go b/controller/channel.go index d95882f95c..c123de41eb 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -263,6 +263,20 @@ func TestChannel(c *gin.Context) { var testAllChannelsLock sync.Mutex var testAllChannelsRunning bool = false +// disable & notify +func disableChannel(channelId int, channelName string, err error) { + if common.RootUserEmail == "" { + common.RootUserEmail = model.GetRootUserEmail() + } + model.UpdateChannelStatusById(channelId, common.ChannelStatusDisabled) + subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channelName, channelId) + content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channelName, channelId, err.Error()) + err = common.SendEmail(subject, common.RootUserEmail, content) + if err != nil { + common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error())) + } +} + func testAllChannels(c *gin.Context) error { testAllChannelsLock.Lock() if testAllChannelsRunning { @@ -280,8 +294,10 @@ func testAllChannels(c *gin.Context) error { return err } testRequest := buildTestRequest(c) - var disableThreshold int64 = 5000 // TODO: make it configurable - email := model.GetRootUserEmail() + var disableThreshold = int64(common.ChannelDisableThreshold * 1000) + if disableThreshold == 0 { + disableThreshold = 10000000 // a impossible value + } go func() { for _, channel := range channels { if channel.Status != common.ChannelStatusEnabled { @@ -295,18 +311,11 @@ func testAllChannels(c *gin.Context) error { if milliseconds > disableThreshold { err = errors.New(fmt.Sprintf("响应时间 %.2fs 超过阈值 %.2fs", float64(milliseconds)/1000.0, float64(disableThreshold)/1000.0)) } - // disable & notify - channel.UpdateStatus(common.ChannelStatusDisabled) - subject := fmt.Sprintf("通道「%s」(#%d)已被禁用", channel.Name, channel.Id) - content := fmt.Sprintf("通道「%s」(#%d)已被禁用,原因:%s", channel.Name, channel.Id, err.Error()) - err = common.SendEmail(subject, email, content) - if err != nil { - common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error())) - } + disableChannel(channel.Id, channel.Name, err) } channel.UpdateResponseTime(milliseconds) } - err := common.SendEmail("通道测试完成", email, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常") + err := common.SendEmail("通道测试完成", common.RootUserEmail, "通道测试完成,如果没有收到禁用通知,说明所有通道都正常") if err != nil { common.SysError(fmt.Sprintf("发送邮件失败:%s", err.Error())) } diff --git a/controller/relay.go b/controller/relay.go index 2be8a82555..5da9c19f53 100644 --- a/controller/relay.go +++ b/controller/relay.go @@ -4,6 +4,7 @@ import ( "bufio" "bytes" "encoding/json" + "errors" "fmt" "github.com/gin-gonic/gin" "github.com/pkoukk/tiktoken-go" @@ -74,6 +75,11 @@ func Relay(c *gin.Context) { "type": "one_api_error", }, }) + if common.AutomaticDisableChannelEnabled { + channelId := c.GetInt("channel_id") + channelName := c.GetString("channel_name") + disableChannel(channelId, channelName, err) + } } } @@ -256,6 +262,10 @@ func relayHelper(c *gin.Context) error { if err != nil { return err } + if textResponse.Error.Type != "" { + return errors.New(fmt.Sprintf("type %s, code %s, message %s", + textResponse.Error.Type, textResponse.Error.Code, textResponse.Error.Message)) + } // Reset response body resp.Body = io.NopCloser(bytes.NewBuffer(responseBody)) } diff --git a/middleware/distributor.go b/middleware/distributor.go index 65fcbd3daf..357849e7bb 100644 --- a/middleware/distributor.go +++ b/middleware/distributor.go @@ -62,6 +62,8 @@ func Distribute() func(c *gin.Context) { } } c.Set("channel", channel.Type) + c.Set("channel_id", channel.Id) + c.Set("channel_name", channel.Name) c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key)) if channel.Type == common.ChannelTypeCustom || channel.Type == common.ChannelTypeAzure { c.Set("base_url", channel.BaseURL) diff --git a/model/channel.go b/model/channel.go index 7b1c9ec206..0335207be8 100644 --- a/model/channel.go +++ b/model/channel.go @@ -86,15 +86,15 @@ func (channel *Channel) UpdateResponseTime(responseTime int64) { } } -func (channel *Channel) UpdateStatus(status int) { - err := DB.Model(channel).Update("status", status).Error - if err != nil { - common.SysError("failed to update response time: " + err.Error()) - } -} - func (channel *Channel) Delete() error { var err error err = DB.Delete(channel).Error return err } + +func UpdateChannelStatusById(id int, status int) { + err := DB.Model(&Channel{}).Where("id = ?", id).Update("status", status).Error + if err != nil { + common.SysError("failed to update channel status: " + err.Error()) + } +} diff --git a/model/option.go b/model/option.go index e5ddb60c8d..5807ff03b1 100644 --- a/model/option.go +++ b/model/option.go @@ -32,6 +32,8 @@ func InitOptionMap() { common.OptionMap["WeChatAuthEnabled"] = strconv.FormatBool(common.WeChatAuthEnabled) common.OptionMap["TurnstileCheckEnabled"] = strconv.FormatBool(common.TurnstileCheckEnabled) common.OptionMap["RegisterEnabled"] = strconv.FormatBool(common.RegisterEnabled) + common.OptionMap["AutomaticDisableChannelEnabled"] = strconv.FormatBool(common.AutomaticDisableChannelEnabled) + common.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(common.ChannelDisableThreshold, 'f', -1, 64) common.OptionMap["SMTPServer"] = "" common.OptionMap["SMTPFrom"] = "" common.OptionMap["SMTPPort"] = strconv.Itoa(common.SMTPPort) @@ -114,6 +116,8 @@ func updateOptionMap(key string, value string) (err error) { common.TurnstileCheckEnabled = boolValue case "RegisterEnabled": common.RegisterEnabled = boolValue + case "AutomaticDisableChannelEnabled": + common.AutomaticDisableChannelEnabled = boolValue } } switch key { @@ -156,6 +160,8 @@ func updateOptionMap(key string, value string) (err error) { err = common.UpdateModelRatioByJSONString(value) case "TopUpLink": common.TopUpLink = value + case "ChannelDisableThreshold": + common.ChannelDisableThreshold, _ = strconv.ParseFloat(value, 64) } return err } diff --git a/web/src/components/SystemSetting.js b/web/src/components/SystemSetting.js index b7c2a6e805..405303a962 100644 --- a/web/src/components/SystemSetting.js +++ b/web/src/components/SystemSetting.js @@ -28,7 +28,9 @@ const SystemSetting = () => { RegisterEnabled: '', QuotaForNewUser: 0, ModelRatio: '', - TopUpLink: '' + TopUpLink: '', + AutomaticDisableChannelEnabled: '', + ChannelDisableThreshold: 0, }); let originInputs = {}; let [loading, setLoading] = useState(false); @@ -62,6 +64,7 @@ const SystemSetting = () => { case 'WeChatAuthEnabled': case 'TurnstileCheckEnabled': case 'RegisterEnabled': + case 'AutomaticDisableChannelEnabled': value = inputs[key] === 'true' ? 'false' : 'true'; break; default: @@ -298,6 +301,30 @@ const SystemSetting = () => { 保存运营设置 +
+ 监控设置 +
+ + + + + + +
配置 SMTP 用以支持系统的邮件发送