diff --git a/models/ai_model_manage.go b/models/ai_model_manage.go index 7eb21684b9..fa0ec33e91 100644 --- a/models/ai_model_manage.go +++ b/models/ai_model_manage.go @@ -774,3 +774,48 @@ func QueryModelFileByModelId(modelId string) []*AiModelFile { } return result } + +func QueryModelForSearch(opts *AiModelQueryOptions) ([]*AiModelManage, int64, error) { + sess := x.NewSession() + defer sess.Close() + var where string + where += "ai_model_manage.user_id=" + fmt.Sprint(opts.UserID) + where += " and ai_model_manage.is_private=true" + if opts.Namelike != "" { + where += " and ( ai_model_manage.name ILIKE '%" + opts.Namelike + "%'" + where += " or ai_model_manage.description ILIKE '%" + opts.Namelike + "%'" + where += " or ai_model_manage.label ILIKE '%" + opts.Namelike + "%'" + where += " or ai_model_file.name ILIKE '%" + opts.Namelike + "%')" + } + + var count int64 + var err error + count, err = sess.Join("LEFT", "ai_model_file", "ai_model_manage.id = ai_model_file.model_id").Select("count(distinct(ai_model_manage.id))").Where(where).Count(new(AiModelManage)) + if err != nil { + log.Info("error=" + err.Error()) + return nil, 0, fmt.Errorf("Count: %v", err) + } + if opts.Page >= 0 && opts.PageSize > 0 { + var start int + if opts.Page == 0 { + start = 0 + } else { + start = (opts.Page - 1) * opts.PageSize + } + sess.Limit(opts.PageSize, start) + } + sess.Join("LEFT", "ai_model_file", "ai_model_manage.id = ai_model_file.model_id") + orderby := "ai_model_manage.created_unix desc" + if opts.SortType != "" { + orderby = opts.SortType + } + sess.OrderBy(orderby) + aiModelManages := make([]*AiModelManage, 0, setting.UI.IssuePagingNum) + if err := sess.Select("distinct(ai_model_manage.*)").Table("ai_model_manage").Where(where). + Find(&aiModelManages); err != nil { + log.Info("error=" + err.Error()) + return nil, 0, fmt.Errorf("Find: %v", err) + } + + return aiModelManages, count, nil +} diff --git a/options/locale/locale_en-US.ini b/options/locale/locale_en-US.ini index f793db2ccf..3f8cdf4e83 100755 --- a/options/locale/locale_en-US.ini +++ b/options/locale/locale_en-US.ini @@ -274,6 +274,7 @@ c2net_center=Center search=Search search_repo=Repository search_dataset=DataSet +search_model=Model search_issue=Issue search_pr=Pull Request search_user=User diff --git a/options/locale/locale_zh-CN.ini b/options/locale/locale_zh-CN.ini index a0f2f62439..35c7e24dbf 100755 --- a/options/locale/locale_zh-CN.ini +++ b/options/locale/locale_zh-CN.ini @@ -276,6 +276,7 @@ c2net_center=中心 search=搜索 search_repo=项目 search_dataset=数据集 +search_model=模型 search_issue=任务 search_pr=合并请求 search_user=用户 diff --git a/public/home/search.js b/public/home/search.js index 86b2ad06eb..57542c88c2 100644 --- a/public/home/search.js +++ b/public/home/search.js @@ -27,6 +27,7 @@ var itemType = { 4: "org", 5: "dataset", 6: "pr", + 7: "model", }; var sortBy = { @@ -51,6 +52,10 @@ var sortBy = { 51: "download_times", 60: "default", 61: "updated_unix.keyword", + 70: "default", + 71: "reference_count", + 72: "download_count", + 73: "created_unix.keyword", }; var sortAscending = { @@ -75,6 +80,10 @@ var sortAscending = { 51: "false", 60: "false", 61: "false", + 70: "false", + 71: "false", + 72: "false", + 73: "false", }; var currentPage = 1; @@ -149,6 +158,7 @@ function emptySearch() { $("#pr_total").text(""); $("#issue_total").text(""); $("#dataset_total").text(""); + $("#model_total").text(""); $("#user_total").text(""); $("#org_total").text(""); setActivate(null); @@ -159,6 +169,7 @@ function initDiv(isSearchLabel = false) { document.getElementById("search_div").style.display = "none"; document.getElementById("search_label_div").style.display = "block"; document.getElementById("dataset_item").style.display = "none"; + document.getElementById("model_item").style.display = "none"; document.getElementById("issue_item").style.display = "none"; document.getElementById("pr_item").style.display = "none"; document.getElementById("user_item").style.display = "none"; @@ -168,6 +179,7 @@ function initDiv(isSearchLabel = false) { document.getElementById("search_div").style.display = "block"; document.getElementById("search_label_div").style.display = "none"; document.getElementById("dataset_item").style.display = "block"; + document.getElementById("model_item").style.display = "block"; document.getElementById("issue_item").style.display = "block"; document.getElementById("pr_item").style.display = "block"; document.getElementById("user_item").style.display = "block"; @@ -210,6 +222,9 @@ function doSpcifySearch(tableName, keyword, sortBy = "", ascending = "false") { if (currentSearchTableName != "dataset") { doSearch("dataset", currentSearchKeyword, 1, pageSize, true, "", false); } + if (currentSearchTableName != "model") { + doSearch("model", currentSearchKeyword, 1, pageSize, true, "", false); + } if (currentSearchTableName != "pr") { doSearch("pr", currentSearchKeyword, 1, pageSize, true, "", false); } @@ -278,7 +293,7 @@ function doSearch( success: function (json) { displayResult(tableName, page, json, onlyReturnNum, keyword); }, - error: function (response) {}, + error: function (response) { }, }); } @@ -293,6 +308,8 @@ function displayResult(tableName, page, jsonResult, onlyReturnNum, keyword) { displayOrgResult(page, jsonResult, onlyReturnNum, keyword); } else if (tableName == "dataset") { displayDataSetResult(page, jsonResult, onlyReturnNum, keyword); + } else if (tableName == "model") { + displayModelResult(page, jsonResult, onlyReturnNum, keyword); } else if (tableName == "pr") { displayPrResult(page, jsonResult, onlyReturnNum, keyword); } @@ -535,6 +552,104 @@ function displayDataSetResult(page, jsonResult, onlyReturnNum, keyword) { } } +function displayModelResult(page, jsonResult, onlyReturnNum, keyword) { + var data = jsonResult.Result; + var total = jsonResult.Total; + $("#model_total").text(total); + if (!onlyReturnNum) { + setActivate("model_item"); + //$('#keyword_desc').text(keyword); + //$('#obj_desc').text(getLabel(isZh,"search_model")); + //$('#child_total').text(total); + $("#find_title").html( + getLabel(isZh, "find_title") + .replace("{keyword}", keyword) + .replace("{tablename}", getLabel(isZh, "search_model")) + .replace("{total}", total) + ); + + var sortHtml = ""; + sortHtml += + '' + + getLabel(isZh, "search_matched") + + ""; + sortHtml += + '' + + getLabel(isZh, "search_matched_reference") + + ""; + sortHtml += + '' + + getLabel(isZh, "search_matched_download") + + ""; + sortHtml += + '' + + getLabel(isZh, "search_lasted_create") + + ""; + document.getElementById("sort_type").innerHTML = sortHtml; + + var html = ""; + var currentTime = new Date().getTime(); + const engineMap = { + '0': 'PyTorch', + '1': 'TensorFlow', + '2': 'MindSpore', + '4': 'PaddlePaddle', + '5': 'OneFlow', + '6': 'MXNet', + '3': 'Other', + }; + for (var i = 0; i < data.length; i++) { + var recordMap = data[i]; + var createDate = new Date(recordMap['created_unix'] * 1000); + var createYear = createDate.getFullYear().toString(); + var createMonth = (createDate.getMonth() + 1).toString(); + var createDay = createDate.getDate().toString(); + html += `
+
+
${engineMap[recordMap['engine']] || 'Other'}
+ +
+

${recordMap['label'] ? recordMap['label'].replace(/font\scolor=/g, 'font_color=').trim().split(/\s+/).map(item => { + return '' + + item.replace(/font_color=/g, 'font color=') + ''; + }).join('') : ''}

+

${recordMap['description']}

+

${recordMap['file_name'] || ''}

+

+ + + ${createYear}-${createMonth.length < 2 ? '0' + createMonth : createMonth}-${createDay.length < 2 ? '0' + createDay : createDay} + + + ${recordMap['reference_count']} + + + ${recordMap['download_count']} + + ${getLabel(isZh, "search_lasted_update")} ${recordMap["updated_html"]} +

+
+
+
+ `; + } + document.getElementById("child_search_item").innerHTML = html; + } +} + function displayOrgResult(page, jsonResult, onlyReturnNum, keyword) { var data = jsonResult.Result; var total = jsonResult.Total; @@ -872,6 +987,7 @@ function setActivate(name) { $("#user_item").removeClass("active"); $("#issue_item").removeClass("active"); $("#dataset_item").removeClass("active"); + $("#model_item").removeClass("active"); $("#org_item").removeClass("active"); $("#pr_item").removeClass("active"); if (name == null) { @@ -885,27 +1001,27 @@ function LetterAvatar(name, size, color) { name = name || ""; size = size || 60; var colours = [ - "#1abc9c", - "#2ecc71", - "#3498db", - "#9b59b6", - "#34495e", - "#16a085", - "#27ae60", - "#2980b9", - "#8e44ad", - "#2c3e50", - "#f1c40f", - "#e67e22", - "#e74c3c", - "#00bcd4", - "#95a5a6", - "#f39c12", - "#d35400", - "#c0392b", - "#bdc3c7", - "#7f8c8d", - ], + "#1abc9c", + "#2ecc71", + "#3498db", + "#9b59b6", + "#34495e", + "#16a085", + "#27ae60", + "#2980b9", + "#8e44ad", + "#2c3e50", + "#f1c40f", + "#e67e22", + "#e74c3c", + "#00bcd4", + "#95a5a6", + "#f39c12", + "#d35400", + "#c0392b", + "#bdc3c7", + "#7f8c8d", + ], nameSplit = String(name).split(" "), initials, charIndex, @@ -1624,6 +1740,7 @@ var zhCN = { search: "搜索", search_repo: "项目", search_dataset: "数据集", + search_model: "模型", search_issue: "任务", search_pr: "合并请求", search_user: "用户", @@ -1631,10 +1748,11 @@ var zhCN = { search_finded: "找到", search_matched: "最佳匹配", search_matched_download: "下载次数", + search_matched_reference: "引用次数", search_lasted_update: "最后更新于", search_letter_asc: "字母顺序排序", search_letter_desc: "字母逆序排序", - search_lasted_create: "最近创建", + search_lasted_create: "最新创建", search_early_create: "最早创建", search_add_by: "加入于", search_lasted: "最近更新", @@ -1654,12 +1772,14 @@ var zhCN = { find_title: '“{keyword}”相关{tablename}约为{total}个', search_empty: "请输入任意关键字开始搜索。", + create_time: "创建时间" }; var esUN = { search: "Search", search_repo: "Repository", search_dataset: "DataSet", + search_model: "Model", search_issue: "Issue", search_pr: "Pull Request", search_user: "User", @@ -1667,6 +1787,7 @@ var esUN = { search_finded: "Find", search_matched: "Best Match", search_matched_download: "Most downloads", + search_matched_reference: "Most reference", search_lasted_update: "Updated ", search_letter_asc: "Alphabetically", search_letter_desc: "Reverse alphabetically", @@ -1691,6 +1812,7 @@ var esUN = { ' {total} "{keyword}" related {tablename}', search_empty: "Please enter any keyword to start the search.", + create_time: "Create Time" }; initDiv(false); document.onreadystatechange = function () { diff --git a/routers/repo/ai_model_manage.go b/routers/repo/ai_model_manage.go index 036fc437e8..af5271c5db 100644 --- a/routers/repo/ai_model_manage.go +++ b/routers/repo/ai_model_manage.go @@ -223,7 +223,7 @@ func SaveNewNameModel(ctx *context.Context) { return } SaveModel(ctx) - ctx.Status(200) + //ctx.Status(200) log.Info("save model end.") } diff --git a/routers/search.go b/routers/search.go index 8453d5c185..1451b250d9 100644 --- a/routers/search.go +++ b/routers/search.go @@ -87,6 +87,9 @@ func SearchApi(ctx *context.Context) { searchIssueOrPr(ctx, "issue-es-index"+setting.INDEXPOSTFIX, Key, Page, PageSize, OnlyReturnNum, "t") //searchPR(ctx, "issue-es-index", Key, Page, PageSize, OnlyReturnNum) return + } else if TableName == "model" { + searchModel(ctx, "model-es-index"+setting.INDEXPOSTFIX, Key, Page, PageSize, OnlyReturnNum) + return } } @@ -313,6 +316,8 @@ func searchRepo(ctx *context.Context, TableName string, Key string, Page int, Pa res, err := client.Search(TableName).Query(boolQ).SortBy(getSort(SortBy, ascending, "num_stars", false)...).From(from).Size(Size).Highlight(queryHighlight("alias", "description", "topics")).Do(ctx.Req.Context()) if err == nil { + searchJson, _ := json.Marshal(res) + log.Info("searchJson=" + string(searchJson)) esresult := makeRepoResult(res, Key, OnlyReturnNum, language) setForkRepoOrder(esresult, SortBy) resultObj.Total = resultObj.PrivateTotal + esresult.Total @@ -1221,3 +1226,228 @@ func makeIssueResult(sRes *elastic.SearchResult, Key string, OnlyReturnNum bool, return returnObj } + +func searchModel(ctx *context.Context, TableName string, Key string, Page int, PageSize int, OnlyReturnNum bool) { + /* + 模型,model-es-index + 搜索: + name , 名称 + description 描述 + label 标签 + file_name 数据集文件名称 + 排序: + download_count + reference_count + created_unix + */ + log.Info("query searchModel start") + SortBy := ctx.Query("SortBy") + ascending := ctx.QueryBool("Ascending") + PrivateTotal := ctx.QueryInt("PrivateTotal") + WebTotal := ctx.QueryInt("WebTotal") + language := ctx.Query("language") + if language == "" { + language = "zh-CN" + } + from := (Page - 1) * PageSize + if from == 0 { + WebTotal = 0 + } + resultObj := &SearchRes{} + log.Info("WebTotal=" + fmt.Sprint(WebTotal)) + log.Info("PrivateTotal=" + fmt.Sprint(PrivateTotal)) + resultObj.Result = make([]map[string]interface{}, 0) + + if ctx.User != nil && (from < PrivateTotal || from == 0) { + + log.Info("actor is null?:" + fmt.Sprint(ctx.User == nil)) + sortBy := "ai_model_manage.reference_count desc,ai_model_manage.download_count desc,ai_model_manage.created_unix desc" + if SortBy != "" && SortBy != "default" { + sortBy = "ai_model_manage." + SortBy + if ascending { + sortBy += " asc" + } else { + sortBy += " desc" + } + } + //Page, PageSize, Key, ctx.User.ID + privateModels, count, err := models.QueryModelForSearch(&models.AiModelQueryOptions{ + ListOptions: models.ListOptions{ + Page: Page, + PageSize: PageSize, + }, + UserID: ctx.User.ID, + Namelike: Key, + SortType: sortBy, + }) + if err != nil { + ctx.JSON(200, "") + return + } + resultObj.PrivateTotal = count + modelSize := len(privateModels) + if modelSize > 0 { + log.Info("Query private model number is:" + fmt.Sprint(modelSize) + " count=" + fmt.Sprint(count)) + makePrivateModel(privateModels, resultObj, Key, language) + } else { + log.Info("not found private model, keyword=" + Key) + } + if modelSize >= PageSize { + if WebTotal > 0 { //next page, not first query. + resultObj.Total = int64(WebTotal) + ctx.JSON(200, resultObj) + return + } + } + } else { + resultObj.PrivateTotal = int64(PrivateTotal) + } + + from = from - PrivateTotal + if from < 0 { + from = 0 + } + Size := PageSize - len(resultObj.Result) + + boolQ := elastic.NewBoolQuery() + if Key != "" { + fileNameQuery := elastic.NewMatchQuery("file_name", Key).Boost(3).QueryName("f_first") + nameQuery := elastic.NewMatchQuery("name", Key).Boost(2).QueryName("f_second") + descQuery := elastic.NewMatchQuery("description", Key).Boost(1.5).QueryName("f_three") + labelQuery := elastic.NewMatchQuery("label", Key).Boost(1).QueryName("f_fourth") + boolQ.Should(fileNameQuery, nameQuery, descQuery, labelQuery) + res, err := client.Search(TableName).Query(boolQ).SortBy(getSort(SortBy, ascending, "updated_unix.keyword", false)...).From(from).Size(Size).Highlight(queryHighlight("file_name", "name", "description", "label")).Do(ctx.Req.Context()) + if err == nil { + searchJson, _ := json.Marshal(res) + log.Info("searchJson=" + string(searchJson)) + esresult := makeModelResult(res, Key, OnlyReturnNum, language) + resultObj.Total = resultObj.PrivateTotal + esresult.Total + log.Info("query model es count=" + fmt.Sprint(esresult.Total) + " total=" + fmt.Sprint(resultObj.Total)) + resultObj.Result = append(resultObj.Result, esresult.Result...) + ctx.JSON(200, resultObj) + } else { + log.Info("query es error," + err.Error()) + } + } else { + log.Info("query all models.") + //搜索的属性要指定{"timestamp":{"unmapped_type":"date"}} + res, err := client.Search(TableName).SortBy(getSort(SortBy, ascending, "updated_unix.keyword", false)...).From(from).Size(Size).Do(ctx.Req.Context()) + if err == nil { + searchJson, _ := json.Marshal(res) + log.Info("searchJson=" + string(searchJson)) + esresult := makeModelResult(res, "", OnlyReturnNum, language) + resultObj.Total = resultObj.PrivateTotal + esresult.Total + log.Info("query model es count=" + fmt.Sprint(esresult.Total) + " total=" + fmt.Sprint(resultObj.Total)) + resultObj.Result = append(resultObj.Result, esresult.Result...) + ctx.JSON(200, resultObj) + } else { + log.Info("query es error," + err.Error()) + ctx.JSON(200, "") + } + } + +} + +func makePrivateModel(privateModels []*models.AiModelManage, res *SearchRes, Key string, language string) { + for _, model := range privateModels { + record := make(map[string]interface{}) + + record["id"] = model.ID + userId := model.UserId + + user, errUser := models.GetUserByID(userId) + if errUser == nil { + record["owerName"] = user.GetDisplayName() + record["avatar"] = user.RelAvatarLink() + } + + repo, errRepo := models.GetRepositoryByID(model.RepoId) + if errRepo == nil { + log.Info("repo_url=" + repo.FullName()) + record["repoUrl"] = repo.FullName() + record["avatar"] = repo.RelAvatarLink() + } else { + log.Info("repo err=" + errRepo.Error()) + } + resultfile := models.QueryModelFileByModelId(model.ID) + file_name := "" + if resultfile != nil && len(resultfile) > 0 { + for _, file := range resultfile { + file_name += file.Name + "," + } + file_name = file_name[0 : len(file_name)-1] + } + record["file_name"] = truncLongText(makeHighLight(Key, file_name), true) + record["name"] = makeHighLight(Key, model.Name) + record["real_name"] = model.Name + record["is_private"] = model.IsPrivate + record["description"] = truncLongText(makeHighLight(Key, model.Description), true) + + record["label"] = makeHighLight(Key, model.Label) + record["download_count"] = model.DownloadCount + record["reference_count"] = model.ReferenceCount + record["engine"] = model.Engine + record["created_unix"] = model.CreatedUnix + record["updated_unix"] = model.UpdatedUnix + record["updated_html"] = timeutil.TimeSinceUnix(model.UpdatedUnix, language) + + res.Result = append(res.Result, record) + } +} + +func makeModelResult(sRes *elastic.SearchResult, Key string, OnlyReturnNum bool, language string) *SearchRes { + total := sRes.Hits.TotalHits.Value + result := make([]map[string]interface{}, 0) + if !OnlyReturnNum { + for i, hit := range sRes.Hits.Hits { + log.Info("this is model query " + fmt.Sprint(i) + " result.") + recordSource := make(map[string]interface{}) + source, err := hit.Source.MarshalJSON() + + if err == nil { + err = json.Unmarshal(source, &recordSource) + if err == nil { + record := make(map[string]interface{}) + record["id"] = hit.Id + userIdStr := recordSource["user_id"].(string) + userId, cerr := strconv.ParseInt(userIdStr, 10, 64) + if cerr == nil { + user, errUser := models.GetUserByID(userId) + if errUser == nil { + record["owerName"] = user.GetDisplayName() + record["avatar"] = user.RelAvatarLink() + } + } + setRepoInfo(recordSource, record) + record["name"] = getLabelValue("name", recordSource, hit.Highlight) + record["real_name"] = recordSource["name"] + record["label"] = getLabelValue("label", recordSource, hit.Highlight) + if recordSource["description"] != nil { + desc := getLabelValue("description", recordSource, hit.Highlight) + record["description"] = dealLongText(desc, Key, hit.MatchedQueries) + } else { + record["description"] = "" + } + record["is_private"] = recordSource["is_private"] + record["engine"] = recordSource["engine"] + record["file_name"] = getDatasetFileName(getLabelValue("file_name", recordSource, hit.Highlight)) + record["reference_count"] = recordSource["reference_count"] + record["download_count"] = recordSource["download_count"] + record["created_unix"] = recordSource["created_unix"] + setUpdateHtml(record, recordSource["updated_unix"].(string), language) + result = append(result, record) + } else { + log.Info("deal model source error," + err.Error()) + } + } else { + log.Info("deal model source error," + err.Error()) + } + } + } + returnObj := &SearchRes{ + Total: total, + Result: result, + } + + return returnObj +} diff --git a/templates/explore/search_new.tmpl b/templates/explore/search_new.tmpl index e3a85414ff..aaceff902d 100644 --- a/templates/explore/search_new.tmpl +++ b/templates/explore/search_new.tmpl @@ -32,6 +32,10 @@ {{.i18n.Tr "home.search_dataset"}} + + {{.i18n.Tr "home.search_model"}} + + {{.i18n.Tr "home.search_issue"}} diff --git a/templates/repo/modelarts/trainjob/show.tmpl b/templates/repo/modelarts/trainjob/show.tmpl index 8031368635..cb3f0f5b16 100755 --- a/templates/repo/modelarts/trainjob/show.tmpl +++ b/templates/repo/modelarts/trainjob/show.tmpl @@ -305,7 +305,6 @@ point_hr: {{$.i18n.Tr "cloudbrain.point_hr"}}, memory: {{$.i18n.Tr "cloudbrain.memory"}}, shared_memory: {{$.i18n.Tr "cloudbrain.shared_memory"}}, - no_use_resource:{{$.i18n.Tr "cloudbrain.no_use_resource"}}, }); $('td.ti-text-form-content.spec{{$k}} div').text(specStr); })();