diff --git a/custom/conf/app.ini.sample b/custom/conf/app.ini.sample old mode 100755 new mode 100644 index 5cf6188609..99ad8b4c7a --- a/custom/conf/app.ini.sample +++ b/custom/conf/app.ini.sample @@ -394,12 +394,12 @@ MAX_OPEN_CONNS = 0 DB_TYPE = postgres HOST = 127.0.0.1:5432 NAME = statistic -USER = -PASSWD = +USER = +PASSWD = SCHEMA = SSL_MODE = disable CHARSET = utf8 -PATH = +PATH = [indexer] ; Issue indexer type, currently support: bleve, db or elasticsearch, default is bleve @@ -1150,5 +1150,5 @@ growth_comments=0.2 [grampus] USERNAME = -PASSWORD = -SERVER_HOST = +PASSWORD = +SERVER_HOST = diff --git a/entity/ai_task.go b/entity/ai_task.go index 5e4a2d8f0c..da78e29abf 100644 --- a/entity/ai_task.go +++ b/entity/ai_task.go @@ -143,6 +143,7 @@ type AITaskDetailInfo struct { FailedReason string `json:"failed_reason"` UserId int64 `json:"-"` AppName string `json:"app_name"` + HasInternet int `json:"has_internet"` } func (a *AITaskDetailInfo) Tr(language string) { diff --git a/models/cloudbrain.go b/models/cloudbrain.go index aecc026c58..46669d1e1c 100755 --- a/models/cloudbrain.go +++ b/models/cloudbrain.go @@ -310,6 +310,7 @@ type Cloudbrain struct { Spec *Specification `xorm:"-"` Config *CloudbrainConfig `xorm:"-"` AppName string //超算任务的应用类型 + HasInternet int } type CloudbrainShow struct { @@ -3452,6 +3453,34 @@ func GetCloudbrainByIDs(ids []int64) ([]*Cloudbrain, error) { Find(&cloudbrains) } +type CountPerUserID struct { + Count int64 + UserID int64 +} + +func GetNotebookCountGreaterThanN(n int) ([]CountPerUserID, error) { + cpuis := []CountPerUserID{} + err := x. + Table("cloudbrain"). + GroupBy("user_id").Having("count(*)>"+strconv.Itoa(n)). + Select("user_id, count(*) AS count"). + Where("job_type=? and (deleted_at=? or deleted_at is NULL)", "DEBUG", "0001-01-01 00:00:00").OrderBy("count(*) desc"). + Find(&cpuis) + return cpuis, err + +} +func GetNotebooksByUser(uid int64, offset int) ([]int64, error) { + var ints []int64 + err := x.Table("cloudbrain").Cols("id").Where("job_type=? and user_id=? and (deleted_at=? or deleted_at is NULL)", "DEBUG", uid, "0001-01-01 00:00:00").Desc("id").Limit(1000, offset).Find(&ints) + return ints, err +} + +func GetNotebooksCountByUser(uid int64) (int64, error) { + cloudbrain := new(Cloudbrain) + return x.Where("user_id=? and job_type=?", uid, "DEBUG").Count(cloudbrain) + +} + func GetCloudbrainWithDeletedByIDs(ids []int64) ([]*Cloudbrain, error) { cloudbrains := make([]*Cloudbrain, 0) return cloudbrains, x. diff --git a/models/file_chunk.go b/models/file_chunk.go index 642872a3c2..7aca886c80 100755 --- a/models/file_chunk.go +++ b/models/file_chunk.go @@ -145,6 +145,26 @@ func getModelFileChunkByUUID(e Engine, uuid string) (*ModelFileChunk, error) { return fileChunk, nil } +func GetModelFileChunksByUserId(userId int64, lastTime int64, isUploadFinished bool) ([]*ModelFileChunk, error) { + return getModelFileChunksByUserId(x, userId, lastTime, isUploadFinished) +} + +func getModelFileChunksByUserId(e Engine, userId int64, lastTime int64, isUploadFinished bool) ([]*ModelFileChunk, error) { + fileChunks := make([]*ModelFileChunk, 0) + cond := builder.NewCond() + cond = cond.And(builder.Eq{"user_id": userId}) + if lastTime > 0 { + cond = cond.And(builder.Gte{"created_unix": lastTime}) + } + if !isUploadFinished { + cond = cond.And(builder.Eq{"is_uploaded": 0}) + } + if err := e.Where(cond).Find(&fileChunks); err != nil { + return nil, err + } + return fileChunks, nil +} + // InsertFileChunk insert a record into file_chunk. func InsertFileChunk(fileChunk *FileChunk) (_ *FileChunk, err error) { if _, err := x.Insert(fileChunk); err != nil { diff --git a/models/llm_chat.go b/models/llm_chat.go new file mode 100644 index 0000000000..1f3caf13e4 --- /dev/null +++ b/models/llm_chat.go @@ -0,0 +1,94 @@ +package models + +import ( + "code.gitea.io/gitea/modules/timeutil" + "fmt" +) + +type LlmChat struct { + ID string `xorm:"pk"` + UserId int64 `xorm:"INDEX"` + Count int + Prompt string `xorm:"text"` + Answer string `xorm:"text"` + InvalidCount int + InvalidType string + InvalidTool string + InvalidDetail string `xorm:"text"` + ChatStatus int + ModelName string + ChatType string + ChatId string + KnowledgeBaseName string + VectorStoreType string + EmbeddingModel string + Endpoint string + CreatedUnix timeutil.TimeStamp `xorm:"created"` + UpdatedUnix timeutil.TimeStamp `xorm:"updated"` +} + +func SaveChat(llmChat *LlmChat) error { + sess := xStatistic.NewSession() + defer sess.Close() + re, err := sess.Insert(llmChat) + if err != nil { + fmt.Printf("insert llmChat error %s\n", err.Error()) + return err + } + fmt.Printf("success to save llmChat db.re=%+v\n", fmt.Sprint(re)) + return nil +} + +func QueryChatCount(userId int64, modelName string) int64 { + sess := xStatistic.NewSession() + defer sess.Close() + query := "SELECT SUM(count) AS count FROM public.llm_chat WHERE chat_status = 1 AND user_id = ? AND model_name = ?" + sumList, err := sess.QueryInterface(query, userId, modelName) + if err == nil { + if len(sumList) == 1 { + return convertInterfaceToInt64(sumList[0]["count"]) + } + } + return 0 +} + +func QueryInvalidPromptCount(userId int64) int64 { + sess := xStatistic.NewSession() + defer sess.Close() + query := "SELECT SUM(invalid_count) AS count FROM public.llm_chat WHERE invalid_type='prompt' and user_id = ?" + sumList, err := sess.QueryInterface(query, userId) + if err == nil { + if len(sumList) == 1 { + return convertInterfaceToInt64(sumList[0]["count"]) + } + } + return 0 +} + +func QueryChatStatistics() ([]map[string]interface{}, error) { + sess := xStatistic.NewSession() + query := ` + SELECT + COALESCE(model_name, 'total') as model_name, + COUNT(DISTINCT id) as chat, + COUNT(DISTINCT user_id) as chat_user, + COUNT(DISTINCT id) / COUNT(DISTINCT user_id) as chat_per_user_avg, + COUNT(DISTINCT CASE WHEN chat_type = 'llm' THEN id END) as chat_llm, + COUNT(DISTINCT CASE WHEN chat_type = 'llm' THEN user_id END) as chat_llm_user, + COUNT(DISTINCT CASE WHEN chat_type = 'kb' THEN id END) as chat_kb, + COUNT(DISTINCT CASE WHEN chat_type = 'kb' THEN user_id END) as chat_kb_user, + COALESCE(SUM(invalid_count),0) as chat_illegal, + COALESCE(SUM(CASE WHEN invalid_type = 'prompt' THEN 1 END),0) as chat_illegal_prompt, + COALESCE(SUM(CASE WHEN invalid_type = 'answer' THEN 1 END),0) as chat_illegal_answer + FROM + llm_chat + GROUP BY + ROLLUP(model_name)` + + results, err := sess.SQL(query).QueryInterface() + if err != nil { + return nil, err + } + + return results, nil +} diff --git a/models/llm_chat_visit.go b/models/llm_chat_visit.go new file mode 100644 index 0000000000..2c35f02d1b --- /dev/null +++ b/models/llm_chat_visit.go @@ -0,0 +1,97 @@ +package models + +import ( + "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/timeutil" + "fmt" +) + +type LlmChatVisit struct { + ID int64 `xorm:"pk autoincr"` + UserId int64 `xorm:"INDEX"` + ChatId string + ModelName string + Agreement int + ExpiredTime string + ExpiredUnix int64 + CreatedUnix timeutil.TimeStamp `xorm:"created"` +} + +func SaveVisit(llmChatVisit *LlmChatVisit) error { + sess := xStatistic.NewSession() + defer sess.Close() + re, err := sess.Insert(llmChatVisit) + if err != nil { + log.Info("insert llmChatVisit error %s\n", err.Error()) + return err + } + log.Info("success to save llmChatVisit db.re=%+v\n", fmt.Sprint(re)) + return nil +} + +func QueryFirstVisit(userId int64) int64 { + sess := xStatistic.NewSession() + defer sess.Close() + query := "SELECT SUM(agreement) AS count FROM public.llm_chat_visit WHERE user_id = ?" + sumList, err := sess.QueryInterface(query, userId) + if err == nil { + if len(sumList) == 1 { + val := convertInterfaceToInt64(sumList[0]["count"]) + return val + } + } + return 0 +} + +func QueryRunningChat(userId int64, modelName string, currentTime int64) (*LlmChatVisit, error) { + sess := xStatistic.NewSession() + defer sess.Close() + re := new(LlmChatVisit) + isExist, err := sess.Table(new(LlmChatVisit)).Where("user_id = ? AND model_name = ? AND ? > created_unix AND ? < expired_unix", userId, modelName, currentTime, currentTime).Get(re) + if err == nil && isExist { + return re, nil + } + return nil, err +} + +func QueryByChatId(chatId string) (*LlmChatVisit, error) { + sess := xStatistic.NewSession() + defer sess.Close() + re := new(LlmChatVisit) + isExist, err := sess.Table(new(LlmChatVisit)).Where("chat_id = ?", chatId).Get(re) + if err == nil && isExist { + return re, nil + } + return nil, err +} + +func UpdateChat(llmChatVisit *LlmChatVisit) error { + sess := xStatistic.ID(llmChatVisit.ID) + defer sess.Close() + re, err := sess.Cols("agreement").Update(llmChatVisit) + if err != nil { + return err + } + log.Info("update llmChatVisit db.re=" + fmt.Sprint(re)) + return nil +} + +func QueryChatVisitStatistics() ([]map[string]interface{}, error) { + sess := xStatistic.NewSession() + query := ` + SELECT + COALESCE(model_name, 'total') as model_name, + COUNT(DISTINCT chat_id) AS visit, + COUNT(DISTINCT user_id) AS visit_user + FROM + llm_chat_visit + GROUP BY + ROLLUP(model_name)` + + results, err := sess.SQL(query).QueryInterface() + if err != nil { + return nil, err + } + + return results, nil +} diff --git a/models/models.go b/models/models.go index df9265decb..29f75e12f8 100755 --- a/models/models.go +++ b/models/models.go @@ -198,6 +198,8 @@ func init() { new(CloudbrainDurationStatistic), new(UserSummaryCurrentYear), new(ModelApp), + new(LlmChat), + new(LlmChatVisit), ) gonicNames := []string{"SSL", "UID"} diff --git a/models/resource_queue.go b/models/resource_queue.go index 7da02bedec..6a97fddd6d 100644 --- a/models/resource_queue.go +++ b/models/resource_queue.go @@ -100,11 +100,13 @@ type ResourceQueueListRes struct { } type ResourceQueueCodesRes struct { - ID int64 - QueueCode string - Cluster string - AiCenterCode string - AiCenterName string + ID int64 + QueueCode string + Cluster string + AiCenterCode string + AiCenterName string + ComputeResource string + AccCardType string } func (ResourceQueueCodesRes) TableName() string { diff --git a/modules/baiduai/baiduai.go b/modules/baiduai/baiduai.go new file mode 100644 index 0000000000..803ee4d40e --- /dev/null +++ b/modules/baiduai/baiduai.go @@ -0,0 +1,32 @@ +package baiduai + +type LegalTextResponse struct { + Conclusion string `json:"conclusion"` + LogId string `json:"log_id"` + IsHitMd5 bool `json:"isHitMd5"` + ConclusionType int `json:"conclusionType"` + Data []Data `json:"data"` +} + +type Data struct { + Msg string `json:"msg"` + Conclusion string `json:"conclusion"` + SubType int `json:"subType"` + ConclusionType int `json:"conclusionType"` + Type int `json:"type"` + Hits []Hit `json:"hits"` +} + +type Hit struct { + Probability int `json:"probability"` + DatasetName string `json:"datasetName"` + Words []string `json:"words"` + ModelHitPositions [][]float64 `json:"modelHitPositions"` + WordHitPositions []WordHitPosition `json:"wordHitPositions"` +} + +type WordHitPosition struct { + Positions [][]int `json:"positions"` + Label string `json:"label"` + Keyword string `json:"keyword"` +} diff --git a/modules/baiduai/resty.go b/modules/baiduai/resty.go new file mode 100644 index 0000000000..5d04c4f5c9 --- /dev/null +++ b/modules/baiduai/resty.go @@ -0,0 +1,72 @@ +package baiduai + +import ( + "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/setting" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "strings" +) + +/** + * 使用 AK,SK 生成鉴权签名(Access Token) + * @return string 鉴权签名信息(Access Token) + */ +func getAccessToken() string { + postData := fmt.Sprintf("grant_type=client_credentials&client_id=%s&client_secret=%s", setting.BAIDU_AI.API_KEY, setting.BAIDU_AI.SECRET_KEY) + resp, err := http.Post(setting.BAIDU_AI.URL, "application/x-www-form-urlencoded", strings.NewReader(postData)) + if err != nil { + fmt.Println(err) + return "" + } + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + fmt.Println(err) + return "" + } + accessTokenObj := map[string]string{} + json.Unmarshal([]byte(body), &accessTokenObj) + return accessTokenObj["access_token"] +} + +/** + * 百度api文本内容合规检测 + */ +func CheckLegalText(query string) (*LegalTextResponse, error) { + var result LegalTextResponse + + url := setting.BAIDU_AI.LEGAL_TEXT_URL + getAccessToken() + payload := strings.NewReader("text=" + query) + log.Info("resty CheckLegalText() payload %+v", payload) + client := &http.Client{} + + req, err := http.NewRequest("POST", url, payload) + if err != nil { + log.Error("resty CheckLegalText() Request error: %s", err.Error()) + return &result, fmt.Errorf("resty CheckLegalText(): %s", err) + } + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + req.Header.Add("Accept", "application/json") + + res, err := client.Do(req) + if err != nil { + log.Error("resty CheckLegalText() Response error: %s", err.Error()) + return &result, fmt.Errorf("resty CheckLegalText(): %s", err) + } + defer res.Body.Close() + log.Error("resty CheckLegalText() Response status: %s\n", res.Status) + + body, err := ioutil.ReadAll(res.Body) + if err != nil { + log.Error("resty CheckLegalText() Response body error: %s", err.Error()) + return &result, fmt.Errorf("resty CheckLegalText(): %s", err) + } + + response := string(body) + json.Unmarshal([]byte(response), &result) + log.Info("resty CheckLegalText() results: %+v\n", result) + return &result, nil +} diff --git a/modules/cron/tasks_basic.go b/modules/cron/tasks_basic.go index b7c0e21b76..39d6c21a72 100755 --- a/modules/cron/tasks_basic.go +++ b/modules/cron/tasks_basic.go @@ -209,6 +209,17 @@ func registerHandleClearCloudbrainResult() { }) } +func registerHandleClearNotebook() { + RegisterTaskFatal("handle_notebook_clear", &BaseConfig{ + Enabled: true, + RunAtStart: setting.NotebookStrategy.RunAtStart, + Schedule: setting.NotebookStrategy.Cron, + }, func(ctx context.Context, _ *models.User, _ Config) error { + task.ClearNotebook() + return nil + }) +} + func registerHandleSummaryStatistic() { RegisterTaskFatal("handle_summary_statistic", &BaseConfig{ Enabled: true, @@ -379,6 +390,7 @@ func initBasicTasks() { registerHandleRepoAndUserStatistic() registerHandleSummaryStatistic() registerHandleClearCloudbrainResult() + registerHandleClearNotebook() registerSyncCloudbrainStatus() registerHandleOrgStatistic() diff --git a/modules/llm_chat/llm_chat.go b/modules/llm_chat/llm_chat.go new file mode 100644 index 0000000000..9c9fbdc539 --- /dev/null +++ b/modules/llm_chat/llm_chat.go @@ -0,0 +1,56 @@ +package llm_chat + +type LLMChatResponse struct { + Answer string `json:"answer"` +} + +type KBChatResponse struct { + Answer string `json:"answer"` + Docs []string `json:"docs"` +} + +type SearchDocResponse struct { + Results []SearchDocResult `json:"results"` +} + +type SearchDocResult struct { + PageContent string `json:"page_content"` + Metadata struct { + } `json:"metadata"` + Score float64 `json:"score"` +} + +type RecreateVectorStoreResponse struct { + Code int `json:"code"` + Msg string `json:"msg"` + Total int `json:"total"` + Finished int `json:"finished"` + Doc string `json:"doc"` +} + +type LLMBasicMsgWithData struct { + Code int `json:"code"` + Msg string `json:"msg"` + Data []string `json:"data"` +} + +type LLMDeleteDocMsg struct { + Code int `json:"code"` + Msg string `json:"msg"` + Data struct { + FailedFiles map[string]string `json:"failed_files"` + } `json:"data"` +} + +type LLMBasicMsg struct { + Code int `json:"code"` + Msg string `json:"msg"` +} + +type LLMErrorMsg struct { + Detail []struct { + Loc []interface{} `json:"loc"` + Msg string `json:"msg"` + Type string `json:"type"` + } `json:"detail"` +} diff --git a/modules/llm_chat/resty.go b/modules/llm_chat/resty.go new file mode 100644 index 0000000000..013f7babe6 --- /dev/null +++ b/modules/llm_chat/resty.go @@ -0,0 +1,578 @@ +package llm_chat + +import ( + "bufio" + "bytes" + "code.gitea.io/gitea/modules/log" + constants "code.gitea.io/gitea/modules/setting" + api "code.gitea.io/gitea/modules/structs" + "crypto/tls" + "encoding/json" + "fmt" + "github.com/go-resty/resty/v2" + "io" + "mime/multipart" + "net/http" + "unicode/utf8" +) + +var ( + restyClient *resty.Client +) + +const ( + urlLLMChat = "/chat/chat" + urlKnowledgeBaseChat = "/chat/knowledge_base_chat" + urlKnowledgeBaseList = "/knowledge_base/list_knowledge_bases" + urlKnowledgeBaseCreate = "/knowledge_base/create_knowledge_base" + urlKnowledgeBaseDelete = "/knowledge_base/delete_knowledge_base" + urlKnowledgeBaseListFiles = "/knowledge_base/list_files" + urlKnowledgeBaseSearchDoc = "/knowledge_base/search_docs" + urlKnowledgeBaseUploadDoc = "/knowledge_base/upload_docs" + urlKnowledgeBaseDeleteDoc = "/knowledge_base/delete_docs" + urlKnowledgeBaseUpdateDoc = "/knowledge_base/update_docs" + urlKnowledgeBaseDownload = "/knowledge_base/download_doc" + urlKnowledgeBaseRecreate = "/knowledge_base/recreate_vector_store" +) + +func GetEndpoint(modelName string) string { + //endpoint := constants.LLM_CHAT_API.HOST + ":" + var endpoint string + switch modelName { + case "chatglm2-6b": + endpoint = constants.LLM_CHAT_API.CHATGLM2_HOST + case "llama2-7b-chat-hf": + endpoint = constants.LLM_CHAT_API.LLAMA2_HOST + default: + endpoint = constants.LLM_CHAT_API.CHATGLM2_HOST + } + + return endpoint +} + +func getRestyClient() *resty.Client { + if restyClient == nil { + restyClient = resty.New() + restyClient.SetTLSClientConfig(&tls.Config{InsecureSkipVerify: true}) + } + return restyClient +} + +// Custom function to split by character +func scanByCharacter(data []byte, atEOF bool) (int, []byte, error) { + if len(data) == 0 { + return 0, nil, nil + } + return 1, data[:1], nil +} + +func StreamLLMChat(params api.LLMChatMessage, resultChan chan string, errChan chan error, done chan struct{}) { + client := http.Client{} + endpoint := GetEndpoint(params.ModelName) + requestBody, _ := json.Marshal(params) + log.Info("Request body: %s\n", requestBody) + request, err := http.NewRequest("POST", endpoint+urlLLMChat, bytes.NewBuffer(requestBody)) + if err != nil { + log.Error("Error creating request: %v", err) + errChan <- err + return + } + request.Header.Set("Content-Type", "application/json") + resp, err := client.Do(request) + if err != nil { + log.Error("Error sending request: %v", err) + errChan <- err + return + } + defer resp.Body.Close() + log.Info("Response status: %s\n", resp.Status) + + scanner := bufio.NewScanner(resp.Body) + scanner.Split(scanByCharacter) + var invalidCharBuffer string + for scanner.Scan() { + char := scanner.Text() + if len(invalidCharBuffer) > 0 { + char = invalidCharBuffer + char + invalidCharBuffer = "" + } + if utf8.ValidString(char) { + //runes := []rune(char) + //log.Info("%s -> %U \n", char, runes[0]) + resultChan <- char + } else { + invalidCharBuffer += char + } + } + if len(invalidCharBuffer) > 0 { + log.Info("Unprocessed invalid UTF-8 characters: %s\n", invalidCharBuffer) + } + close(done) + + if scanner.Err() != nil { + errChan <- scanner.Err() + } +} + +func StreamKBChat(params api.KBChatMessage, resultChan chan string, errChan chan error, done chan struct{}) { + client := http.Client{} + endpoint := GetEndpoint(params.ModelName) + requestBody, _ := json.Marshal(params) + log.Info("Request body: %s\n", requestBody) + request, err := http.NewRequest("POST", endpoint+urlKnowledgeBaseChat, bytes.NewBuffer(requestBody)) + if err != nil { + log.Error("Error creating request: %v", err) + errChan <- err + return + } + request.Header.Set("Content-Type", "application/json") + resp, err := client.Do(request) + if err != nil { + log.Error("Error sending request: %v", err) + errChan <- err + return + } + defer resp.Body.Close() + log.Info("Response status: %s\n", resp.Status) + + //Create a buffer to read 2048-byte blocks + buffer := make([]byte, 4096) + for { + // Read a 4096-byte block from the response body + n, err := resp.Body.Read(buffer) + if err != nil { + if err != io.EOF { + errChan <- err + } + break + } + resultChan <- string(buffer[:n]) + } + close(done) +} + +func SendLLMChat(params api.LLMChatMessage) (*LLMChatResponse, error) { + client := getRestyClient() + retry := 0 + endpoint := GetEndpoint(params.ModelName) + + request, _ := json.Marshal(params) + log.Info("resty request body: %s", request) + +sendjob: + res, err := client.R(). + SetHeader("Content-Type", "application/json"). + SetBody(params). + Post(endpoint + urlLLMChat) + + log.Info("resty status: %+v, route: %+v", res.StatusCode(), res.Request.URL) + + result := LLMChatResponse{ + Answer: res.String(), + } + log.Info("resty response: %+v", result) + + if err != nil { + return &result, fmt.Errorf("resty SendLLMChat(): %s", err) + } + + if res.StatusCode() == http.StatusUnauthorized && retry < 1 { + retry++ + goto sendjob + } + return &result, nil +} + +func SendKBChat(params api.KBChatMessage) (*KBChatResponse, error) { + client := getRestyClient() + retry := 0 + endpoint := GetEndpoint(params.ModelName) + var result KBChatResponse + + request, _ := json.Marshal(params) + log.Info("resty request body: %s", request) + +sendjob: + res, err := client.R(). + SetHeader("Content-Type", "application/json"). + SetBody(params). + Post(endpoint + urlKnowledgeBaseChat) + + log.Info("resty status: %+v, route: %+v", res.StatusCode(), res.Request.URL) + + response := res.String() + json.Unmarshal([]byte(response), &result) + log.Info("resty response: %+v", result) + + if err != nil { + return &result, fmt.Errorf("resty SendLLMChat(): %s", err) + } + + if res.StatusCode() == http.StatusUnauthorized && retry < 1 { + retry++ + goto sendjob + } + return &result, nil +} + +func ListKnowledgeBase() (*LLMBasicMsgWithData, error) { + client := getRestyClient() + retry := 0 + endpoint := GetEndpoint("") + var result LLMBasicMsgWithData + +sendjob: + res, err := client.R(). + SetHeader("Content-Type", "application/json"). + SetResult(&result). + Get(endpoint + urlKnowledgeBaseList) + + log.Info("resty status: %+v, route: %+v", res.StatusCode(), res.Request.URL) + + response, _ := json.Marshal(result) + log.Info("resty response: %s", response) + + if err != nil { + return &result, fmt.Errorf("resty ListKnowledgeBase(): %s", err) + } + + if res.StatusCode() == http.StatusUnauthorized && retry < 1 { + retry++ + goto sendjob + } + return &result, nil +} + +func CreateKnowledgeBase(params api.CreateKnowledgeBaseParams) (*LLMBasicMsg, error) { + client := getRestyClient() + retry := 0 + endpoint := GetEndpoint("") + var result LLMBasicMsg + + request, _ := json.Marshal(params) + log.Info("resty request body: %s", request) + +sendjob: + res, err := client.R(). + SetHeader("Content-Type", "application/json"). + SetBody(params). + SetResult(&result). + Post(endpoint + urlKnowledgeBaseCreate) + + log.Info("resty status: %+v, route: %+v", res.StatusCode(), res.Request.URL) + + response, _ := json.Marshal(result) + log.Info("resty response: %s", response) + + if err != nil { + return &result, fmt.Errorf("resty CreateKnowledgeBase(): %s", err) + } + + if res.StatusCode() == http.StatusUnauthorized && retry < 1 { + retry++ + goto sendjob + } + return &result, nil +} + +func DeleteKnowledgeBase(knowledgeBaseName string) (*LLMBasicMsgWithData, error) { + client := getRestyClient() + retry := 0 + endpoint := GetEndpoint("") + var result LLMBasicMsgWithData + + log.Info("resty request body: %s", knowledgeBaseName) + +sendjob: + res, err := client.R(). + SetHeader("Content-Type", "application/text"). + SetBody(knowledgeBaseName). + SetResult(&result). + Post(endpoint + urlKnowledgeBaseDelete) + + log.Info("resty status: %+v, route: %+v", res.StatusCode(), res.Request.URL) + + response, _ := json.Marshal(result) + log.Info("resty response: %s", response) + + if err != nil { + log.Error("resty DeleteKnowledgeBase(): %s", err) + return &result, fmt.Errorf("resty DeleteKnowledgeBase(): %s", err) + } + if res.StatusCode() == http.StatusUnauthorized && retry < 1 { + retry++ + goto sendjob + } + return &result, nil +} + +func KBListFiles(knowledgeBaseName string) (*LLMBasicMsgWithData, error) { + client := getRestyClient() + retry := 0 + endpoint := GetEndpoint("") + var result LLMBasicMsgWithData + + log.Info("resty request body: %s", knowledgeBaseName) + +sendjob: + res, err := client.R(). + SetQueryParams(map[string]string{ + "knowledge_base_name": knowledgeBaseName, + }). + SetHeader("Content-Type", "application/text"). + SetBody(knowledgeBaseName). + SetResult(&result). + Get(endpoint + urlKnowledgeBaseListFiles) + + log.Info("resty status: %+v, route: %+v", res.StatusCode(), res.Request.URL) + + response, _ := json.Marshal(result) + log.Info("resty response: %s", response) + + if err != nil { + log.Error("resty KBListFiles(): %s", err) + return &result, fmt.Errorf("resty KBListFiles(): %s", err) + } + if res.StatusCode() == http.StatusUnauthorized && retry < 1 { + retry++ + goto sendjob + } + return &result, nil +} + +func KBSearchDoc(params api.SearchDocParams) (*SearchDocResponse, error) { + client := getRestyClient() + retry := 0 + endpoint := GetEndpoint("") + var result []SearchDocResult + + log.Info("resty request body: %+v", params) + +sendjob: + res, err := client.R(). + SetHeader("Content-Type", "application/json"). + SetBody(params). + SetResult(&result). + Post(endpoint + urlKnowledgeBaseSearchDoc) + + log.Info("resty status: %+v, route: %+v", res.StatusCode(), res.Request.URL) + + resultAPI := SearchDocResponse{ + Results: result, + } + response, _ := json.Marshal(resultAPI) + log.Info("resty response: %s", response) + + if err != nil { + log.Error("resty KBListFiles(): %s", err) + return &resultAPI, fmt.Errorf("resty KBListFiles(): %s", err) + } + if res.StatusCode() == http.StatusUnauthorized && retry < 1 { + retry++ + goto sendjob + } + return &resultAPI, nil +} + +func KBDeleteDoc(params api.DeleteDocParams) (interface{}, error) { // *LLMDeleteDocMsg, error) { + client := getRestyClient() + retry := 0 + endpoint := GetEndpoint("") + var result LLMDeleteDocMsg + + request, _ := json.Marshal(params) + log.Info("resty request body: %s", request) + +sendjob: + res, err := client.R(). + SetHeader("Content-Type", "application/json"). + SetBody(params). + SetResult(&result). + Post(endpoint + urlKnowledgeBaseDeleteDoc) + + log.Info("resty status: %+v, route: %+v", res.StatusCode(), res.Request.URL) + if err != nil { + var errResult LLMErrorMsg + json.Unmarshal([]byte(res.String()), &errResult) + return &errResult, fmt.Errorf("resty KBDeleteDoc(): %s", err) + } + + response, _ := json.Marshal(result) + log.Info("resty response: %s", response) + + if res.StatusCode() == http.StatusUnauthorized && retry < 1 { + retry++ + goto sendjob + } + return &result, nil +} + +func KBUpdateDoc(params api.UpdateDocParams) (*LLMBasicMsg, error) { + client := getRestyClient() + retry := 0 + endpoint := GetEndpoint("") + var result LLMBasicMsg + + request, _ := json.Marshal(params) + log.Info("resty request body: %s", request) + +sendjob: + res, err := client.R(). + SetHeader("Content-Type", "application/json"). + SetBody(params). + SetResult(&result). + Post(endpoint + urlKnowledgeBaseUpdateDoc) + + log.Info("resty status: %+v, route: %+v", res.StatusCode(), res.Request.URL) + + response, _ := json.Marshal(result) + log.Info("resty response: %s", response) + + if err != nil { + return &result, fmt.Errorf("resty KBDeleteDoc(): %s", err) + } + + if res.StatusCode() == http.StatusUnauthorized && retry < 1 { + retry++ + goto sendjob + } + return &result, nil +} + +func KBRecreateVectorStore(params api.RecreateVectorStoreParams, resultChan chan string, errChan chan error, done chan struct{}) { + client := http.Client{} + endpoint := GetEndpoint("") + requestBody, _ := json.Marshal(params) + log.Info("Request body: %s\n", requestBody) + request, err := http.NewRequest("POST", endpoint+urlKnowledgeBaseRecreate, bytes.NewBuffer(requestBody)) + if err != nil { + log.Error("Error creating request: %v", err) + errChan <- err + return + } + + request.Header.Set("Content-Type", "application/json") + resp, err := client.Do(request) + if err != nil { + log.Error("Error sending request: %v", err) + errChan <- err + return + } + defer resp.Body.Close() + log.Info("Response status: %s\n", resp.Status) + + // Create a buffer to read 2048-byte blocks + buffer := make([]byte, 4096) + for { + // Read a 2048-byte block from the response body + n, err := resp.Body.Read(buffer) + if err != nil { + if err != io.EOF { + errChan <- err + } + break + } + resultChan <- string(buffer[:n]) + } + close(done) +} + +func GetUploadDocUrl() (string, error) { + endpoint := GetEndpoint("") + urlKnowledgeBaseUploadDoc + log.Info("resty GetUploadDocUrl: %s", endpoint) + return endpoint, nil +} + +func writeDocs(fileHeader *multipart.FileHeader, writer *multipart.Writer) error { + filename := fileHeader.Filename + file, err := fileHeader.Open() + if err != nil { + log.Error(err.Error()) + return err + } + defer file.Close() + part, err := writer.CreateFormFile("files", filename) + if err != nil { + log.Error("Error creating form file:", err) + return err + } + _, err = io.Copy(part, file) + return nil +} + +func writeBody(requestBody *bytes.Buffer, form api.LLMChatUploadForm) (string, error) { + writer := multipart.NewWriter(requestBody) + defer writer.Close() + err := writer.WriteField("knowledge_base_name", form.KnowledgeBaseName) + if err != nil { + log.Error("failed to create upload_doc() writer") + return "", err + } + for _, fileHeader := range form.Files { + err = writeDocs(fileHeader, writer) + if err != nil { + log.Error("Error getting doc content: %s", err) + return "", err + } + } + return writer.FormDataContentType(), nil +} + +func UploadDocs(modelName string, form api.LLMChatUploadForm) (*map[string]interface{}, error) { + log.Info("######### received by resty\n") + + var requestBody bytes.Buffer + headerValue, err := writeBody(&requestBody, form) + if err != nil { + log.Error("upload docs write body failed.") + return nil, err + } + + endpoint := GetEndpoint(modelName) + req, err := http.NewRequest("POST", endpoint+urlKnowledgeBaseUploadDoc, &requestBody) + if err != nil { + log.Info("Error creating request:", err) + return nil, err + } + + req.Header.Set("Content-Type", headerValue) + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + log.Info("Error making request:", err) + return nil, err + } + defer resp.Body.Close() + log.Info("############## Response Status:", resp.Status) + + var errResult map[string]interface{} + //if resp.StatusCode == http.StatusUnprocessableEntity { + if resp.StatusCode != http.StatusOK { + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + log.Info("Error reading response body:", err) + return nil, err + } + err = json.Unmarshal(bodyBytes, &errResult) + log.Error("##############upload_docs() errResult: %+v\n", errResult) + return &errResult, nil + } + log.Info("############## Response Body: %+v\n", resp.Body) + + // Parse the response + var result map[string]interface{} + err = json.NewDecoder(resp.Body).Decode(&result) + if err != nil { + log.Info("Error decoding response:", err) + return nil, err + } + return &result, nil +} + +func GetDownloadDocUrl(knowledgeBaseName string, fileName string) (string, error) { + endpoint := GetEndpoint("") + urlKnowledgeBaseDownload + params := "?knowledge_base_name=" + knowledgeBaseName + "&file_name=" + fileName + log.Info("resty GetDownloadDocUrl: %s", endpoint+params) + return endpoint + params, nil +} diff --git a/modules/modelappservice/modelsevice.go b/modules/modelappservice/modelsevice.go index 8b90be3499..7997b88da2 100644 --- a/modules/modelappservice/modelsevice.go +++ b/modules/modelappservice/modelsevice.go @@ -14,9 +14,15 @@ import ( ) var wenxinChannel = make(chan *models.ModelApp, 10000) +var isCD bool func Init() { urls := setting.BaiduWenXin.ModelArtsWenXinURL + if strings.Index(urls, "cdzs.cn") > 0 { + isCD = true + } else { + isCD = false + } urlarray := strings.Split(urls, ",") urlNums := len(urlarray) log.Info("url nums=" + fmt.Sprint(urlNums)) @@ -45,7 +51,13 @@ func consumerOrder(in <-chan *models.ModelApp, url string) { continue } log.Info("goroutine id=" + fmt.Sprint(goroutine_id) + " wenxin text=" + modelApp.Desc) - result, err := modelarts.CreateWenXinJobToCD(modelApp, url) + var result *modelarts.WenXinResult + var err error + if isCD { + result, err = modelarts.CreateWenXinJobToCD(modelApp, url) + } else { + result, err = modelarts.CreateWenXinJob(modelApp, url) + } if err == nil { if !modelarts.SendPictureReivew(result.Result) { modelApp.Status = -1 diff --git a/modules/setting/setting.go b/modules/setting/setting.go index d6d337a568..545997d299 100755 --- a/modules/setting/setting.go +++ b/modules/setting/setting.go @@ -652,6 +652,13 @@ var ( Cron string RunAtStart bool }{} + NotebookStrategy = struct { + ClearEnabled bool + ClearBatchSize int + MaxNumberPerUser int + Cron string + RunAtStart bool + }{} C2NetInfos *C2NetSqInfos CenterInfos *AiCenterInfos @@ -856,6 +863,25 @@ var ( ATTACHEMENT_SIZE_A_USER int64 //G ALL_ATTACHEMENT_NUM_SDK int }{} + + LLM_CHAT_API = struct { + CHATGLM2_HOST string + CHATGLM2_MAX_LENGTH int + LLAMA2_HOST string + LLAMA2_MAX_LENGTH int + COMMON_KB string + MAX_FREE_TRIES int64 + LEGAL_CHECK bool + LEGAL_MAX_COUNT int64 + CHAT_EXPIRED_MINUTES int64 + }{} + + BAIDU_AI = struct { + API_KEY string + SECRET_KEY string + URL string + LEGAL_TEXT_URL string + }{} ) // DateLang transforms standard language locale name to corresponding value in datetime plugin. @@ -1735,12 +1761,30 @@ func NewContext() { BaiduWenXin.RUN_WORKERS = sec.Key("RUN_WORKERS").MustInt(1) BaiduWenXin.MODEL_SERVERS = sec.Key("MODEL_SERVERS").MustInt(1) + sec = Cfg.Section("llm_chat_api") + LLM_CHAT_API.CHATGLM2_HOST = sec.Key("CHATGLM2_HOST").MustString("") + LLM_CHAT_API.CHATGLM2_MAX_LENGTH = sec.Key("CHATGLM2_MAX_LENGTH").MustInt(8192) + LLM_CHAT_API.LLAMA2_HOST = sec.Key("LLAMA2_HOST").MustString("") + LLM_CHAT_API.LLAMA2_MAX_LENGTH = sec.Key("LLAMA2_MAX_LENGTH").MustInt(4096) + LLM_CHAT_API.COMMON_KB = sec.Key("COMMON_KNOWLEDGE_BASE").MustString("") + LLM_CHAT_API.MAX_FREE_TRIES = sec.Key("MAX_FREE_TRIES").MustInt64(200) + LLM_CHAT_API.LEGAL_CHECK = sec.Key("LEGAL_CHECK").MustBool(false) + LLM_CHAT_API.LEGAL_MAX_COUNT = sec.Key("LEGAL_MAX_COUNT").MustInt64(5) + LLM_CHAT_API.CHAT_EXPIRED_MINUTES = sec.Key("CHAT_EXPIRED_MINUTES").MustInt64(30) + + sec = Cfg.Section("baidu_ai") + BAIDU_AI.API_KEY = sec.Key("API_KEY").MustString("") + BAIDU_AI.SECRET_KEY = sec.Key("SECRET_KEY").MustString("") + BAIDU_AI.URL = sec.Key("URL").MustString("https://aip.baidubce.com/oauth/2.0/token") + BAIDU_AI.LEGAL_TEXT_URL = sec.Key("LEGAL_TEXT_URL").MustString("https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token=") + GetGrampusConfig() GetModelartsCDConfig() getModelConvertConfig() getModelSafetyConfig() getModelAppConfig() getClearStrategy() + getNotebookStrategy() NewScreenMapConfig() } @@ -1850,6 +1894,16 @@ func getClearStrategy() { ClearStrategy.RunAtStart = sec.Key("RUN_AT_START").MustBool(false) } +func getNotebookStrategy() { + + sec := Cfg.Section("notebook_strategy") + NotebookStrategy.ClearEnabled = sec.Key("CLEAR_ENABLED").MustBool(false) + NotebookStrategy.ClearBatchSize = sec.Key("CLEAR_BATCH_SIZE").MustInt(300) + NotebookStrategy.MaxNumberPerUser = sec.Key("MAX_NUMBER").MustInt(5) + NotebookStrategy.Cron = sec.Key("CRON").MustString("* 0,0 2-8 * * ?") + NotebookStrategy.RunAtStart = sec.Key("RUN_AT_START").MustBool(false) +} + func GetGrampusConfig() { sec := Cfg.Section("grampus") diff --git a/modules/structs/llm_chat.go b/modules/structs/llm_chat.go new file mode 100644 index 0000000000..624069bf84 --- /dev/null +++ b/modules/structs/llm_chat.go @@ -0,0 +1,82 @@ +package structs + +import "mime/multipart" + +type LLMChatHistory struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type LLMChatMessage struct { + ModelName string `json:"model_name" binding:"Required"` + Query string `json:"query" binding:"Required"` + History []LLMChatHistory `json:"history"` + Stream bool `json:"stream" ` +} + +type KBChatMessage struct { + ModelName string `json:"model_name" binding:"Required"` + Query string `json:"query" binding:"Required"` + KnowledgeBaseName string `json:"knowledge_base_name" binding:"Required"` + History []LLMChatHistory `json:"history"` + Stream bool `json:"stream"` + TopK int `json:"top_k"` + ScoreThreshold float64 `json:"score_threshold"` +} + +type CreateKnowledgeBaseParams struct { + KnowledgeBaseName string `json:"knowledge_base_name"` + VectorStoreType string `json:"vector_store_type"` + EmbedModel string `json:"embed_model"` +} + +type SearchDocParams struct { + Query string `json:"query"` + KnowledgeBaseName string `json:"knowledge_base_name"` + TopK int `json:"top_k"` + ScoreThreshold float64 `json:"score_threshold"` +} + +type DeleteDocParams struct { + KnowledgeBaseName string `json:"knowledge_base_name" binding:"Required"` + FileNames []string `json:"file_names" binding:"Required"` + DeleteContent bool `json:"delete_content"` + NotRefreshVsCache bool `json:"not_refresh_vs_cache"` +} + +type UpdateDocParams struct { + KnowledgeBaseName string `json:"knowledge_base_name"` + FileNames string `json:"file_names"` + NotRefreshVsCache bool `json:"not_refresh_vs_cache"` +} + +type RecreateVectorStoreParams struct { + KnowledgeBaseName string `json:"knowledge_base_name"` + AllowEmptyKb bool `json:"allow_empty_kb"` + VsType string `json:"vs_type"` + EmbedModel string `json:"embed_model"` +} + +type LLMChatCountsResults struct { + MaxTries int64 `json:"max_tries"` + Counts int64 `json:"counts"` + CanChat bool `json:"can_chat"` + //FirstVisit bool `json:"first_visit"` +} + +type KBChatAnswer struct { + Answer string `json:"answer"` +} + +type KBChatDocs struct { + Docs []string `json:"docs"` +} +type LegalTextParams struct { + Text string `json:"text"` +} + +type LLMChatUploadForm struct { + KnowledgeBaseName string `form:"knowledge_base_name"` + Files []*multipart.FileHeader `form:"files"` + Override bool `form:"override"` +} diff --git a/modules/templates/helper.go b/modules/templates/helper.go index 18b92c02dc..205ef2b9ba 100755 --- a/modules/templates/helper.go +++ b/modules/templates/helper.go @@ -107,6 +107,15 @@ func NewFuncMap() []template.FuncMap { "DebugAttachSize": func() int { return setting.DebugAttachSize * 1000 * 1000 * 1000 }, + "LlmCommonKB": func() string { + return setting.LLM_CHAT_API.COMMON_KB + }, + "LlmMaxCounts": func() string { + return strconv.FormatInt(setting.LLM_CHAT_API.MAX_FREE_TRIES, 10) + }, + "LlmExpireMinutes": func() string { + return strconv.FormatInt(setting.LLM_CHAT_API.CHAT_EXPIRED_MINUTES, 10) + }, "AvatarLink": models.AvatarLink, "Safe": Safe, "SafeJS": SafeJS, @@ -260,7 +269,7 @@ func NewFuncMap() []template.FuncMap { return dict, nil }, "Printf": fmt.Sprintf, - "ToLower": strings.ToLower, + "ToLower": strings.ToLower, "Escape": Escape, "Sec2Time": models.SecToTime, "ParseDeadline": func(deadline string) []string { @@ -423,7 +432,7 @@ func NewTextFuncMap() []texttmpl.FuncMap { return dict, nil }, "Printf": fmt.Sprintf, - "ToLower": strings.ToLower, + "ToLower": strings.ToLower, "Escape": Escape, "Sec2Time": models.SecToTime, "ParseDeadline": func(deadline string) []string { diff --git a/modules/urfs_client/objectstorage/mocks/objectstorage_mock.go b/modules/urfs_client/objectstorage/mocks/objectstorage_mock.go index baa34f437e..0b09403dd9 100644 --- a/modules/urfs_client/objectstorage/mocks/objectstorage_mock.go +++ b/modules/urfs_client/objectstorage/mocks/objectstorage_mock.go @@ -1,5 +1,9 @@ // Code generated by MockGen. DO NOT EDIT. // Source: objectstorage.go -// Package mocks is a generated GoMock package. package mocks + +import ( + gomock "github.com/golang/mock/gomock" + reflect "reflect" +) diff --git a/options/locale/locale_en-US.ini b/options/locale/locale_en-US.ini index 7809d934f2..094940c177 100755 --- a/options/locale/locale_en-US.ini +++ b/options/locale/locale_en-US.ini @@ -354,7 +354,7 @@ reset_password_mail_sent_prompt = A confirmation email has been sent to %s%s). has_unconfirmed_mail_resend = If you did not receive the activation email, or need to resend it, please click the "Resend your activation email" button below. @@ -1125,6 +1125,7 @@ images.name_placerholder = Please enter the image name images.descr_placerholder = The description should not exceed 1000 characters image.label_tooltips = Example Python 3.7, Tensorflow 2.0, cuda 10, pytorch 1.6 images.public_tooltips = After the image is set to public, it can be seen by other users. +images.submit_tooltips = The code directory /code, dataset directory /dataset will not be submitted with the image, and other directories will be packaged into the image. images.name_format_err=The format of image tag is wrong. images.name_rule50 = Please enter letters, numbers, _ and - up to 50 characters and starts with a letter. images.name_rule100 = Please enter letters, numbers, _ and - up to 100 characters and cannot end with a dash (-). @@ -3466,6 +3467,7 @@ branch_not_exists = The branch does not exist. Please refresh and select again. dataset_number_over_limit = The dataset count exceed the limit result_cleared=The files of the task have been cleared, can not restart or retrain any more, please create a new task instead model_not_exist=The model in the task does not exist or has been deleted +too_many_notebook=A user can have up to 5 debug tasks, please try again after delete some debug tasks. [common_error] system_error = System error.Please try again later @@ -3482,3 +3484,11 @@ builidng_fail = Failed to build AI Model, please try again later deletion_notice_repo = There is a deploying or running service related to this repository, please stop the service before deletion. deletion_notice_trainjob = There is a deploying or running service related to this task, please stop the service before deletion. stop_service_failed = Failed to stop deploy service + + +[llm_chat] +chat_expired = Chat session expired, please create a new chat. +max_free_exceed = You have reached the maximum number of free chat. +query_empty = Empty prompt is not allowed. +query_too_long = Your prompt is too long. +server_error = Server busy, please try again later. diff --git a/options/locale/locale_zh-CN.ini b/options/locale/locale_zh-CN.ini index 4fd5e33c0d..3b3c3f0184 100755 --- a/options/locale/locale_zh-CN.ini +++ b/options/locale/locale_zh-CN.ini @@ -357,7 +357,7 @@ reset_password_mail_sent_prompt=确认电子邮件已被发送到 %s。 active_your_account=激活您的帐户 account_activated=帐户已激活 prohibit_login=禁止登录 -prohibit_login_desc=您的帐户被禁止登录,请与网站管理员联系。 +prohibit_login_desc=您的帐户被禁止登录,请与网站管理员联系:secretariat@openi.org.cn resent_limit_prompt=您请求发送激活邮件过于频繁,请等待 3 分钟后再试! has_unconfirmed_mail=%s 您好,系统检测到您有一封发送至 %s 但未被确认的邮件。 has_unconfirmed_mail_resend=如果您未收到激活邮件,或需要重新发送,请单击下方的 "重新发送确认邮件 " 按钮。 @@ -1125,6 +1125,7 @@ images.name_placerholder = 请输入镜像Tag images.descr_placerholder = 描述字数不超过1000个字符 image.label_tooltips = 如Python 3.7, Tensorflow 2.0, cuda 10, pytorch 1.6 images.public_tooltips = 镜像设置为公开后,可被其他用户看到。 +images.submit_tooltips = 代码目录/code,数据集目录/dataset不会随镜像提交,其他目录都会打包到镜像中。 images.name_format_err=镜像Tag格式错误。 images.name_rule50 = 请输入字母、数字、_和-,最长50个字符,且以字母开头。 images.name_rule100 = 请输入字母、数字、_和-,最长100个字符,且不能以中划线(-)结尾。 @@ -3489,6 +3490,7 @@ branch_not_exists = 代码分支不存在,请刷新后重试 dataset_number_over_limit = 选择的数据集文件数量超出限制 result_cleared=源任务的文件已被清理,无法再次调试或复用训练结果,请新建任务。 model_not_exist=选择的预训练模型不存在或者已被删除 +too_many_notebook=每个用户最多只能创建5个调试任务,请删除历史任务再新建。 [common_error] @@ -3506,3 +3508,10 @@ builidng_fail = AI应用创建失败 deletion_notice_repo = 此项目有正在部署或正在体验的服务,请先停止服务,然后再删除。 deletion_notice_trainjob = 此任务有正在部署或正在体验的服务,请先停止服务,然后再删除。 stop_service_failed = 停止部署服务失败 + +[llm_chat] +chat_expired = 对话已过期,请重新创建对话 +max_free_exceed = 您的对话次数已达上限 +query_empty = 您发送的指令不能为空 +query_too_long = 您发送的指令长度超过限制 +server_error = 服务器繁忙,请稍后再试 diff --git a/routers/api/v1/api.go b/routers/api/v1/api.go index a34d4bfc31..c820a3b738 100755 --- a/routers/api/v1/api.go +++ b/routers/api/v1/api.go @@ -88,6 +88,7 @@ import ( "code.gitea.io/gitea/modules/setting" api "code.gitea.io/gitea/modules/structs" "code.gitea.io/gitea/routers/api/v1/admin" + "code.gitea.io/gitea/routers/api/v1/llm_chat" "code.gitea.io/gitea/routers/api/v1/misc" "code.gitea.io/gitea/routers/api/v1/notify" "code.gitea.io/gitea/routers/api/v1/org" @@ -651,8 +652,8 @@ func RegisterRoutes(m *macaron.Macaron) { m.Group("/:username/:reponame", func() { m.Group("/ai_task", func() { m.Post("/create", reqWeChatStandard(), reqRepoWriter(models.UnitTypeCloudBrain), bind(entity.CreateReq{}), ai_task.CreateAITask) - m.Post("/stop", reqWeChatStandard(), reqRepoWriter(models.UnitTypeCloudBrain), reqAITaskInRepo(), reqAdminOrOwnerAITaskCreator(), ai_task.StopAITask) - m.Post("/del", reqWeChatStandard(), reqRepoWriter(models.UnitTypeCloudBrain), reqAITaskInRepo(), reqAdminOrOwnerAITaskCreator(), ai_task.DelAITask) + m.Post("/stop", reqRepoWriter(models.UnitTypeCloudBrain), reqAITaskInRepo(), reqAdminOrOwnerAITaskCreator(), ai_task.StopAITask) + m.Post("/del", reqRepoWriter(models.UnitTypeCloudBrain), reqAITaskInRepo(), reqAdminOrOwnerAITaskCreator(), ai_task.DelAITask) m.Post("/restart", reqWeChatStandard(), reqRepoWriter(models.UnitTypeCloudBrain), reqAITaskInRepo(), reqAdminOrAITaskCreator(), ai_task.RestartAITask) m.Get("/debug_url", reqWeChatStandard(), reqRepoWriter(models.UnitTypeCloudBrain), reqAITaskInRepo(), ai_task.GetNotebookUrl) m.Get("/creation/required", reqWeChatStandard(), reqRepoWriter(models.UnitTypeCloudBrain), ai_task.GetCreationRequiredInfo) @@ -723,6 +724,32 @@ func RegisterRoutes(m *macaron.Macaron) { m.Get("/spec", finetune.GetSpec) }, reqToken()) + // llm_chat + m.Group("/llm", func() { + m.Get("/stats", reqAdmin(), llm_chat.GetChatStats) + m.Group("/chat", func() { + m.Get("/counts", llm_chat.GetFreeTries) + m.Post("/visit", llm_chat.NewVisit) + m.Post("/agree", llm_chat.SaveAgreement) + m.Post("/legaltext", bind(api.LegalTextParams{}), llm_chat.LegalText) + m.Post("/chat", bind(api.LLMChatMessage{}), llm_chat.LLMChat) + m.Post("/knowledge_base_chat", bind(api.KBChatMessage{}), llm_chat.KBChat) + }) + m.Group("/knowledge_base", func() { + m.Get("/list", llm_chat.ListKnowledgeBase) + m.Post("/create", bind(api.CreateKnowledgeBaseParams{}), llm_chat.CreateKnowledgeBase) + m.Post("/delete", llm_chat.DeleteKnowledgeBase) + m.Get("/list_files", llm_chat.ListFiles) + m.Post("/search_docs", bind(api.SearchDocParams{}), llm_chat.SearchDoc) + m.Post("/delete_doc", bind(api.DeleteDocParams{}), llm_chat.DeleteDoc) + m.Post("/update_doc", llm_chat.UpdateDoc) + m.Post("/recreate_vector_store", llm_chat.RecreateVectorStore) + m.Get("/upload_doc_url", llm_chat.UploadDocUrl) + m.Post("/upload_doc", binding.MultipartForm(api.LLMChatUploadForm{}), llm_chat.UploadDoc) + m.Get("/download_doc_url", llm_chat.DownloadDoc) + }) + }, reqToken(), reqWeChatStandard()) + m.Group("/reward_point", func() { m.Get("/is_admin", user.IsRewardPointAdmin) m.Group("/list", func() { @@ -849,10 +876,10 @@ func RegisterRoutes(m *macaron.Macaron) { // Users m.Group("/users", func() { - m.Get("/search", user.Search) + m.Get("/search", reqToken(), user.Search) m.Group("/:username", func() { - m.Get("", user.GetInfo) + m.Get("", reqToken(), user.GetInfo) m.Get("/heatmap", mustEnableUserHeatmap, user.GetUserHeatmapData) m.Get("/repos", user.ListUserRepos) @@ -966,7 +993,7 @@ func RegisterRoutes(m *macaron.Macaron) { }) m.Group("/repos", func() { - m.Get("/search", repo.Search) + m.Get("/search", reqToken(), repo.Search) m.Get("/issues/search", repo.SearchIssues) diff --git a/routers/api/v1/llm_chat/llm_chat.go b/routers/api/v1/llm_chat/llm_chat.go new file mode 100644 index 0000000000..da3b7e4367 --- /dev/null +++ b/routers/api/v1/llm_chat/llm_chat.go @@ -0,0 +1,327 @@ +package llm_chat + +import ( + "code.gitea.io/gitea/models" + baiduAPI "code.gitea.io/gitea/modules/baiduai" + "code.gitea.io/gitea/modules/context" + "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/setting" + api "code.gitea.io/gitea/modules/structs" + llmService "code.gitea.io/gitea/services/llm_chat" + "net/http" + "strings" + "time" +) + +const ( + Chatglm = "chatglm2-6b" + Llama = "llama2-7b-chat-hf" + FlagExpired = "" +) + +func chatPreCheck(ctx *context.APIContext, isPrompt bool) *models.LlmChatVisit { + modelName := ctx.Query("model_name") + if modelName == "" { + ctx.Error(http.StatusBadRequest, "model_name can't be empty", "model_name can't be empty") + return nil + } + userID := ctx.User.ID + currentTime := time.Now() + hasChat, _ := models.QueryRunningChat(userID, modelName, currentTime.Unix()) + + if hasChat == nil { + if isPrompt { + ctx.Resp.Header().Set("Content-Type", "application/octet-stream; charset=utf-8") + ctx.Resp.Write([]byte(FlagExpired)) + ctx.Resp.Flush() + return nil + } else { + ctx.JSON(http.StatusOK, map[string]string{ + "code": "-1", + "msg": ctx.Tr("llm_chat.chat_expired"), + }) + log.Error("userID %d : no running chat session for model %s.", ctx.User.ID, modelName) + return nil + } + } + + counts := models.QueryChatCount(ctx.User.ID, modelName) + maxTires := setting.LLM_CHAT_API.MAX_FREE_TRIES + if counts >= maxTires { + ctx.JSON(http.StatusOK, map[string]string{ + "code": "-1", + "msg": ctx.Tr("llm_chat.max_free_exceed"), + }) + log.Error("userID %d : max free times exceed %d.", ctx.User.ID, maxTires) + return nil + } + + return hasChat +} + +func promptPreCheck(ctx *context.APIContext, prompt string, modelName string) bool { + queryLen := len(strings.TrimSpace(prompt)) + log.Info("query length check: %d tokens\n", queryLen) + if queryLen == 0 { + ctx.JSON(http.StatusOK, map[string]string{ + "code": "-1", + "msg": ctx.Tr("llm_chat.query_empty"), + }) + log.Error("userID %d : query can't be empty.", ctx.User.ID) + return false + } + + lenFlag := false + if modelName == Chatglm { + lenFlag = queryLen > setting.LLM_CHAT_API.CHATGLM2_MAX_LENGTH + } else if modelName == Llama { + lenFlag = queryLen > setting.LLM_CHAT_API.LLAMA2_MAX_LENGTH + } + if lenFlag { + ctx.JSON(http.StatusOK, map[string]string{ + "code": "-1", + "msg": ctx.Tr("llm_chat.query_too_long"), + }) + log.Error("userID %d : query length too long.", ctx.User.ID) + return false + } + + return true +} + +func LLMChat(ctx *context.APIContext, data api.LLMChatMessage) { + log.Info("LLM chat by api.") + hasChat := chatPreCheck(ctx, true) + promptFlag := promptPreCheck(ctx, data.Query, data.ModelName) + if !promptFlag || hasChat == nil { + log.Error("userID %d : chat prompt pre-check failed.", ctx.User.ID) + return + } + if data.Stream { + llmService.StreamLLMChatService(ctx.Context, data, hasChat) + } else { + ctx.JSON(http.StatusInternalServerError, "currently not supported") + //llmService.LLMChatService(ctx.Context, data, hasChat) + } +} + +func KBChat(ctx *context.APIContext, data api.KBChatMessage) { + log.Info("LLM KnowledgeBase chat by api.") + hasChat := chatPreCheck(ctx, true) + promptFlag := promptPreCheck(ctx, data.Query, data.ModelName) + if !promptFlag || hasChat == nil { + log.Error("userID %d : chat prompt pre-check failed.", ctx.User.ID) + return + } + if data.Stream { + llmService.StreamKBChatService(ctx.Context, data, hasChat) + } else { + ctx.JSON(http.StatusInternalServerError, "currently not supported") + //llmService.KBChatService(ctx.Context, data, hasChat) + } +} + +func ListKnowledgeBase(ctx *context.APIContext) { + log.Info("LLM list KnowledgeBase by api.") + hasChat := chatPreCheck(ctx, false) + if hasChat == nil { + return + } + llmService.ListKnowledgeBaseService(ctx.Context) +} + +func CreateKnowledgeBase(ctx *context.APIContext, data api.CreateKnowledgeBaseParams) { + log.Info("LLM create KnowledgeBase by api.") + hasChat := chatPreCheck(ctx, false) + if hasChat == nil { + return + } + llmService.CreateKnowledgeBaseService(ctx.Context, data) +} + +func DeleteKnowledgeBase(ctx *context.APIContext) { + log.Info("LLM delete KnowledgeBase by api.") + hasChat := chatPreCheck(ctx, false) + if hasChat == nil { + return + } + llmService.DeleteKnowledgeBaseService(ctx.Context) +} + +func ListFiles(ctx *context.APIContext) { + log.Info("LLM list files by api.") + hasChat := chatPreCheck(ctx, false) + if hasChat == nil { + return + } + llmService.ListFilesService(ctx.Context) +} + +func SearchDoc(ctx *context.APIContext, data api.SearchDocParams) { + log.Info("LLM search doc by api.") + hasChat := chatPreCheck(ctx, false) + if hasChat == nil { + return + } + llmService.SearchDocService(ctx.Context, data) +} + +func DeleteDoc(ctx *context.APIContext, data api.DeleteDocParams) { + log.Info("LLM delete doc by api.") + hasChat := chatPreCheck(ctx, false) + if hasChat == nil { + return + } + llmService.DeleteDocService(ctx.Context, data) + +} + +func UpdateDoc(ctx *context.APIContext) { + log.Info("LLM update doc by api.") + hasChat := chatPreCheck(ctx, false) + if hasChat == nil { + return + } + llmService.UpdateDocService(ctx.Context) +} + +func RecreateVectorStore(ctx *context.APIContext) { + log.Info("LLM recreate vector store by api.") + hasChat := chatPreCheck(ctx, false) + if hasChat == nil { + return + } + llmService.RecreateVectorStoreService(ctx.Context) +} + +func UploadDocUrl(ctx *context.APIContext) { + log.Info("LLM upload doc by api.") + hasChat := chatPreCheck(ctx, false) + if hasChat == nil { + return + } + llmService.UploadDocUrlService(ctx.Context) +} + +func UploadDoc(ctx *context.APIContext, form api.LLMChatUploadForm) { + log.Info("LLM upload doc by api.") + hasChat := chatPreCheck(ctx, false) + if hasChat == nil { + return + } + llmService.UploadDocService(ctx.Context, form) +} + +func DownloadDoc(ctx *context.APIContext) { + log.Info("LLM download doc by api.") + hasChat := chatPreCheck(ctx, false) + if hasChat == nil { + return + } + llmService.DownloadDocService(ctx.Context) +} + +func GetFreeTries(ctx *context.APIContext) { + log.Info("LLM get free tries by api.") + llmService.GetFreeTriesService(ctx.Context) +} + +func LegalText(ctx *context.APIContext, data api.LegalTextParams) { + log.Info("LLM get chat counts by api.") + res, err := baiduAPI.CheckLegalText(data.Text) + if err != nil { + log.Error("CheckLegalText failed: %s", err) + ctx.Error(http.StatusInternalServerError, "CheckLegalText failed", err.Error()) + return + } + ctx.JSON(http.StatusOK, res) +} + +func NewVisit(ctx *context.APIContext) { + log.Info("LLM new visit by api.") + currentTime := time.Now() + modelName := ctx.Query("model_name") + + hasChat, _ := models.QueryRunningChat(ctx.User.ID, modelName, currentTime.Unix()) + if hasChat != nil { + ctx.JSON(http.StatusOK, map[string]string{ + "chat_id": hasChat.ChatId, + "model_name": hasChat.ModelName, + "expired_at": hasChat.ExpiredTime, + }) + log.Info("hasChat %s, expired at %s\n", hasChat.ChatId, hasChat.ExpiredTime) + return + } + + chatID := ctx.User.Name + currentTime.Format("20060102150405") + duration := time.Duration(setting.LLM_CHAT_API.CHAT_EXPIRED_MINUTES) + endTime := currentTime.Add(time.Minute * duration) + entTimeStr := endTime.Format("2006-01-02 15:04:05") + llmChatVisit := &models.LlmChatVisit{ + UserId: ctx.User.ID, + ChatId: chatID, + ModelName: modelName, + ExpiredUnix: endTime.Unix(), + ExpiredTime: entTimeStr, + Agreement: 0, + } + models.SaveVisit(llmChatVisit) + log.Info("new chat %s, expired at %s\n", chatID, entTimeStr) + ctx.JSON(http.StatusOK, map[string]string{ + "chat_id": llmChatVisit.ChatId, + "model_name": llmChatVisit.ModelName, + "expired_at": llmChatVisit.ExpiredTime, + }) +} + +func SaveAgreement(ctx *context.APIContext) { + hasChat := chatPreCheck(ctx, false) + if hasChat == nil { + return + } + if hasChat.Agreement == 1 { + ctx.JSON(http.StatusOK, map[string]string{ + "code": "-1", + "msg": "already saved Agreement", + }) + return + } + + hasChat.Agreement = 1 + models.UpdateChat(hasChat) + ctx.JSON(http.StatusOK, map[string]string{ + "code": "1", + "msg": "successfully saved Agreement status", + }) +} + +func GetChatStats(ctx *context.APIContext) { + log.Info("LLM chat stats by api.") + + resChat, err := models.QueryChatStatistics() + if err != nil { + log.Error("QueryChatStatistics failed: %s", err) + ctx.JSON(http.StatusInternalServerError, "QueryChatStatistics failed") + return + } + + resVisit, err := models.QueryChatVisitStatistics() + if err != nil { + log.Error("QueryChatVisitStatistics failed: %s", err) + ctx.JSON(http.StatusInternalServerError, "QueryChatVisitStatistics failed") + return + } + + res := make(map[string]interface{}) + for _, chat := range resChat { + for _, visit := range resVisit { + if chat["model_name"] == visit["model_name"] { + chat["visit"] = visit["visit"] + chat["visit_user"] = visit["visit_user"] + } + } + res[chat["model_name"].(string)] = chat + } + + ctx.JSON(http.StatusOK, res) +} diff --git a/routers/api/v1/repo/attachments.go b/routers/api/v1/repo/attachments.go index 795a62b14e..413b367685 100644 --- a/routers/api/v1/repo/attachments.go +++ b/routers/api/v1/repo/attachments.go @@ -2,6 +2,7 @@ package repo import ( "net/http" + "strings" "sync" "code.gitea.io/gitea/modules/log" @@ -12,6 +13,7 @@ import ( ) var mutex *sync.Mutex = new(sync.Mutex) +var modelMutex *sync.Mutex = new(sync.Mutex) func GetSuccessChunks(ctx *context.APIContext) { if errStr := checkDatasetPermission(ctx); errStr != "" { @@ -146,7 +148,27 @@ func NewModelMultipart(ctx *context.APIContext) { return } - routeRepo.NewModelMultipart(ctx.Context) + if err := routeRepo.CheckFlowForModelSDK(); err != nil { + ctx.JSON(200, map[string]string{ + "result_code": "-1", + "msg": err.Error(), + }) + return + } + modelMutex.Lock() + defer modelMutex.Unlock() + fileName := ctx.Query("file_name") + re, err := routeRepo.NewModelMultipartForApi(ctx.Context, true) + if err != nil { + ctx.JSON(200, map[string]string{ + "result_code": "-1", + "msg": err.Error(), + }) + } else { + routeRepo.AddModelFileNameToCache(modeluuid, fileName, ctx.User.ID) + re["result_code"] = "0" + ctx.JSON(200, re) + } } func checkModelPermission(ctx *context.APIContext, model *models.AiModelManage) string { @@ -178,5 +200,14 @@ func GetModelMultipartUploadUrl(ctx *context.APIContext) { func CompleteModelMultipart(ctx *context.APIContext) { log.Info("CompleteModelMultipart by api.") + modeluuid := ctx.Query("modeluuid") + //fileName := ctx.Query("file_name") + uuid := ctx.Query("uuid") + fileChunk, err := models.GetModelFileChunkByUUID(uuid) + if err == nil { + log.Info("fileChunk.ObjectName=" + fileChunk.ObjectName) + objectNames := strings.Split(fileChunk.ObjectName, "/") + routeRepo.RemoveModelFileFromCache(modeluuid, objectNames[len(objectNames)-1], ctx.User.ID) + } routeRepo.CompleteModelMultipart(ctx.Context) } diff --git a/routers/modelapp/wenxin.go b/routers/modelapp/wenxin.go index 7f84bfc9c0..cea1f78ba8 100644 --- a/routers/modelapp/wenxin.go +++ b/routers/modelapp/wenxin.go @@ -1,17 +1,21 @@ package modelapp import ( - "fmt" - "code.gitea.io/gitea/models" "code.gitea.io/gitea/modules/base" "code.gitea.io/gitea/modules/context" "code.gitea.io/gitea/modules/modelappservice" + "code.gitea.io/gitea/modules/setting" + "fmt" uuid "github.com/satori/go.uuid" + "net/http" + "time" ) var modelMainTpl base.TplName = "model/index" var modelWenXinTpl base.TplName = "model/wenxin/index" +var modelLLMChatTpl base.TplName = "model/llmchat/index" +var modelLLMChatCreateTpl base.TplName = "model/llmchat/create/index" const WAIT_TIME int = 7 @@ -23,6 +27,51 @@ func WenXinPage(ctx *context.Context) { ctx.HTML(200, modelWenXinTpl) } +func LLMChatCreate(ctx *context.Context) { + currentTime := time.Now() + modelName := ctx.Query("model_name") + hasChat, _ := models.QueryRunningChat(ctx.User.ID, modelName, currentTime.Unix()) + if hasChat != nil { + ctx.Data["chatID"] = hasChat.ChatId + } else { + ctx.Data["chatID"] = "-1" + } + ctx.HTML(200, modelLLMChatCreateTpl) + return +} + +func LLMChatPage(ctx *context.Context) { + chatID := ctx.Params("chatID") + modelName := ctx.Query("model_name") + + currentTime := time.Now().Unix() + hasChat, _ := models.QueryByChatId(chatID) + if hasChat != nil { + if ctx.User.ID != hasChat.UserId || hasChat.ModelName != modelName { + ctx.Data["Title"] = "Page Not Found" + ctx.HTML(http.StatusNotFound, base.TplName("status/404")) + return + } + if currentTime > hasChat.ExpiredUnix { + ctx.Redirect(setting.AppSubURL + "/extension/llm_chat/create?model_name=" + modelName) //HTML(200, modelLLMChatCreateTpl) + return + } + ctx.Data["expiredUnix"] = hasChat.ExpiredUnix + ctx.Data["landingUnix"] = currentTime + } + + firstVisit := models.QueryFirstVisit(ctx.User.ID) + ctx.Data["firstVisit"] = firstVisit == int64(0) + + counts := models.QueryChatCount(ctx.User.ID, modelName) + maxTries := setting.LLM_CHAT_API.MAX_FREE_TRIES + ctx.Data["can_chat"] = counts < maxTries + ctx.Data["counts"] = counts + ctx.Data["max_tries"] = maxTries + + ctx.HTML(200, modelLLMChatTpl) +} + func WenXinPaintNew(ctx *context.Context) { textDesc := ctx.Query("textDesc") uuid := uuid.NewV4() diff --git a/routers/repo/attachment_model.go b/routers/repo/attachment_model.go index fcc4a8fdf2..33c2e93e46 100644 --- a/routers/repo/attachment_model.go +++ b/routers/repo/attachment_model.go @@ -1,6 +1,7 @@ package repo import ( + "errors" "fmt" "path" "strconv" @@ -142,37 +143,48 @@ func getObjectName(filename string, modeluuid string) string { } func NewModelMultipart(ctx *context.Context) { - if !setting.Attachment.Enabled { - ctx.Error(404, "attachment is not enabled") + re, err := NewModelMultipartForApi(ctx, false) + if err != nil { + ctx.ServerError("NewMultipart failed", err) return } + ctx.JSON(200, re) +} + +func NewModelMultipartForApi(ctx *context.Context, isFlowControl bool) (map[string]string, error) { + if !setting.Attachment.Enabled { + return nil, errors.New("attachment is not enabled") + } fileName := ctx.Query("file_name") modeluuid := ctx.Query("modeluuid") err := upload.VerifyFileType(ctx.Query("fileType"), strings.Split(setting.Attachment.AllowedTypes, ",")) if err != nil { - ctx.Error(400, err.Error()) - return + return nil, err + } + if isFlowControl { + err = CheckFlowForModel(ctx) + if err != nil { + log.Info("check error," + err.Error()) + return nil, err + } } - typeCloudBrain := ctx.QueryInt("type") err = checkTypeCloudBrain(typeCloudBrain) if err != nil { - ctx.ServerError("checkTypeCloudBrain failed", err) - return + return nil, err } if setting.Attachment.StoreType == storage.MinioStorageType { totalChunkCounts := ctx.QueryInt("totalChunkCounts") if totalChunkCounts > minio_ext.MaxPartsCount { - ctx.Error(400, fmt.Sprintf("chunk counts(%d) is too much", totalChunkCounts)) - return + return nil, errors.New(fmt.Sprintf("chunk counts(%d) is too much", totalChunkCounts)) + } fileSize := ctx.QueryInt64("size") if fileSize > minio_ext.MaxMultipartPutObjectSize { - ctx.Error(400, fmt.Sprintf("file size(%d) is too big", fileSize)) - return + return nil, errors.New(fmt.Sprintf("file size(%d) is too big", fileSize)) } uuid := gouuid.NewV4().String() @@ -182,16 +194,14 @@ func NewModelMultipart(ctx *context.Context) { objectName = strings.TrimPrefix(path.Join(Model_prefix, path.Join(modeluuid[0:1], modeluuid[1:2], modeluuid, fileName)), "/") uploadID, err = storage.NewMultiPartUpload(objectName) if err != nil { - ctx.ServerError("NewMultipart", err) - return + return nil, err } } else { objectName = strings.TrimPrefix(path.Join(Model_prefix, path.Join(modeluuid[0:1], modeluuid[1:2], modeluuid, fileName)), "/") uploadID, err = storage.NewObsMultiPartUpload(objectName) if err != nil { - ctx.ServerError("NewObsMultiPartUpload", err) - return + return nil, err } } @@ -208,17 +218,15 @@ func NewModelMultipart(ctx *context.Context) { }) if err != nil { - ctx.Error(500, fmt.Sprintf("InsertFileChunk: %v", err)) - return - } + return nil, err - ctx.JSON(200, map[string]string{ + } + return map[string]string{ "uuid": uuid, "uploadID": uploadID, - }) + }, nil } else { - ctx.Error(404, "storage type is not enabled") - return + return nil, errors.New("storage type is not enabled") } } diff --git a/routers/repo/flow_control.go b/routers/repo/flow_control.go index d73273e88a..f56d051cb0 100644 --- a/routers/repo/flow_control.go +++ b/routers/repo/flow_control.go @@ -16,10 +16,12 @@ import ( ) const ( - REDIS_FLOW_ATTACHMENT_KEY = "flow_attachment_key" + REDIS_FLOW_ATTACHMENT_KEY = "flow_attachment_key" + REDIS_FLOW_MODEL_ATTACHMENT_KEY = "flow_model_attachment_key" ) var mutex *sync.RWMutex = new(sync.RWMutex) +var modelMutex *sync.Mutex = new(sync.Mutex) func CheckFlowForDataset(ctx *context.Context) error { if ctx.User == nil { @@ -86,6 +88,31 @@ func AddFileNameToCache(datasetId int64, fileName string, userId int64) { setSDKUploadFileCache(REDIS_FLOW_ATTACHMENT_KEY, cacheMap) } +func AddModelFileNameToCache(modelId string, fileName string, userId int64) { + modelMutex.Lock() + defer modelMutex.Unlock() + cacheMap := getSDKUploadFileMap(REDIS_FLOW_MODEL_ATTACHMENT_KEY) + expireTimeKeys := make([]string, 0) + currentTime := time.Now().Unix() + for tmpKey, tmpValue := range cacheMap { + time, err := strconv.ParseInt(tmpValue, 10, 64) + if err == nil { + if currentTime-time > 24*3600 { + expireTimeKeys = append(expireTimeKeys, tmpKey) + continue + } + } + } + for _, delKey := range expireTimeKeys { + delete(cacheMap, delKey) + } + key := modelId + "_" + fileName + "_" + fmt.Sprint(userId) + value := fmt.Sprint(time.Now().Unix()) + cacheMap[key] = value + log.Info("set key=" + key + " value=" + value + " to cache.") + setSDKUploadFileCache(REDIS_FLOW_MODEL_ATTACHMENT_KEY, cacheMap) +} + func RemoveFileFromCache(datasetId int64, fileName string, userId int64) { mutex.Lock() defer mutex.Unlock() @@ -96,6 +123,16 @@ func RemoveFileFromCache(datasetId int64, fileName string, userId int64) { setSDKUploadFileCache(REDIS_FLOW_ATTACHMENT_KEY, cacheMap) } +func RemoveModelFileFromCache(modelId string, fileName string, userId int64) { + modelMutex.Lock() + defer modelMutex.Unlock() + key := modelId + "_" + fileName + "_" + fmt.Sprint(userId) + cacheMap := getSDKUploadFileMap(REDIS_FLOW_MODEL_ATTACHMENT_KEY) + delete(cacheMap, key) + log.Info("remove key=" + key + " from cache.") + setSDKUploadFileCache(REDIS_FLOW_MODEL_ATTACHMENT_KEY, cacheMap) +} + func getSDKUploadFileMap(msgKey string) map[string]string { valueStr, err := redis_client.Get(msgKey) msgMap := make(map[string]string, 0) @@ -144,3 +181,64 @@ func CheckFlowForDatasetSDK() error { } return nil } + +func CheckFlowForModelSDK() error { + cacheMap := getSDKUploadFileMap(REDIS_FLOW_MODEL_ATTACHMENT_KEY) + currentTime := time.Now().Unix() + count := 0 + for _, tmpValue := range cacheMap { + time, err := strconv.ParseInt(tmpValue, 10, 64) + if err == nil { + if currentTime-time > 24*3600 { + continue + } + } + count += 1 + } + log.Info("total find " + fmt.Sprint(count) + " uploading files.") + if count >= setting.FLOW_CONTROL.ALL_ATTACHEMENT_NUM_SDK { + log.Info("The number of model files uploaded using the SDK simultaneously cannot exceed " + fmt.Sprint(setting.FLOW_CONTROL.ALL_ATTACHEMENT_NUM_SDK)) + return errors.New("The number of model files uploaded using the SDK simultaneously cannot exceed " + fmt.Sprint(setting.FLOW_CONTROL.ALL_ATTACHEMENT_NUM_SDK)) + } + return nil +} + +func CheckFlowForModel(ctx *context.Context) error { + if ctx.User == nil { + return errors.New("User not login.") + } + log.Info("start to check flow for upload model file.") + fileName := ctx.Query("file_name") + currentTimeNow := time.Now() + currentLongTime := currentTimeNow.Unix() + last24Hour := currentTimeNow.AddDate(0, 0, -1).Unix() + filechunks, err := models.GetModelFileChunksByUserId(ctx.User.ID, last24Hour, true) + if err == nil { + if len(filechunks) >= setting.FLOW_CONTROL.ATTACHEMENT_NUM_A_USER_LAST24HOUR { + log.Info("A single user cannot upload more than " + fmt.Sprint(setting.FLOW_CONTROL.ATTACHEMENT_NUM_A_USER_LAST24HOUR) + " files within the last 24 hours. so " + fileName + " is rejected. user id=" + fmt.Sprint(ctx.User.ID)) + return errors.New("A single user cannot upload more than " + fmt.Sprint(setting.FLOW_CONTROL.ATTACHEMENT_NUM_A_USER_LAST24HOUR) + " files within the last 24 hours.") + } + var totalSize int64 + totalSize += ctx.QueryInt64("size") + concurrentUpload := 0 + for _, file := range filechunks { + totalSize += file.Size + if (currentLongTime - int64(file.CreatedUnix)) < 10*60 { + log.Info("the file " + file.Md5 + " in 10min upload." + file.CreatedUnix.Format("2006-01-02 15:04:05")) + concurrentUpload += 1 + } else { + log.Info("the file " + file.Md5 + " not in 10min upload." + file.CreatedUnix.Format("2006-01-02 15:04:05")) + } + } + log.Info("The concurrentUpload is " + fmt.Sprint(concurrentUpload) + " to checked " + fileName + ". user id=" + fmt.Sprint(ctx.User.ID)) + if concurrentUpload >= setting.FLOW_CONTROL.ATTACHEMENT_NUM_A_USER_LAST10M { + log.Info("A single user cannot upload more than " + fmt.Sprint(setting.FLOW_CONTROL.ATTACHEMENT_NUM_A_USER_LAST10M) + " files within the past 10 minutes. so " + fileName + " is rejected. user id=" + fmt.Sprint(ctx.User.ID)) + return errors.New("A single user cannot upload more than " + fmt.Sprint(setting.FLOW_CONTROL.ATTACHEMENT_NUM_A_USER_LAST10M) + " files within the past 10 minutes.") + } + if totalSize >= setting.FLOW_CONTROL.ATTACHEMENT_SIZE_A_USER*1024*1024*1024 { + log.Info("The total file size uploaded by a single user within the past 24 hours cannot exceed " + fmt.Sprint(setting.FLOW_CONTROL.ATTACHEMENT_SIZE_A_USER) + "G. so " + fileName + " is rejected. user id=" + fmt.Sprint(ctx.User.ID)) + return errors.New("The total file size uploaded by a single user within the past 24 hours cannot exceed " + fmt.Sprint(setting.FLOW_CONTROL.ATTACHEMENT_SIZE_A_USER) + "G.") + } + } + return nil +} diff --git a/routers/repo/grampus.go b/routers/repo/grampus.go index d5bb2c116a..ab9961ffba 100755 --- a/routers/repo/grampus.go +++ b/routers/repo/grampus.go @@ -2017,7 +2017,7 @@ func GrampusCommitImage(ctx *context.Context, form auth.CommitImageGrampusForm) if err != nil { log.Error("CommitImage(%s) failed:%v", ctx.Cloudbrain.JobName, err.Error(), ctx.Data["msgID"]) - if models.IsErrImageTagExist(err) { + if models.IsErrImageTagExist(err) || strings.Contains(err.Error(), "Image already exists") || strings.Contains(err.Error(), "image exists") { ctx.JSON(200, models.BaseErrorMessage(ctx.Tr("repo.image_exist"))) } else if models.IsErrorImageCommitting(err) { diff --git a/routers/response/response_list.go b/routers/response/response_list.go index 03b558bb93..e37673b269 100644 --- a/routers/response/response_list.go +++ b/routers/response/response_list.go @@ -36,3 +36,4 @@ var LOAD_CODE_FAILED = &BizError{Code: 2019, DefaultMsg: "Fail to load code, ple var BRANCH_NOT_EXISTS = &BizError{Code: 2020, DefaultMsg: "The branch does not exist", TrCode: "ai_task.branch_not_exists"} var MODEL_NUM_OVER_LIMIT = &BizError{Code: 2021, DefaultMsg: "The number of models exceeds the limit of 30", TrCode: "repo.debug.manage.model_num_over_limit"} var DATASET_NUMBER_OVER_LIMIT = &BizError{Code: 2022, DefaultMsg: "The dataset count exceed the limit", TrCode: "ai_task.dataset_number_over_limit"} +var NOTEBOOK_EXCEED_MAX_NUM = &BizError{Code: 2023, DefaultMsg: "You can have up to 5 Debug Tasks, please try again after delete some tasks. ", TrCode: "ai_task.too_many_notebook"} diff --git a/routers/routes/routes.go b/routers/routes/routes.go index a038f19b13..1c21f7c8fa 100755 --- a/routers/routes/routes.go +++ b/routers/routes/routes.go @@ -364,7 +364,7 @@ func RegisterRoutes(m *macaron.Macaron) { m.Post("/user/login/kanban", user.SignInPostAPI) m.Get("/home/term", routers.HomeTerm) m.Get("/home/annual_privacy", routers.HomeAnnual) - m.Get("/home/wenxin_privacy", routers.HomeWenxin) + m.Get("/home/model_privacy", routers.HomeWenxin) m.Get("/home/notice", routers.HomeNoticeTmpl) m.Get("/home/privacy", routers.HomePrivacy) @@ -381,6 +381,10 @@ func RegisterRoutes(m *macaron.Macaron) { m.Get("/tuomin/upload", modelapp.ProcessImageUI) m.Post("/tuomin/upload", reqSignIn, modelapp.ProcessImage) m.Get("/wenxin", modelapp.WenXinPage) + m.Group("/llm_chat", func() { + m.Get("/create", reqSignIn, reqWechatBind, modelapp.LLMChatCreate) + m.Get("/:chatID", reqSignIn, reqWechatBind, modelapp.LLMChatPage) + }) m.Get("/wenxin/paint_new", reqSignIn, modelapp.WenXinPaintNew) m.Get("/wenxin/query_paint_result", reqSignIn, modelapp.QueryWenXinPaintResult) m.Get("/wenxin/query_paint_image", reqSignIn, modelapp.QueryWenXinPaintById) @@ -391,9 +395,7 @@ func RegisterRoutes(m *macaron.Macaron) { m.Get("/create", reqSignIn, reqWechatBind, modelapp.PanguFinetuneCreateUI) m.Get("/inference", reqSignIn, modelapp.PanguInferenceUI) }) - }) - }) m.Group("/explore", func() { diff --git a/services/ai_task_service/cluster/c2net.go b/services/ai_task_service/cluster/c2net.go index 522ec187fb..581e4535f0 100644 --- a/services/ai_task_service/cluster/c2net.go +++ b/services/ai_task_service/cluster/c2net.go @@ -271,7 +271,8 @@ func (c C2NetClusterAdapter) DeleteNoteBook(opts entity.JobIdAndVersionId) error _, err := grampus.DeleteJob(opts.JobID, string(models.JobTypeDebug)) if err != nil { log.Error("DeleteNoteBook(%s) failed:%v", opts, err) - return err + log.Info("error=" + err.Error()) + return nil } return nil } @@ -598,7 +599,8 @@ func (c C2NetClusterAdapter) DeleteTrainJob(opts entity.JobIdAndVersionId) error _, err := grampus.DeleteJob(opts.JobID) if err != nil { log.Error("Delete train job(%s) failed:%v", opts, err) - return err + log.Info("error=" + err.Error()) + return nil } return nil } diff --git a/services/ai_task_service/cluster/cloudbrain_two.go b/services/ai_task_service/cluster/cloudbrain_two.go index 677e59f541..6991504d12 100644 --- a/services/ai_task_service/cluster/cloudbrain_two.go +++ b/services/ai_task_service/cluster/cloudbrain_two.go @@ -186,7 +186,8 @@ func (c CloudbrainTwoClusterAdapter) DeleteNoteBook(opts entity.JobIdAndVersionI } if err != nil { log.Error("DeleteNoteBook err.jobID=%s err=%v", opts, err) - return err + log.Info("error=" + err.Error()) + return nil } return nil } @@ -552,7 +553,12 @@ func getCloudbrainTwoModelUrl(datasets []entity.ContainerData) string { func (c CloudbrainTwoClusterAdapter) DeleteTrainJob(opts entity.JobIdAndVersionId) error { _, err := modelarts.DelTrainJobVersion(opts.JobID, strconv.FormatInt(opts.VersionID, 10)) - return err + if err != nil { + log.Error("DeleteTrainJob err.jobID=%s err=%v", opts, err) + log.Info("error=" + err.Error()) + return nil + } + return nil } func (c CloudbrainTwoClusterAdapter) StopTrainJob(opts entity.JobIdAndVersionId) error { diff --git a/services/ai_task_service/task/cloudbrain_one_notebook_task.go b/services/ai_task_service/task/cloudbrain_one_notebook_task.go index 2f5f3658aa..701da0c47f 100644 --- a/services/ai_task_service/task/cloudbrain_one_notebook_task.go +++ b/services/ai_task_service/task/cloudbrain_one_notebook_task.go @@ -83,6 +83,7 @@ func (t CloudbrainOneNotebookTaskTemplate) Create(ctx *context.CreationContext) c := &CreateOperator{} err := c.Next(t.CheckParamFormat). Next(t.CheckMultiRequest). + Next(t.CheckNotebookCount). Next(t.CheckDisplayJobName). Next(t.LoadSpec). Next(t.CheckPointBalance). diff --git a/services/ai_task_service/task/cloudbrain_two_notebook_task.go b/services/ai_task_service/task/cloudbrain_two_notebook_task.go index a6eac5b04e..10c5ae1ea3 100644 --- a/services/ai_task_service/task/cloudbrain_two_notebook_task.go +++ b/services/ai_task_service/task/cloudbrain_two_notebook_task.go @@ -68,6 +68,7 @@ func (t CloudbrainTwoNotebookTaskTemplate) Create(ctx *context.CreationContext) err := c.Next(t.CheckParamFormat). Next(t.CheckMultiRequest). Next(t.CheckDisplayJobName). + Next(t.CheckNotebookCount). Next(t.LoadSpec). Next(t.CheckPointBalance). Next(t.CheckDatasets). diff --git a/services/ai_task_service/task/grampus_notebook_task.go b/services/ai_task_service/task/grampus_notebook_task.go index 2663be8d2c..3648735d92 100644 --- a/services/ai_task_service/task/grampus_notebook_task.go +++ b/services/ai_task_service/task/grampus_notebook_task.go @@ -149,6 +149,7 @@ func (t GrampusNoteBookTaskTemplate) Create(ctx *context.CreationContext) (*enti err := c.Next(t.CheckParamFormat). Next(t.CheckMultiRequest). Next(t.CheckDisplayJobName). + Next(t.CheckNotebookCount). Next(t.LoadSpec). Next(t.CheckPointBalance). Next(t.CheckDatasets). diff --git a/services/ai_task_service/task/opt_handler.go b/services/ai_task_service/task/opt_handler.go index d8df7937f7..d2d8338bf7 100644 --- a/services/ai_task_service/task/opt_handler.go +++ b/services/ai_task_service/task/opt_handler.go @@ -42,6 +42,7 @@ type CreationHandler interface { CallRestartAPI(ctx *context.CreationContext) *response.BizError NotifyCreation(ctx *context.CreationContext) *response.BizError HandleErr4Async(ctx *context.CreationContext) *response.BizError + CheckNotebookCount(ctx *context.CreationContext) *response.BizError } //DefaultCreationHandler CreationHandler的默认实现,公共逻辑可以在此结构体中实现 @@ -106,6 +107,7 @@ func (DefaultCreationHandler) BuildRequest4Restart(ctx *context.CreationContext) IsFileNoteBookRequest: task.BootFile != "", IsRestartRequest: true, DatasetNames: task.DatasetName, + HasInternet: models.SpecInternetQuery(task.HasInternet), } log.Info("BuildRequest4Restart success.displayJobName=%s jobType=%s cluster=%s", ctx.Request.DisplayJobName, ctx.Request.JobType, ctx.Request.Cluster) return nil @@ -421,6 +423,7 @@ func (DefaultCreationHandler) InsertCloudbrainRecord4Async(ctx *context.Creation UpdatedUnix: timeutil.TimeStampNow(), GpuQueue: ctx.Spec.QueueCode, AppName: req.AppName, + HasInternet: int(req.HasInternet), } err := models.CreateCloudbrain(c) @@ -594,6 +597,7 @@ func (DefaultCreationHandler) CreateCloudbrainRecord4Restart(ctx *context.Creati SubTaskName: models.SubTaskName, ModelId: req.PretrainModelId, GpuQueue: ctx.Spec.QueueCode, + HasInternet: int(req.HasInternet), } err := models.RestartCloudbrain(ctx.SourceCloudbrain, c) @@ -685,3 +689,17 @@ func (DefaultCreationHandler) HandleErr4Async(ctx *context.CreationContext) *res func (g DefaultCreationHandler) NotifyCreation(ctx *context.CreationContext) *response.BizError { return nil } + +func (DefaultCreationHandler) CheckNotebookCount(ctx *context.CreationContext) *response.BizError { + + if setting.NotebookStrategy.ClearEnabled && ctx.Request.JobType == models.JobTypeDebug { + count, err := models.GetNotebooksCountByUser(ctx.User.ID) + if err != nil { + log.Warn("can not get user notebook count", err) + } + if count >= int64(setting.NotebookStrategy.MaxNumberPerUser) { + return response.NOTEBOOK_EXCEED_MAX_NUM + } + } + return nil +} diff --git a/services/ai_task_service/task/task_service.go b/services/ai_task_service/task/task_service.go index 39ffcbc567..a1ef16fe46 100644 --- a/services/ai_task_service/task/task_service.go +++ b/services/ai_task_service/task/task_service.go @@ -138,6 +138,7 @@ func buildAITaskInfo(task *models.Cloudbrain, creator *models.User, config *enti EngineName: task.EngineName, UserId: task.UserID, AppName: task.AppName, + HasInternet: task.HasInternet, }, nil } @@ -805,3 +806,47 @@ func HandleNewAITaskDelete(cloudbrainId int64) (isHandled bool, err error) { } return true, nil } + +func ClearNotebook() { + defer func() { + if err := recover(); err != nil { + log.Error("panic occurred:", err) + } + }() + + if !setting.NotebookStrategy.ClearEnabled { + return + } + + userCountInfo, err := models.GetNotebookCountGreaterThanN(setting.NotebookStrategy.MaxNumberPerUser) + if err != nil { + log.Error("can not get Notebook user count info", err) + return + } + deleteCount := 0 + for _, userCount := range userCountInfo { + ids, err := models.GetNotebooksByUser(userCount.UserID, setting.NotebookStrategy.MaxNumberPerUser) + if err != nil { + log.Error("can not get Notebook by user id", err) + continue + } + for _, id := range ids { + t, _ := GetAITaskTemplateByCloudbrainId(id) + if t == nil { + log.Error("can not get task template") + continue + } + err := t.Delete(id) + if err != nil { + log.Error("Delete error.%v", err) + continue + } + log.Info("Clear Notebook id is " + strconv.FormatInt(id, 10)) + deleteCount += 1 + if deleteCount >= setting.NotebookStrategy.ClearBatchSize { + return + } + } + } + +} diff --git a/services/llm_chat/llm_chat.go b/services/llm_chat/llm_chat.go new file mode 100644 index 0000000000..43ea3c8f86 --- /dev/null +++ b/services/llm_chat/llm_chat.go @@ -0,0 +1,564 @@ +package llm_chat + +import ( + "code.gitea.io/gitea/models" + baiduAPI "code.gitea.io/gitea/modules/baiduai" + "code.gitea.io/gitea/modules/context" + llmChatAPI "code.gitea.io/gitea/modules/llm_chat" + "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/setting" + api "code.gitea.io/gitea/modules/structs" + "encoding/json" + uuid "github.com/satori/go.uuid" + "net/http" + "strconv" + "strings" + "time" +) + +const ( + VectorStoreType = "faiss" + EmbeddingModel = "m3e-base" + TopK = 5 + ScoreThreshold = 0.5 + DeleteContent = true + NotRefreshVsCache = false + FlagTextInvalid = "" + FlagAccountBanned = "" + FlagTextDoc = "" + ValidationTool = "baidu_api" +) + +func getKnowledgeBaseName(ctx *context.Context) string { + kbName := ctx.Query("knowledge_base_name") + userID := strconv.FormatInt(ctx.User.ID, 10) + if kbName != setting.LLM_CHAT_API.COMMON_KB { + return userID + "_" + kbName + } + return kbName +} + +func LLMChatService(ctx *context.Context, data api.LLMChatMessage, chat *models.LlmChatVisit) { + log.Info("received by api %+v", data) + res, err := llmChatAPI.SendLLMChat(data) + if err != nil { + log.Error("LLMChatService failed: %s", err) + ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(err.Error())) + } + ctx.JSON(http.StatusOK, res) +} + +func KBChatService(ctx *context.Context, data api.KBChatMessage, chat *models.LlmChatVisit) { + if data.TopK == 0 || data.ScoreThreshold == 0 { + data.TopK = TopK + data.ScoreThreshold = ScoreThreshold + } + log.Info("received by api %+v", data) + res, err := llmChatAPI.SendKBChat(data) + if err != nil { + log.Error("KnowledgeBaseChatService failed: %s", err) + ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(err.Error())) + } + log.Info("received by resty %+v", res) + ctx.JSON(http.StatusOK, res) +} + +func isInvalidQuery(ctx *context.Context, chat *models.LlmChat, queryType string) bool { + var query string + if queryType == "prompt" { + query = chat.Prompt + } else { + query = chat.Answer + } + + if query == "" { + return false + } + + chat.InvalidCount = 0 + chat.InvalidTool = ValidationTool + res, err := baiduAPI.CheckLegalText(query) + if err != nil { + log.Error("isInvalidQuery() failed: %s", err) + return false + } + if res.ConclusionType != 1 { + chat.InvalidCount = 1 + chat.InvalidType = queryType + jsonRes, _ := json.Marshal(res) + chat.InvalidDetail = string(jsonRes) + err := models.SaveChat(chat) + if err != nil { + log.Error("isInvalidQuery() SaveChat failed: %s", err) + ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(err.Error())) + return true + } + + invalidTotal := models.QueryInvalidPromptCount(ctx.User.ID) + if invalidTotal >= setting.LLM_CHAT_API.LEGAL_MAX_COUNT { + log.Info("isInvalidQuery() invalid total reach max: %d\n", invalidTotal) + ctx.User.ProhibitLogin = true + models.UpdateUserCols(ctx.User, "prohibit_login") + ctx.Resp.Write([]byte(FlagAccountBanned)) + ctx.Resp.Flush() + } else { + ctx.Resp.Write([]byte(FlagTextInvalid)) + ctx.Resp.Flush() + } + return true + } + return false +} + +func StreamLLMChatService(ctx *context.Context, data api.LLMChatMessage, chat *models.LlmChatVisit) { + uuid := uuid.NewV4() + id := uuid.String() + llmChat := &models.LlmChat{ + ID: id, + UserId: ctx.User.ID, + ChatId: chat.ChatId, + Prompt: data.Query, + ModelName: data.ModelName, + Endpoint: llmChatAPI.GetEndpoint(data.ModelName), + ChatType: "llm", + ChatStatus: 1, + Count: 1, + } + + var answer string + ctx.Resp.Header().Set("Content-Type", "application/octet-stream; charset=utf-8") + ctx.Resp.Header().Set("X-Accel-Buffering", "no") + + //call baiduai api to check legality of query + if setting.LLM_CHAT_API.LEGAL_CHECK { + invalidPrompt := isInvalidQuery(ctx, llmChat, "prompt") + if invalidPrompt { + log.Info("StreamLLMChatService() invalid prompt: %s\n", llmChat.Prompt) + return + } + } + + resultChan := make(chan string) + errChan := make(chan error) + done := make(chan struct{}) + go llmChatAPI.StreamLLMChat(data, resultChan, errChan, done) + + for { + select { + case data := <-resultChan: + answer += data + ctx.Resp.Write([]byte(data)) + ctx.Resp.Flush() + case err := <-errChan: + response := ctx.Tr("llm_chat.server_error") + for _, v := range response { + ctx.Resp.Write([]byte(string(v))) + ctx.Resp.Flush() + time.Sleep(50 * time.Millisecond) + } + log.Error("StreamLLMChatService() failed: %s", err) + log.Info("StreamLLMChatService() chat server api error, save to db") + llmChat.ChatStatus = 0 + err = models.SaveChat(llmChat) + if err != nil { + log.Error("StreamLLMChatService() SaveChat failed: %s", err) + ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(err.Error())) + } + close(resultChan) + close(errChan) + close(done) + return + case <-done: + llmChat.Answer = answer + if llmChat.Answer == "" { + llmChat.ChatStatus = 0 + } + if setting.LLM_CHAT_API.LEGAL_CHECK { + invalidAnswer := isInvalidQuery(ctx, llmChat, "answer") + if invalidAnswer { + log.Info("StreamLLMChatService() invalid answer: %s\n", llmChat.Answer) + close(resultChan) + close(errChan) + return + } + } + log.Info("StreamLLMChatService() nothing invalid, save to db") + err := models.SaveChat(llmChat) + if err != nil { + log.Error("StreamLLMChatService() SaveChat failed: %s", err) + ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(err.Error())) + } + close(resultChan) + close(errChan) + return + } + } +} + +func StreamKBChatService(ctx *context.Context, data api.KBChatMessage, chat *models.LlmChatVisit) { + userID := strconv.FormatInt(ctx.User.ID, 10) + if data.KnowledgeBaseName != setting.LLM_CHAT_API.COMMON_KB { + data.KnowledgeBaseName = userID + "_" + data.KnowledgeBaseName + } + + uuid := uuid.NewV4() + id := uuid.String() + llmChat := &models.LlmChat{ + ID: id, + UserId: ctx.User.ID, + ChatId: chat.ChatId, + Prompt: data.Query, + ModelName: data.ModelName, + Endpoint: llmChatAPI.GetEndpoint(data.ModelName), + KnowledgeBaseName: data.KnowledgeBaseName, + VectorStoreType: VectorStoreType, + EmbeddingModel: EmbeddingModel, + ChatType: "kb", + ChatStatus: 1, + Count: 1, + } + + var answer string + var docs string + ctx.Resp.Header().Set("Content-Type", "application/octet-stream; charset=utf-8") + ctx.Resp.Header().Set("X-Accel-Buffering", "no") + + //call baiduai api to check legality of query + if setting.LLM_CHAT_API.LEGAL_CHECK { + invalidPrompt := isInvalidQuery(ctx, llmChat, "prompt") + if invalidPrompt { + log.Info("StreamKBChatService() invalid prompt: %s\n", llmChat.Prompt) + return + } + } + + resultChan := make(chan string) + errChan := make(chan error) + done := make(chan struct{}) + go llmChatAPI.StreamKBChat(data, resultChan, errChan, done) + + for { + select { + case data := <-resultChan: + if strings.Contains(data, "answer") { + var result api.KBChatAnswer + json.Unmarshal([]byte(data), &result) + //ctx.JSON(http.StatusOK, result) + ctx.Resp.Write([]byte(result.Answer)) + ctx.Resp.Flush() + answer += result.Answer + } + if strings.Contains(data, "docs") { + docs += data + } + case err := <-errChan: + response := ctx.Tr("llm_chat.server_error") + for _, v := range response { + ctx.Resp.Write([]byte(string(v))) + ctx.Resp.Flush() + time.Sleep(50 * time.Millisecond) + } + log.Error("StreamKBChatService() failed: %s", err) + log.Info("StreamKBChatService() chat server api error, save to db") + llmChat.ChatStatus = 0 + err = models.SaveChat(llmChat) + if err != nil { + log.Error("SaveChat failed: %s", err) + ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(err.Error())) + } + close(resultChan) + close(errChan) + close(done) + return + case <-done: + if docs != "" { + ctx.Resp.Write([]byte(FlagTextDoc + docs)) + ctx.Resp.Flush() + } + llmChat.Answer = answer + if llmChat.Answer == "" { + llmChat.ChatStatus = 0 + } + //call baiduai api to check legality of query + if setting.LLM_CHAT_API.LEGAL_CHECK { + invalidAnswer := isInvalidQuery(ctx, llmChat, "answer") + if invalidAnswer { + log.Info("StreamKBChatService() invalid answer: %s\n", llmChat.Answer) + close(resultChan) + close(errChan) + return + } + } + log.Info("StreamKBChatService() nothing invalid, save to db") + err := models.SaveChat(llmChat) + if err != nil { + log.Error("StreamKBChatService() SaveChat failed: %s", err) + ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(err.Error())) + } + close(resultChan) + close(errChan) + return + } + } +} + +func ListKnowledgeBaseService(ctx *context.Context) { + userID := strconv.FormatInt(ctx.User.ID, 10) + res, err := llmChatAPI.ListKnowledgeBase() + if err != nil { + log.Error("LLMChatService failed: %s", err) + ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(err.Error())) + return + } + log.Info("received by resty %+v", res) + log.Info("######## received by resty %+v\n", res) + var realKB []string + for i := len(res.Data) - 1; i >= 0; i-- { + v := res.Data[i] + if strings.Contains(v, userID) { + substr := strings.Replace(v, userID+"_", "", -1) + if strings.TrimSpace(substr) != "" { + realKB = append(realKB, substr) + } + } + } + + if setting.LLM_CHAT_API.COMMON_KB != "" { + realKB = append(realKB, setting.LLM_CHAT_API.COMMON_KB) + } + + realData := llmChatAPI.LLMBasicMsgWithData{ + Code: res.Code, + Msg: res.Msg, + Data: realKB, + } + log.Info("######## sent %+v\n", realData) + ctx.JSON(http.StatusOK, realData) +} + +func CreateKnowledgeBaseService(ctx *context.Context, data api.CreateKnowledgeBaseParams) { + userID := strconv.FormatInt(ctx.User.ID, 10) + realKB := userID + "_" + data.KnowledgeBaseName + params := api.CreateKnowledgeBaseParams{ + KnowledgeBaseName: realKB, + VectorStoreType: VectorStoreType, + EmbedModel: EmbeddingModel, + } + log.Info("received by api %+v\n", params) + res, err := llmChatAPI.CreateKnowledgeBase(params) + if err != nil { + log.Error("KnowledgeBaseChatService failed: %s", err) + ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(err.Error())) + return + } + log.Info("##### received by resty %+v\n", res) + if strings.Contains(res.Msg, userID) { + substr := strings.Replace(res.Msg, userID+"_", "", -1) + if strings.TrimSpace(substr) != "" { + res.Msg = substr + } + } + log.Info("##### sent to client %+v\n", res) + ctx.JSON(http.StatusOK, res) +} + +func DeleteKnowledgeBaseService(ctx *context.Context) { + userID := strconv.FormatInt(ctx.User.ID, 10) + data := getKnowledgeBaseName(ctx) + log.Info("### received by api %+v\n", data) + if data == setting.LLM_CHAT_API.COMMON_KB { + ctx.Error(http.StatusForbidden, "You can't operate %s", data) + return + } + res, err := llmChatAPI.DeleteKnowledgeBase(data) + if err != nil { + log.Error("DeleteKnowledgeBaseService failed: %s", err) + ctx.JSON(http.StatusOK, res) + return + } + log.Info("######## received by resty %+v\n", res) + if strings.Contains(res.Msg, userID) { + substr := strings.Replace(res.Msg, userID+"_", "", -1) + if strings.TrimSpace(substr) != "" { + res.Msg = substr + } + } + log.Info("######## sent to client %+v\n", res) + ctx.JSON(http.StatusOK, res) +} + +func ListFilesService(ctx *context.Context) { + data := getKnowledgeBaseName(ctx) + log.Info("received by api %+v", data) + res, err := llmChatAPI.KBListFiles(data) + if err != nil { + log.Error("ListFiles failed: %s", err) + ctx.JSON(http.StatusOK, res) + return + } + log.Info("received by resty %+v", res) + ctx.JSON(http.StatusOK, res) +} + +func SearchDocService(ctx *context.Context, data api.SearchDocParams) { + if data.TopK == 0 || data.ScoreThreshold == 0 { + data.TopK = TopK + data.ScoreThreshold = ScoreThreshold + } + realKB := getKnowledgeBaseName(ctx) + data.KnowledgeBaseName = realKB + log.Info("received by api %+v", data) + res, err := llmChatAPI.KBSearchDoc(data) + if err != nil { + log.Error("SearchDocService failed: %s", err) + ctx.JSON(http.StatusOK, res) + return + } + log.Info("received by resty %+v", res) + ctx.JSON(http.StatusOK, res) +} + +func DeleteDocService(ctx *context.Context, data api.DeleteDocParams) { + data.DeleteContent = DeleteContent + data.NotRefreshVsCache = NotRefreshVsCache + userID := strconv.FormatInt(ctx.User.ID, 10) + realKB := userID + "_" + data.KnowledgeBaseName + data.KnowledgeBaseName = realKB + log.Info("received by api %+v", data) + if data.KnowledgeBaseName == setting.LLM_CHAT_API.COMMON_KB { + ctx.Error(http.StatusForbidden, "You can't operate %s", data.KnowledgeBaseName) + return + } + res, err := llmChatAPI.KBDeleteDoc(data) + if err != nil { + log.Error("LLMChatService failed: %s", err) + ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(err.Error())) + return + } + ctx.JSON(http.StatusOK, res) +} + +func UpdateDocService(ctx *context.Context) { + data := api.UpdateDocParams{ + KnowledgeBaseName: getKnowledgeBaseName(ctx), + FileNames: ctx.Query("file_name"), + NotRefreshVsCache: NotRefreshVsCache, + } + log.Info("received by api %+v", data) + if data.KnowledgeBaseName == setting.LLM_CHAT_API.COMMON_KB { + ctx.Error(http.StatusForbidden, "You can't operate %s", data.KnowledgeBaseName) + return + } + res, err := llmChatAPI.KBUpdateDoc(data) + if err != nil { + log.Error("LLMChatService failed: %s", err) + ctx.JSON(http.StatusOK, models.BaseErrorMessageApi(err.Error())) + return + } + ctx.JSON(http.StatusOK, res) +} + +func RecreateVectorStoreService(ctx *context.Context) { + data := api.RecreateVectorStoreParams{ + KnowledgeBaseName: getKnowledgeBaseName(ctx), + AllowEmptyKb: true, + VsType: VectorStoreType, + EmbedModel: EmbeddingModel, + } + log.Info("received by api %+v", data) + if data.KnowledgeBaseName == setting.LLM_CHAT_API.COMMON_KB { + ctx.Error(http.StatusForbidden, "You can't operate %s", data.KnowledgeBaseName) + return + } + + resultChan := make(chan string) + errChan := make(chan error) + done := make(chan struct{}) + + go llmChatAPI.KBRecreateVectorStore(data, resultChan, errChan, done) + ctx.Resp.Header().Set("Content-Type", "application/octet-stream; charset=utf-8") + ctx.Resp.Header().Set("X-Accel-Buffering", "no") + + for { + select { + case data := <-resultChan: + _, err := ctx.Resp.Write([]byte(data)) + if err != nil { + log.Error("Error writing response: %s", err) + ctx.JSON(http.StatusInternalServerError, models.BaseErrorMessageApi(err.Error())) + return + } + log.Info("%s\n", []byte(data)) + ctx.Resp.Flush() // Flush the response to send it immediately + case err := <-errChan: + log.Error("Error writing response: %s", err) + ctx.JSON(http.StatusInternalServerError, models.BaseErrorMessageApi(err.Error())) + return + case <-done: + close(resultChan) + close(errChan) + return + } + } +} + +func UploadDocUrlService(ctx *context.Context) { + data := getKnowledgeBaseName(ctx) + if data == setting.LLM_CHAT_API.COMMON_KB { + ctx.Error(http.StatusForbidden, "You can't operate %s", data) + return + } + url, _ := llmChatAPI.GetUploadDocUrl() + log.Info("received by api %+v", url) + + ctx.JSON(http.StatusOK, url) +} + +func UploadDocService(ctx *context.Context, form api.LLMChatUploadForm) { + log.Info("######### received request %+v\n", ctx.Req.Request) + log.Info("######### form api.LLMChatUploadForm %+v\n", form) + modelName := ctx.Query("model_name") + userID := strconv.FormatInt(ctx.User.ID, 10) + if form.KnowledgeBaseName != setting.LLM_CHAT_API.COMMON_KB { + form.KnowledgeBaseName = userID + "_" + form.KnowledgeBaseName + } + res, err := llmChatAPI.UploadDocs(modelName, form) + log.Info("######### received by resty %+v\n", res) + + if err != nil { + log.Error("UploadDocService failed: %s", err) + ctx.JSON(http.StatusOK, err.Error()) + return + } + + ctx.JSON(http.StatusOK, res) +} + +func DownloadDocService(ctx *context.Context) { + data := getKnowledgeBaseName(ctx) + if data == setting.LLM_CHAT_API.COMMON_KB { + ctx.Error(http.StatusForbidden, "You can't operate %s", data) + return + } + fileName := ctx.Query("file_name") + log.Info("received by api knowledgeBaseName:%s, fileName: %s", data, fileName) + url, _ := llmChatAPI.GetDownloadDocUrl(data, fileName) + log.Info("received by api %+v", url) + http.Redirect(ctx.Resp, ctx.Req.Request, url, http.StatusMovedPermanently) +} + +func GetFreeTriesService(ctx *context.Context) { + modelName := ctx.Query("model_name") + maxTries := setting.LLM_CHAT_API.MAX_FREE_TRIES + counts := models.QueryChatCount(ctx.User.ID, modelName) + //firstVisit := models.QueryFirstVisit(ctx.User.ID) + + data := api.LLMChatCountsResults{ + MaxTries: maxTries, + Counts: counts, + CanChat: counts < maxTries, + //FirstVisit: firstVisit == 0, + } + log.Info("user %+v, GetFreeTriesService() data= %+v", ctx, data) + ctx.JSON(http.StatusOK, data) +} diff --git a/templates/admin/cloudbrain/imagecommit.tmpl b/templates/admin/cloudbrain/imagecommit.tmpl index 98737b7d6f..e0cd9c5b7d 100644 --- a/templates/admin/cloudbrain/imagecommit.tmpl +++ b/templates/admin/cloudbrain/imagecommit.tmpl @@ -119,7 +119,11 @@ --> -
+
+ + {{.i18n.Tr "repo.images.submit_tooltips"}} +
+