Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor ratio #2038

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion common/ctxkey/key.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package ctxkey

import "github.com/gin-gonic/gin"

const (
Config = "config"
Id = "id"
Expand All @@ -19,6 +21,6 @@ const (
TokenName = "token_name"
BaseURL = "base_url"
AvailableModels = "available_models"
KeyRequestBody = "key_request_body"
KeyRequestBody = gin.BodyBytesKey
SystemPrompt = "system_prompt"
)
7 changes: 4 additions & 3 deletions common/gin.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@ package common
import (
"bytes"
"encoding/json"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/ctxkey"
"io"
"strings"

"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/ctxkey"
)

func GetRequestBody(c *gin.Context) ([]byte, error) {
Expand All @@ -31,7 +32,6 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
contentType := c.Request.Header.Get("Content-Type")
if strings.HasPrefix(contentType, "application/json") {
err = json.Unmarshal(requestBody, &v)
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
} else {
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
err = c.ShouldBind(&v)
Expand All @@ -40,6 +40,7 @@ func UnmarshalBodyReusable(c *gin.Context, v any) error {
return err
}
// Reset request body
c.Request.Body = io.NopCloser(bytes.NewBuffer(requestBody))
return nil
}

Expand Down
4 changes: 3 additions & 1 deletion controller/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"github.com/songquanpeng/one-api/common/ctxkey"
"github.com/songquanpeng/one-api/model"
relay "github.com/songquanpeng/one-api/relay"
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/apitype"
"github.com/songquanpeng/one-api/relay/channeltype"
Expand Down Expand Up @@ -86,7 +87,8 @@ func init() {
if channelType == channeltype.Azure {
continue
}
channelName, channelModelList := openai.GetCompatibleChannelMeta(channelType)
channelName, channelRatioMap := openai.GetCompatibleChannelMeta(channelType)
channelModelList := adaptor.GetModelListHelper(channelRatioMap)
for _, modelName := range channelModelList {
models = append(models, OpenAIModels{
Id: modelName,
Expand Down
34 changes: 29 additions & 5 deletions model/option.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
package model

import (
"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
"strconv"
"strings"
"time"

"github.com/songquanpeng/one-api/common/config"
"github.com/songquanpeng/one-api/common/logger"
billingratio "github.com/songquanpeng/one-api/relay/billing/ratio"
)

type Option struct {
Expand Down Expand Up @@ -70,6 +71,7 @@ func InitOptionMap() {
config.OptionMap["ModelRatio"] = billingratio.ModelRatio2JSONString()
config.OptionMap["GroupRatio"] = billingratio.GroupRatio2JSONString()
config.OptionMap["CompletionRatio"] = billingratio.CompletionRatio2JSONString()
config.OptionMap["Ratio"] = billingratio.Ratio2JSONString()
config.OptionMap["TopUpLink"] = config.TopUpLink
config.OptionMap["ChatLink"] = config.ChatLink
config.OptionMap["QuotaPerUnit"] = strconv.FormatFloat(config.QuotaPerUnit, 'f', -1, 64)
Expand All @@ -81,15 +83,35 @@ func InitOptionMap() {

func loadOptionsFromDatabase() {
options, _ := AllOption()
var oldModelRatio string
var oldCompletionRatio string
for _, option := range options {
if option.Key == "ModelRatio" {
oldModelRatio = option.Value
option.Value = billingratio.AddNewMissingRatio(option.Value)
}
if option.Key == "CompletionRatio" {
oldCompletionRatio = option.Value
}
err := updateOptionMap(option.Key, option.Value)
if err != nil {
logger.SysError("failed to update option map: " + err.Error())
}
}
for _, option := range options {
if option.Key == "Ratio" {
option.Value = billingratio.AddOldRatio(oldModelRatio, oldCompletionRatio)
err := updateOptionMap(option.Key, option.Value)
if err != nil {
logger.SysError("failed to update option map: " + err.Error())
}
err = UpdateOption(option.Key, option.Value)
if err != nil {
logger.SysError("failed to update option map: " + err.Error())
}
logger.SysLog("ratio merged")
}
}
}

func SyncOptions(frequency int) {
Expand Down Expand Up @@ -223,12 +245,14 @@ func updateOptionMap(key string, value string) (err error) {
config.PreConsumedQuota, _ = strconv.ParseInt(value, 10, 64)
case "RetryTimes":
config.RetryTimes, _ = strconv.Atoi(value)
case "ModelRatio":
case "ModelRatio": // Deprecated
err = billingratio.UpdateModelRatioByJSONString(value)
case "GroupRatio":
err = billingratio.UpdateGroupRatioByJSONString(value)
case "CompletionRatio":
case "CompletionRatio": // Deprecated
err = billingratio.UpdateCompletionRatioByJSONString(value)
case "Ratio":
err = billingratio.UpdateRatioByJSONString(value)
case "TopUpLink":
config.TopUpLink = value
case "ChatLink":
Expand Down
12 changes: 7 additions & 5 deletions relay/adaptor/ai360/constants.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package ai360

var ModelList = []string{
"360GPT_S2_V9",
"embedding-bert-512-v1",
"embedding_s1_v1",
"semantic_similarity_s1_v1",
import "github.com/songquanpeng/one-api/relay/billing/ratio"

var RatioMap = map[string]ratio.Ratio{
"360GPT_S2_V9": {Input: 0.012 * ratio.RMB, Output: 0.012 * ratio.RMB},
"embedding-bert-512-v1": {Input: 0.0001 * ratio.RMB, Output: 0},
"embedding_s1_v1": {Input: 0.0001 * ratio.RMB, Output: 0},
"semantic_similarity_s1_v1": {Input: 0.0001 * ratio.RMB, Output: 0},
}
12 changes: 9 additions & 3 deletions relay/adaptor/aiproxy/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@ package aiproxy
import (
"errors"
"fmt"
"io"
"net/http"

"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/billing/ratio"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
"io"
"net/http"
)

type Adaptor struct {
Expand Down Expand Up @@ -58,8 +60,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met
return
}

func (a *Adaptor) GetRatio(meta *meta.Meta) *ratio.Ratio {
return adaptor.GetRatioHelper(meta, RatioMap)
}

func (a *Adaptor) GetModelList() []string {
return ModelList
return adaptor.GetModelListHelper(RatioMap)
}

func (a *Adaptor) GetChannelName() string {
Expand Down
6 changes: 1 addition & 5 deletions relay/adaptor/aiproxy/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,4 @@ package aiproxy

import "github.com/songquanpeng/one-api/relay/adaptor/openai"

var ModelList = []string{""}

func init() {
ModelList = openai.ModelList
}
var RatioMap = openai.RatioMap
12 changes: 9 additions & 3 deletions relay/adaptor/ali/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ package ali
import (
"errors"
"fmt"
"io"
"net/http"

"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/billing/ratio"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
"io"
"net/http"
)

// https://help.aliyun.com/zh/dashscope/developer-reference/api-details
Expand Down Expand Up @@ -96,8 +98,12 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *meta.Met
return
}

func (a *Adaptor) GetRatio(meta *meta.Meta) *ratio.Ratio {
return adaptor.GetRatioHelper(meta, RatioMap)
}

func (a *Adaptor) GetModelList() []string {
return ModelList
return adaptor.GetModelListHelper(RatioMap)
}

func (a *Adaptor) GetChannelName() string {
Expand Down
Loading