diff --git a/models/ai_model_manage.go b/models/ai_model_manage.go
index 7eb21684b9..fa0ec33e91 100644
--- a/models/ai_model_manage.go
+++ b/models/ai_model_manage.go
@@ -774,3 +774,48 @@ func QueryModelFileByModelId(modelId string) []*AiModelFile {
}
return result
}
+
+func QueryModelForSearch(opts *AiModelQueryOptions) ([]*AiModelManage, int64, error) {
+ sess := x.NewSession()
+ defer sess.Close()
+ var where string
+ where += "ai_model_manage.user_id=" + fmt.Sprint(opts.UserID)
+ where += " and ai_model_manage.is_private=true"
+ if opts.Namelike != "" {
+ where += " and ( ai_model_manage.name ILIKE '%" + opts.Namelike + "%'"
+ where += " or ai_model_manage.description ILIKE '%" + opts.Namelike + "%'"
+ where += " or ai_model_manage.label ILIKE '%" + opts.Namelike + "%'"
+ where += " or ai_model_file.name ILIKE '%" + opts.Namelike + "%')"
+ }
+
+ var count int64
+ var err error
+ count, err = sess.Join("LEFT", "ai_model_file", "ai_model_manage.id = ai_model_file.model_id").Select("count(distinct(ai_model_manage.id))").Where(where).Count(new(AiModelManage))
+ if err != nil {
+ log.Info("error=" + err.Error())
+ return nil, 0, fmt.Errorf("Count: %v", err)
+ }
+ if opts.Page >= 0 && opts.PageSize > 0 {
+ var start int
+ if opts.Page == 0 {
+ start = 0
+ } else {
+ start = (opts.Page - 1) * opts.PageSize
+ }
+ sess.Limit(opts.PageSize, start)
+ }
+ sess.Join("LEFT", "ai_model_file", "ai_model_manage.id = ai_model_file.model_id")
+ orderby := "ai_model_manage.created_unix desc"
+ if opts.SortType != "" {
+ orderby = opts.SortType
+ }
+ sess.OrderBy(orderby)
+ aiModelManages := make([]*AiModelManage, 0, setting.UI.IssuePagingNum)
+ if err := sess.Select("distinct(ai_model_manage.*)").Table("ai_model_manage").Where(where).
+ Find(&aiModelManages); err != nil {
+ log.Info("error=" + err.Error())
+ return nil, 0, fmt.Errorf("Find: %v", err)
+ }
+
+ return aiModelManages, count, nil
+}
diff --git a/options/locale/locale_en-US.ini b/options/locale/locale_en-US.ini
index f793db2ccf..3f8cdf4e83 100755
--- a/options/locale/locale_en-US.ini
+++ b/options/locale/locale_en-US.ini
@@ -274,6 +274,7 @@ c2net_center=Center
search=Search
search_repo=Repository
search_dataset=DataSet
+search_model=Model
search_issue=Issue
search_pr=Pull Request
search_user=User
diff --git a/options/locale/locale_zh-CN.ini b/options/locale/locale_zh-CN.ini
index a0f2f62439..35c7e24dbf 100755
--- a/options/locale/locale_zh-CN.ini
+++ b/options/locale/locale_zh-CN.ini
@@ -276,6 +276,7 @@ c2net_center=中心
search=搜索
search_repo=项目
search_dataset=数据集
+search_model=模型
search_issue=任务
search_pr=合并请求
search_user=用户
diff --git a/public/home/search.js b/public/home/search.js
index 86b2ad06eb..57542c88c2 100644
--- a/public/home/search.js
+++ b/public/home/search.js
@@ -27,6 +27,7 @@ var itemType = {
4: "org",
5: "dataset",
6: "pr",
+ 7: "model",
};
var sortBy = {
@@ -51,6 +52,10 @@ var sortBy = {
51: "download_times",
60: "default",
61: "updated_unix.keyword",
+ 70: "default",
+ 71: "reference_count",
+ 72: "download_count",
+ 73: "created_unix.keyword",
};
var sortAscending = {
@@ -75,6 +80,10 @@ var sortAscending = {
51: "false",
60: "false",
61: "false",
+ 70: "false",
+ 71: "false",
+ 72: "false",
+ 73: "false",
};
var currentPage = 1;
@@ -149,6 +158,7 @@ function emptySearch() {
$("#pr_total").text("");
$("#issue_total").text("");
$("#dataset_total").text("");
+ $("#model_total").text("");
$("#user_total").text("");
$("#org_total").text("");
setActivate(null);
@@ -159,6 +169,7 @@ function initDiv(isSearchLabel = false) {
document.getElementById("search_div").style.display = "none";
document.getElementById("search_label_div").style.display = "block";
document.getElementById("dataset_item").style.display = "none";
+ document.getElementById("model_item").style.display = "none";
document.getElementById("issue_item").style.display = "none";
document.getElementById("pr_item").style.display = "none";
document.getElementById("user_item").style.display = "none";
@@ -168,6 +179,7 @@ function initDiv(isSearchLabel = false) {
document.getElementById("search_div").style.display = "block";
document.getElementById("search_label_div").style.display = "none";
document.getElementById("dataset_item").style.display = "block";
+ document.getElementById("model_item").style.display = "block";
document.getElementById("issue_item").style.display = "block";
document.getElementById("pr_item").style.display = "block";
document.getElementById("user_item").style.display = "block";
@@ -210,6 +222,9 @@ function doSpcifySearch(tableName, keyword, sortBy = "", ascending = "false") {
if (currentSearchTableName != "dataset") {
doSearch("dataset", currentSearchKeyword, 1, pageSize, true, "", false);
}
+ if (currentSearchTableName != "model") {
+ doSearch("model", currentSearchKeyword, 1, pageSize, true, "", false);
+ }
if (currentSearchTableName != "pr") {
doSearch("pr", currentSearchKeyword, 1, pageSize, true, "", false);
}
@@ -278,7 +293,7 @@ function doSearch(
success: function (json) {
displayResult(tableName, page, json, onlyReturnNum, keyword);
},
- error: function (response) {},
+ error: function (response) { },
});
}
@@ -293,6 +308,8 @@ function displayResult(tableName, page, jsonResult, onlyReturnNum, keyword) {
displayOrgResult(page, jsonResult, onlyReturnNum, keyword);
} else if (tableName == "dataset") {
displayDataSetResult(page, jsonResult, onlyReturnNum, keyword);
+ } else if (tableName == "model") {
+ displayModelResult(page, jsonResult, onlyReturnNum, keyword);
} else if (tableName == "pr") {
displayPrResult(page, jsonResult, onlyReturnNum, keyword);
}
@@ -535,6 +552,104 @@ function displayDataSetResult(page, jsonResult, onlyReturnNum, keyword) {
}
}
+function displayModelResult(page, jsonResult, onlyReturnNum, keyword) {
+ var data = jsonResult.Result;
+ var total = jsonResult.Total;
+ $("#model_total").text(total);
+ if (!onlyReturnNum) {
+ setActivate("model_item");
+ //$('#keyword_desc').text(keyword);
+ //$('#obj_desc').text(getLabel(isZh,"search_model"));
+ //$('#child_total').text(total);
+ $("#find_title").html(
+ getLabel(isZh, "find_title")
+ .replace("{keyword}", keyword)
+ .replace("{tablename}", getLabel(isZh, "search_model"))
+ .replace("{total}", total)
+ );
+
+ var sortHtml = "";
+ sortHtml +=
+ '";
+ sortHtml +=
+ '";
+ sortHtml +=
+ '";
+ sortHtml +=
+ '";
+ document.getElementById("sort_type").innerHTML = sortHtml;
+
+ var html = "";
+ var currentTime = new Date().getTime();
+ const engineMap = {
+ '0': 'PyTorch',
+ '1': 'TensorFlow',
+ '2': 'MindSpore',
+ '4': 'PaddlePaddle',
+ '5': 'OneFlow',
+ '6': 'MXNet',
+ '3': 'Other',
+ };
+ for (var i = 0; i < data.length; i++) {
+ var recordMap = data[i];
+ var createDate = new Date(recordMap['created_unix'] * 1000);
+ var createYear = createDate.getFullYear().toString();
+ var createMonth = (createDate.getMonth() + 1).toString();
+ var createDay = createDate.getDate().toString();
+ html += `
+
+
${engineMap[recordMap['engine']] || 'Other'}
+
+
+
${recordMap['label'] ? recordMap['label'].replace(/font\scolor=/g, 'font_color=').trim().split(/\s+/).map(item => {
+ return ''
+ + item.replace(/font_color=/g, 'font color=') + '';
+ }).join('') : ''}
+
${recordMap['description']}
+
${recordMap['file_name'] || ''}
+
+
+
+ ${createYear}-${createMonth.length < 2 ? '0' + createMonth : createMonth}-${createDay.length < 2 ? '0' + createDay : createDay}
+
+
+ ${recordMap['reference_count']}
+
+
+ ${recordMap['download_count']}
+
+ ${getLabel(isZh, "search_lasted_update")} ${recordMap["updated_html"]}
+
+
+
+
+ `;
+ }
+ document.getElementById("child_search_item").innerHTML = html;
+ }
+}
+
function displayOrgResult(page, jsonResult, onlyReturnNum, keyword) {
var data = jsonResult.Result;
var total = jsonResult.Total;
@@ -872,6 +987,7 @@ function setActivate(name) {
$("#user_item").removeClass("active");
$("#issue_item").removeClass("active");
$("#dataset_item").removeClass("active");
+ $("#model_item").removeClass("active");
$("#org_item").removeClass("active");
$("#pr_item").removeClass("active");
if (name == null) {
@@ -885,27 +1001,27 @@ function LetterAvatar(name, size, color) {
name = name || "";
size = size || 60;
var colours = [
- "#1abc9c",
- "#2ecc71",
- "#3498db",
- "#9b59b6",
- "#34495e",
- "#16a085",
- "#27ae60",
- "#2980b9",
- "#8e44ad",
- "#2c3e50",
- "#f1c40f",
- "#e67e22",
- "#e74c3c",
- "#00bcd4",
- "#95a5a6",
- "#f39c12",
- "#d35400",
- "#c0392b",
- "#bdc3c7",
- "#7f8c8d",
- ],
+ "#1abc9c",
+ "#2ecc71",
+ "#3498db",
+ "#9b59b6",
+ "#34495e",
+ "#16a085",
+ "#27ae60",
+ "#2980b9",
+ "#8e44ad",
+ "#2c3e50",
+ "#f1c40f",
+ "#e67e22",
+ "#e74c3c",
+ "#00bcd4",
+ "#95a5a6",
+ "#f39c12",
+ "#d35400",
+ "#c0392b",
+ "#bdc3c7",
+ "#7f8c8d",
+ ],
nameSplit = String(name).split(" "),
initials,
charIndex,
@@ -1624,6 +1740,7 @@ var zhCN = {
search: "搜索",
search_repo: "项目",
search_dataset: "数据集",
+ search_model: "模型",
search_issue: "任务",
search_pr: "合并请求",
search_user: "用户",
@@ -1631,10 +1748,11 @@ var zhCN = {
search_finded: "找到",
search_matched: "最佳匹配",
search_matched_download: "下载次数",
+ search_matched_reference: "引用次数",
search_lasted_update: "最后更新于",
search_letter_asc: "字母顺序排序",
search_letter_desc: "字母逆序排序",
- search_lasted_create: "最近创建",
+ search_lasted_create: "最新创建",
search_early_create: "最早创建",
search_add_by: "加入于",
search_lasted: "最近更新",
@@ -1654,12 +1772,14 @@ var zhCN = {
find_title:
'“{keyword}”相关{tablename}约为{total}个',
search_empty: "请输入任意关键字开始搜索。",
+ create_time: "创建时间"
};
var esUN = {
search: "Search",
search_repo: "Repository",
search_dataset: "DataSet",
+ search_model: "Model",
search_issue: "Issue",
search_pr: "Pull Request",
search_user: "User",
@@ -1667,6 +1787,7 @@ var esUN = {
search_finded: "Find",
search_matched: "Best Match",
search_matched_download: "Most downloads",
+ search_matched_reference: "Most reference",
search_lasted_update: "Updated ",
search_letter_asc: "Alphabetically",
search_letter_desc: "Reverse alphabetically",
@@ -1691,6 +1812,7 @@ var esUN = {
' {total} "{keyword}" related {tablename}',
search_empty:
"Please enter any keyword to start the search.",
+ create_time: "Create Time"
};
initDiv(false);
document.onreadystatechange = function () {
diff --git a/routers/repo/ai_model_manage.go b/routers/repo/ai_model_manage.go
index 036fc437e8..af5271c5db 100644
--- a/routers/repo/ai_model_manage.go
+++ b/routers/repo/ai_model_manage.go
@@ -223,7 +223,7 @@ func SaveNewNameModel(ctx *context.Context) {
return
}
SaveModel(ctx)
- ctx.Status(200)
+ //ctx.Status(200)
log.Info("save model end.")
}
diff --git a/routers/search.go b/routers/search.go
index 8453d5c185..1451b250d9 100644
--- a/routers/search.go
+++ b/routers/search.go
@@ -87,6 +87,9 @@ func SearchApi(ctx *context.Context) {
searchIssueOrPr(ctx, "issue-es-index"+setting.INDEXPOSTFIX, Key, Page, PageSize, OnlyReturnNum, "t")
//searchPR(ctx, "issue-es-index", Key, Page, PageSize, OnlyReturnNum)
return
+ } else if TableName == "model" {
+ searchModel(ctx, "model-es-index"+setting.INDEXPOSTFIX, Key, Page, PageSize, OnlyReturnNum)
+ return
}
}
@@ -313,6 +316,8 @@ func searchRepo(ctx *context.Context, TableName string, Key string, Page int, Pa
res, err := client.Search(TableName).Query(boolQ).SortBy(getSort(SortBy, ascending, "num_stars", false)...).From(from).Size(Size).Highlight(queryHighlight("alias", "description", "topics")).Do(ctx.Req.Context())
if err == nil {
+ searchJson, _ := json.Marshal(res)
+ log.Info("searchJson=" + string(searchJson))
esresult := makeRepoResult(res, Key, OnlyReturnNum, language)
setForkRepoOrder(esresult, SortBy)
resultObj.Total = resultObj.PrivateTotal + esresult.Total
@@ -1221,3 +1226,228 @@ func makeIssueResult(sRes *elastic.SearchResult, Key string, OnlyReturnNum bool,
return returnObj
}
+
+func searchModel(ctx *context.Context, TableName string, Key string, Page int, PageSize int, OnlyReturnNum bool) {
+ /*
+ 模型,model-es-index
+ 搜索:
+ name , 名称
+ description 描述
+ label 标签
+ file_name 数据集文件名称
+ 排序:
+ download_count
+ reference_count
+ created_unix
+ */
+ log.Info("query searchModel start")
+ SortBy := ctx.Query("SortBy")
+ ascending := ctx.QueryBool("Ascending")
+ PrivateTotal := ctx.QueryInt("PrivateTotal")
+ WebTotal := ctx.QueryInt("WebTotal")
+ language := ctx.Query("language")
+ if language == "" {
+ language = "zh-CN"
+ }
+ from := (Page - 1) * PageSize
+ if from == 0 {
+ WebTotal = 0
+ }
+ resultObj := &SearchRes{}
+ log.Info("WebTotal=" + fmt.Sprint(WebTotal))
+ log.Info("PrivateTotal=" + fmt.Sprint(PrivateTotal))
+ resultObj.Result = make([]map[string]interface{}, 0)
+
+ if ctx.User != nil && (from < PrivateTotal || from == 0) {
+
+ log.Info("actor is null?:" + fmt.Sprint(ctx.User == nil))
+ sortBy := "ai_model_manage.reference_count desc,ai_model_manage.download_count desc,ai_model_manage.created_unix desc"
+ if SortBy != "" && SortBy != "default" {
+ sortBy = "ai_model_manage." + SortBy
+ if ascending {
+ sortBy += " asc"
+ } else {
+ sortBy += " desc"
+ }
+ }
+ //Page, PageSize, Key, ctx.User.ID
+ privateModels, count, err := models.QueryModelForSearch(&models.AiModelQueryOptions{
+ ListOptions: models.ListOptions{
+ Page: Page,
+ PageSize: PageSize,
+ },
+ UserID: ctx.User.ID,
+ Namelike: Key,
+ SortType: sortBy,
+ })
+ if err != nil {
+ ctx.JSON(200, "")
+ return
+ }
+ resultObj.PrivateTotal = count
+ modelSize := len(privateModels)
+ if modelSize > 0 {
+ log.Info("Query private model number is:" + fmt.Sprint(modelSize) + " count=" + fmt.Sprint(count))
+ makePrivateModel(privateModels, resultObj, Key, language)
+ } else {
+ log.Info("not found private model, keyword=" + Key)
+ }
+ if modelSize >= PageSize {
+ if WebTotal > 0 { //next page, not first query.
+ resultObj.Total = int64(WebTotal)
+ ctx.JSON(200, resultObj)
+ return
+ }
+ }
+ } else {
+ resultObj.PrivateTotal = int64(PrivateTotal)
+ }
+
+ from = from - PrivateTotal
+ if from < 0 {
+ from = 0
+ }
+ Size := PageSize - len(resultObj.Result)
+
+ boolQ := elastic.NewBoolQuery()
+ if Key != "" {
+ fileNameQuery := elastic.NewMatchQuery("file_name", Key).Boost(3).QueryName("f_first")
+ nameQuery := elastic.NewMatchQuery("name", Key).Boost(2).QueryName("f_second")
+ descQuery := elastic.NewMatchQuery("description", Key).Boost(1.5).QueryName("f_three")
+ labelQuery := elastic.NewMatchQuery("label", Key).Boost(1).QueryName("f_fourth")
+ boolQ.Should(fileNameQuery, nameQuery, descQuery, labelQuery)
+ res, err := client.Search(TableName).Query(boolQ).SortBy(getSort(SortBy, ascending, "updated_unix.keyword", false)...).From(from).Size(Size).Highlight(queryHighlight("file_name", "name", "description", "label")).Do(ctx.Req.Context())
+ if err == nil {
+ searchJson, _ := json.Marshal(res)
+ log.Info("searchJson=" + string(searchJson))
+ esresult := makeModelResult(res, Key, OnlyReturnNum, language)
+ resultObj.Total = resultObj.PrivateTotal + esresult.Total
+ log.Info("query model es count=" + fmt.Sprint(esresult.Total) + " total=" + fmt.Sprint(resultObj.Total))
+ resultObj.Result = append(resultObj.Result, esresult.Result...)
+ ctx.JSON(200, resultObj)
+ } else {
+ log.Info("query es error," + err.Error())
+ }
+ } else {
+ log.Info("query all models.")
+ //搜索的属性要指定{"timestamp":{"unmapped_type":"date"}}
+ res, err := client.Search(TableName).SortBy(getSort(SortBy, ascending, "updated_unix.keyword", false)...).From(from).Size(Size).Do(ctx.Req.Context())
+ if err == nil {
+ searchJson, _ := json.Marshal(res)
+ log.Info("searchJson=" + string(searchJson))
+ esresult := makeModelResult(res, "", OnlyReturnNum, language)
+ resultObj.Total = resultObj.PrivateTotal + esresult.Total
+ log.Info("query model es count=" + fmt.Sprint(esresult.Total) + " total=" + fmt.Sprint(resultObj.Total))
+ resultObj.Result = append(resultObj.Result, esresult.Result...)
+ ctx.JSON(200, resultObj)
+ } else {
+ log.Info("query es error," + err.Error())
+ ctx.JSON(200, "")
+ }
+ }
+
+}
+
+func makePrivateModel(privateModels []*models.AiModelManage, res *SearchRes, Key string, language string) {
+ for _, model := range privateModels {
+ record := make(map[string]interface{})
+
+ record["id"] = model.ID
+ userId := model.UserId
+
+ user, errUser := models.GetUserByID(userId)
+ if errUser == nil {
+ record["owerName"] = user.GetDisplayName()
+ record["avatar"] = user.RelAvatarLink()
+ }
+
+ repo, errRepo := models.GetRepositoryByID(model.RepoId)
+ if errRepo == nil {
+ log.Info("repo_url=" + repo.FullName())
+ record["repoUrl"] = repo.FullName()
+ record["avatar"] = repo.RelAvatarLink()
+ } else {
+ log.Info("repo err=" + errRepo.Error())
+ }
+ resultfile := models.QueryModelFileByModelId(model.ID)
+ file_name := ""
+ if resultfile != nil && len(resultfile) > 0 {
+ for _, file := range resultfile {
+ file_name += file.Name + ","
+ }
+ file_name = file_name[0 : len(file_name)-1]
+ }
+ record["file_name"] = truncLongText(makeHighLight(Key, file_name), true)
+ record["name"] = makeHighLight(Key, model.Name)
+ record["real_name"] = model.Name
+ record["is_private"] = model.IsPrivate
+ record["description"] = truncLongText(makeHighLight(Key, model.Description), true)
+
+ record["label"] = makeHighLight(Key, model.Label)
+ record["download_count"] = model.DownloadCount
+ record["reference_count"] = model.ReferenceCount
+ record["engine"] = model.Engine
+ record["created_unix"] = model.CreatedUnix
+ record["updated_unix"] = model.UpdatedUnix
+ record["updated_html"] = timeutil.TimeSinceUnix(model.UpdatedUnix, language)
+
+ res.Result = append(res.Result, record)
+ }
+}
+
+func makeModelResult(sRes *elastic.SearchResult, Key string, OnlyReturnNum bool, language string) *SearchRes {
+ total := sRes.Hits.TotalHits.Value
+ result := make([]map[string]interface{}, 0)
+ if !OnlyReturnNum {
+ for i, hit := range sRes.Hits.Hits {
+ log.Info("this is model query " + fmt.Sprint(i) + " result.")
+ recordSource := make(map[string]interface{})
+ source, err := hit.Source.MarshalJSON()
+
+ if err == nil {
+ err = json.Unmarshal(source, &recordSource)
+ if err == nil {
+ record := make(map[string]interface{})
+ record["id"] = hit.Id
+ userIdStr := recordSource["user_id"].(string)
+ userId, cerr := strconv.ParseInt(userIdStr, 10, 64)
+ if cerr == nil {
+ user, errUser := models.GetUserByID(userId)
+ if errUser == nil {
+ record["owerName"] = user.GetDisplayName()
+ record["avatar"] = user.RelAvatarLink()
+ }
+ }
+ setRepoInfo(recordSource, record)
+ record["name"] = getLabelValue("name", recordSource, hit.Highlight)
+ record["real_name"] = recordSource["name"]
+ record["label"] = getLabelValue("label", recordSource, hit.Highlight)
+ if recordSource["description"] != nil {
+ desc := getLabelValue("description", recordSource, hit.Highlight)
+ record["description"] = dealLongText(desc, Key, hit.MatchedQueries)
+ } else {
+ record["description"] = ""
+ }
+ record["is_private"] = recordSource["is_private"]
+ record["engine"] = recordSource["engine"]
+ record["file_name"] = getDatasetFileName(getLabelValue("file_name", recordSource, hit.Highlight))
+ record["reference_count"] = recordSource["reference_count"]
+ record["download_count"] = recordSource["download_count"]
+ record["created_unix"] = recordSource["created_unix"]
+ setUpdateHtml(record, recordSource["updated_unix"].(string), language)
+ result = append(result, record)
+ } else {
+ log.Info("deal model source error," + err.Error())
+ }
+ } else {
+ log.Info("deal model source error," + err.Error())
+ }
+ }
+ }
+ returnObj := &SearchRes{
+ Total: total,
+ Result: result,
+ }
+
+ return returnObj
+}
diff --git a/templates/explore/search_new.tmpl b/templates/explore/search_new.tmpl
index e3a85414ff..aaceff902d 100644
--- a/templates/explore/search_new.tmpl
+++ b/templates/explore/search_new.tmpl
@@ -32,6 +32,10 @@
{{.i18n.Tr "home.search_dataset"}}
+
+ {{.i18n.Tr "home.search_model"}}
+
+
{{.i18n.Tr "home.search_issue"}}
diff --git a/templates/repo/modelarts/trainjob/show.tmpl b/templates/repo/modelarts/trainjob/show.tmpl
index 8031368635..cb3f0f5b16 100755
--- a/templates/repo/modelarts/trainjob/show.tmpl
+++ b/templates/repo/modelarts/trainjob/show.tmpl
@@ -305,7 +305,6 @@
point_hr: {{$.i18n.Tr "cloudbrain.point_hr"}},
memory: {{$.i18n.Tr "cloudbrain.memory"}},
shared_memory: {{$.i18n.Tr "cloudbrain.shared_memory"}},
- no_use_resource:{{$.i18n.Tr "cloudbrain.no_use_resource"}},
});
$('td.ti-text-form-content.spec{{$k}} div').text(specStr);
})();