Skip to content

Commit

Permalink
🔖 chore: Optimize user balance API (#330)
Browse files Browse the repository at this point in the history
  • Loading branch information
MartialBE committed Aug 19, 2024
1 parent e047dac commit 0b707f9
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 56 deletions.
1 change: 0 additions & 1 deletion common/config/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ var ChatLink = ""
var ChatLinks = ""
var QuotaPerUnit = 500 * 1000.0 // $0.002 / 1K tokens
var DisplayInCurrencyEnabled = true
var DisplayTokenStatEnabled = true

// Any options with "Secret", "Token" in its key won't be return by GetOptions

Expand Down
89 changes: 44 additions & 45 deletions controller/billing.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package controller

import (
"fmt"
"net/http"
"one-api/common"
"one-api/common/config"
"one-api/model"
"one-api/types"

"github.com/gin-gonic/gin"
)
Expand All @@ -14,48 +16,41 @@ func GetSubscription(c *gin.Context) {
var err error
var token *model.Token
var expiredTime int64
if config.DisplayTokenStatEnabled {
tokenId := c.GetInt("token_id")
token, err = model.GetTokenById(tokenId)
expiredTime = token.ExpiredTime
remainQuota = token.RemainQuota
usedQuota = token.UsedQuota
} else {

tokenId := c.GetInt("token_id")
token, err = model.GetTokenById(tokenId)
if err != nil {
common.APIRespondWithError(c, http.StatusOK, fmt.Errorf("获取信息失败: %v", err))
return
}

if token.UnlimitedQuota {
userId := c.GetInt("id")
remainQuota, err = model.GetUserQuota(userId)
userData, err := model.GetUserFields(userId, []string{"quota", "used_quota"})
if err != nil {
openAIError := types.OpenAIError{
Message: err.Error(),
Type: "upstream_error",
}
c.JSON(200, gin.H{
"error": openAIError,
})
common.APIRespondWithError(c, http.StatusOK, fmt.Errorf("获取用户信息失败: %v", err))

return
}
usedQuota, err = model.GetUserUsedQuota(userId)

remainQuota = userData["quota"].(int)
usedQuota = userData["used_quota"].(int)
} else {
expiredTime = token.ExpiredTime
remainQuota = token.RemainQuota
usedQuota = token.UsedQuota
}

if expiredTime <= 0 {
expiredTime = 0
}
if err != nil {
openAIError := types.OpenAIError{
Message: err.Error(),
Type: "upstream_error",
}
c.JSON(200, gin.H{
"error": openAIError,
})
return
}

quota := remainQuota + usedQuota
amount := float64(quota)
if config.DisplayInCurrencyEnabled {
amount /= config.QuotaPerUnit
}
if token != nil && token.UnlimitedQuota {
amount = 100000000
}

subscription := OpenAISubscriptionResponse{
Object: "billing_subscription",
HasPaymentMethod: true,
Expand All @@ -71,24 +66,28 @@ func GetUsage(c *gin.Context) {
var quota int
var err error
var token *model.Token
if config.DisplayTokenStatEnabled {
tokenId := c.GetInt("token_id")
token, err = model.GetTokenById(tokenId)
quota = token.UsedQuota
} else {
userId := c.GetInt("id")
quota, err = model.GetUserUsedQuota(userId)
}

tokenId := c.GetInt("token_id")
token, err = model.GetTokenById(tokenId)
if err != nil {
openAIError := types.OpenAIError{
Message: err.Error(),
Type: "one_api_error",
}
c.JSON(200, gin.H{
"error": openAIError,
})
common.APIRespondWithError(c, http.StatusOK, fmt.Errorf("获取信息失败: %v", err))
return
}

if token.UnlimitedQuota {
userId := c.GetInt("id")
userData, err := model.GetUserFields(userId, []string{"used_quota"})
if err != nil {
common.APIRespondWithError(c, http.StatusOK, fmt.Errorf("获取用户信息失败: %v", err))

return
}

quota = userData["used_quota"].(int)
} else {
quota = token.UsedQuota
}

amount := float64(quota)
if config.DisplayInCurrencyEnabled {
amount /= config.QuotaPerUnit
Expand Down
5 changes: 5 additions & 0 deletions model/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,8 @@ func RecordExists(table interface{}, fieldName string, fieldValue interface{}, e
query.Count(&count)
return count > 0
}

func GetFieldsByID(model interface{}, fieldNames []string, id int, result interface{}) error {
err := DB.Model(model).Where("id = ?", id).Select(fieldNames).Find(result).Error
return err
}
2 changes: 0 additions & 2 deletions model/option.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ func InitOptionMap() {
config.OptionMap["ApproximateTokenEnabled"] = strconv.FormatBool(config.ApproximateTokenEnabled)
config.OptionMap["LogConsumeEnabled"] = strconv.FormatBool(config.LogConsumeEnabled)
config.OptionMap["DisplayInCurrencyEnabled"] = strconv.FormatBool(config.DisplayInCurrencyEnabled)
config.OptionMap["DisplayTokenStatEnabled"] = strconv.FormatBool(config.DisplayTokenStatEnabled)
config.OptionMap["ChannelDisableThreshold"] = strconv.FormatFloat(config.ChannelDisableThreshold, 'f', -1, 64)
config.OptionMap["EmailDomainRestrictionEnabled"] = strconv.FormatBool(config.EmailDomainRestrictionEnabled)
config.OptionMap["EmailDomainWhitelist"] = strings.Join(config.EmailDomainWhitelist, ",")
Expand Down Expand Up @@ -157,7 +156,6 @@ var optionBoolMap = map[string]*bool{
"ApproximateTokenEnabled": &config.ApproximateTokenEnabled,
"LogConsumeEnabled": &config.LogConsumeEnabled,
"DisplayInCurrencyEnabled": &config.DisplayInCurrencyEnabled,
"DisplayTokenStatEnabled": &config.DisplayTokenStatEnabled,
"MjNotifyEnabled": &config.MjNotifyEnabled,
"ChatCacheEnabled": &config.ChatCacheEnabled,
}
Expand Down
6 changes: 6 additions & 0 deletions model/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,12 @@ func ValidateAccessToken(token string) (user *User) {
return nil
}

func GetUserFields(id int, fields []string) (map[string]interface{}, error) {
result := make(map[string]interface{})
err := GetFieldsByID(&User{}, fields, id, &result)
return result, err
}

func GetUserQuota(id int) (quota int, err error) {
err = DB.Model(&User{}).Where("id = ?", id).Select("quota").Find(&quota).Error
return quota, err
Expand Down
8 changes: 0 additions & 8 deletions web/src/views/Setting/component/OperationSetting.js
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ const OperationSetting = () => {
ChannelDisableThreshold: 0,
LogConsumeEnabled: '',
DisplayInCurrencyEnabled: '',
DisplayTokenStatEnabled: '',
ApproximateTokenEnabled: '',
RetryTimes: 0,
RetryCooldownSeconds: 0,
Expand Down Expand Up @@ -317,13 +316,6 @@ const OperationSetting = () => {
}
/>

<FormControlLabel
label={t('setting_index.operationSettings.generalSettings.displayTokenStat')}
control={
<Checkbox checked={inputs.DisplayTokenStatEnabled === 'true'} onChange={handleInputChange} name="DisplayTokenStatEnabled" />
}
/>

<FormControlLabel
label={t('setting_index.operationSettings.generalSettings.approximateToken')}
control={
Expand Down

0 comments on commit 0b707f9

Please sign in to comment.