#4068 统一搜索支持模型搜索,合入主分支。

Merged
ychao_1983 merged 26 commits from zouap into V20230425 1 year ago
  1. +45
    -0
      models/ai_model_manage.go
  2. +1
    -0
      options/locale/locale_en-US.ini
  3. +1
    -0
      options/locale/locale_zh-CN.ini
  4. +145
    -23
      public/home/search.js
  5. +1
    -1
      routers/repo/ai_model_manage.go
  6. +230
    -0
      routers/search.go
  7. +4
    -0
      templates/explore/search_new.tmpl
  8. +0
    -1
      templates/repo/modelarts/trainjob/show.tmpl

+ 45
- 0
models/ai_model_manage.go View File

@@ -774,3 +774,48 @@ func QueryModelFileByModelId(modelId string) []*AiModelFile {
}
return result
}

func QueryModelForSearch(opts *AiModelQueryOptions) ([]*AiModelManage, int64, error) {
sess := x.NewSession()
defer sess.Close()
var where string
where += "ai_model_manage.user_id=" + fmt.Sprint(opts.UserID)
where += " and ai_model_manage.is_private=true"
if opts.Namelike != "" {
where += " and ( ai_model_manage.name ILIKE '%" + opts.Namelike + "%'"
where += " or ai_model_manage.description ILIKE '%" + opts.Namelike + "%'"
where += " or ai_model_manage.label ILIKE '%" + opts.Namelike + "%'"
where += " or ai_model_file.name ILIKE '%" + opts.Namelike + "%')"
}

var count int64
var err error
count, err = sess.Join("LEFT", "ai_model_file", "ai_model_manage.id = ai_model_file.model_id").Select("count(distinct(ai_model_manage.id))").Where(where).Count(new(AiModelManage))
if err != nil {
log.Info("error=" + err.Error())
return nil, 0, fmt.Errorf("Count: %v", err)
}
if opts.Page >= 0 && opts.PageSize > 0 {
var start int
if opts.Page == 0 {
start = 0
} else {
start = (opts.Page - 1) * opts.PageSize
}
sess.Limit(opts.PageSize, start)
}
sess.Join("LEFT", "ai_model_file", "ai_model_manage.id = ai_model_file.model_id")
orderby := "ai_model_manage.created_unix desc"
if opts.SortType != "" {
orderby = opts.SortType
}
sess.OrderBy(orderby)
aiModelManages := make([]*AiModelManage, 0, setting.UI.IssuePagingNum)
if err := sess.Select("distinct(ai_model_manage.*)").Table("ai_model_manage").Where(where).
Find(&aiModelManages); err != nil {
log.Info("error=" + err.Error())
return nil, 0, fmt.Errorf("Find: %v", err)
}

return aiModelManages, count, nil
}

+ 1
- 0
options/locale/locale_en-US.ini View File

@@ -274,6 +274,7 @@ c2net_center=Center
search=Search
search_repo=Repository
search_dataset=DataSet
search_model=Model
search_issue=Issue
search_pr=Pull Request
search_user=User


+ 1
- 0
options/locale/locale_zh-CN.ini View File

@@ -276,6 +276,7 @@ c2net_center=中心
search=搜索
search_repo=项目
search_dataset=数据集
search_model=模型
search_issue=任务
search_pr=合并请求
search_user=用户


+ 145
- 23
public/home/search.js View File

@@ -27,6 +27,7 @@ var itemType = {
4: "org",
5: "dataset",
6: "pr",
7: "model",
};

var sortBy = {
@@ -51,6 +52,10 @@ var sortBy = {
51: "download_times",
60: "default",
61: "updated_unix.keyword",
70: "default",
71: "reference_count",
72: "download_count",
73: "created_unix.keyword",
};

var sortAscending = {
@@ -75,6 +80,10 @@ var sortAscending = {
51: "false",
60: "false",
61: "false",
70: "false",
71: "false",
72: "false",
73: "false",
};

var currentPage = 1;
@@ -149,6 +158,7 @@ function emptySearch() {
$("#pr_total").text("");
$("#issue_total").text("");
$("#dataset_total").text("");
$("#model_total").text("");
$("#user_total").text("");
$("#org_total").text("");
setActivate(null);
@@ -159,6 +169,7 @@ function initDiv(isSearchLabel = false) {
document.getElementById("search_div").style.display = "none";
document.getElementById("search_label_div").style.display = "block";
document.getElementById("dataset_item").style.display = "none";
document.getElementById("model_item").style.display = "none";
document.getElementById("issue_item").style.display = "none";
document.getElementById("pr_item").style.display = "none";
document.getElementById("user_item").style.display = "none";
@@ -168,6 +179,7 @@ function initDiv(isSearchLabel = false) {
document.getElementById("search_div").style.display = "block";
document.getElementById("search_label_div").style.display = "none";
document.getElementById("dataset_item").style.display = "block";
document.getElementById("model_item").style.display = "block";
document.getElementById("issue_item").style.display = "block";
document.getElementById("pr_item").style.display = "block";
document.getElementById("user_item").style.display = "block";
@@ -210,6 +222,9 @@ function doSpcifySearch(tableName, keyword, sortBy = "", ascending = "false") {
if (currentSearchTableName != "dataset") {
doSearch("dataset", currentSearchKeyword, 1, pageSize, true, "", false);
}
if (currentSearchTableName != "model") {
doSearch("model", currentSearchKeyword, 1, pageSize, true, "", false);
}
if (currentSearchTableName != "pr") {
doSearch("pr", currentSearchKeyword, 1, pageSize, true, "", false);
}
@@ -278,7 +293,7 @@ function doSearch(
success: function (json) {
displayResult(tableName, page, json, onlyReturnNum, keyword);
},
error: function (response) {},
error: function (response) { },
});
}

@@ -293,6 +308,8 @@ function displayResult(tableName, page, jsonResult, onlyReturnNum, keyword) {
displayOrgResult(page, jsonResult, onlyReturnNum, keyword);
} else if (tableName == "dataset") {
displayDataSetResult(page, jsonResult, onlyReturnNum, keyword);
} else if (tableName == "model") {
displayModelResult(page, jsonResult, onlyReturnNum, keyword);
} else if (tableName == "pr") {
displayPrResult(page, jsonResult, onlyReturnNum, keyword);
}
@@ -535,6 +552,104 @@ function displayDataSetResult(page, jsonResult, onlyReturnNum, keyword) {
}
}

function displayModelResult(page, jsonResult, onlyReturnNum, keyword) {
var data = jsonResult.Result;
var total = jsonResult.Total;
$("#model_total").text(total);
if (!onlyReturnNum) {
setActivate("model_item");
//$('#keyword_desc').text(keyword);
//$('#obj_desc').text(getLabel(isZh,"search_model"));
//$('#child_total').text(total);
$("#find_title").html(
getLabel(isZh, "find_title")
.replace("{keyword}", keyword)
.replace("{tablename}", getLabel(isZh, "search_model"))
.replace("{total}", total)
);

var sortHtml = "";
sortHtml +=
'<a class="' +
getActiveItem(70) +
'item" href="javascript:searchItem(7,70);" tabindex="-1" role="menuitem" id="menuitem_1">' +
getLabel(isZh, "search_matched") +
"</a>";
sortHtml +=
'<a class="' +
getActiveItem(71) +
'item" href="javascript:searchItem(7,71);" tabindex="-1" role="menuitem" id="menuitem_1">' +
getLabel(isZh, "search_matched_reference") +
"</a>";
sortHtml +=
'<a class="' +
getActiveItem(72) +
'item" href="javascript:searchItem(7,72);" tabindex="-1" role="menuitem" id="menuitem_1">' +
getLabel(isZh, "search_matched_download") +
"</a>";
sortHtml +=
'<a class="' +
getActiveItem(73) +
'item" href="javascript:searchItem(7,73);" tabindex="-1" role="menuitem" id="menuitem_1">' +
getLabel(isZh, "search_lasted_create") +
"</a>";
document.getElementById("sort_type").innerHTML = sortHtml;

var html = "";
var currentTime = new Date().getTime();
const engineMap = {
'0': 'PyTorch',
'1': 'TensorFlow',
'2': 'MindSpore',
'4': 'PaddlePaddle',
'5': 'OneFlow',
'6': 'MXNet',
'3': 'Other',
};
for (var i = 0; i < data.length; i++) {
var recordMap = data[i];
var createDate = new Date(recordMap['created_unix'] * 1000);
var createYear = createDate.getFullYear().toString();
var createMonth = (createDate.getMonth() + 1).toString();
var createDay = createDate.getDate().toString();
html += `<div class="item">
<div class="content">
<div class="ui right metas" style="color:#767676">${engineMap[recordMap['engine']] || 'Other'}</div>
<div class="ui header">
<a class="name" href="/${recordMap["repoUrl"]}/modelmanage/model_readme_tmpl?name=${recordMap["real_name"]}" target="_blank">
${recordMap["name"]}
</a>
</div>
<div class="description">
<p class="labels has-emoji">${recordMap['label'] ? recordMap['label'].replace(/font\scolor=/g, 'font_color=').trim().split(/\s+/).map(item => {
return '<span style="color:rgba(16, 16, 16, 0.8);border-radius:4px;font-size:12px;background:rgba(232, 232, 232, 0.6);padding:2px 6px;margin-right:8px">'
+ item.replace(/font_color=/g, 'font color=') + '</span>';
}).join('') : ''}</p>
<p class="descr has-emoji">${recordMap['description']}</p>
<p class="filename has-emoji">${recordMap['file_name'] || ''}</p>
<p class="time" style="display:flex;align-items:center">
<a style="margin-right:8px;display:inline-block;height:22px;" href="/${recordMap['owerName']}"><img src="/user/avatar/${recordMap['owerName']}/-1"
style="display:inline-block;width:22px;height:22px;border-radius:100%;"></a>
<span style="margin-right:8px" title="${getLabel(isZh, "create_time")}">
${createYear}-${createMonth.length < 2 ? '0' + createMonth : createMonth}-${createDay.length < 2 ? '0' + createDay : createDay}
</span>
<span title="${getLabel(isZh, "search_matched_reference")}" style="display:flex;align-items:center">
<i class="ri-link" style="margin-right:4px"></i><span style="margin-right:4px">${recordMap['reference_count']}</span>
</span>
<span title="${getLabel(isZh, "search_matched_download")}" style="display:flex;align-items:center">
<i class="ri-download-line" style="margin-right:4px"></i><span style="margin-right:8px">${recordMap['download_count']}</span>
</span>
${getLabel(isZh, "search_lasted_update")}&nbsp;${recordMap["updated_html"]}
</p>
</div>
</div>
</div>
`;
}
document.getElementById("child_search_item").innerHTML = html;
}
}

function displayOrgResult(page, jsonResult, onlyReturnNum, keyword) {
var data = jsonResult.Result;
var total = jsonResult.Total;
@@ -872,6 +987,7 @@ function setActivate(name) {
$("#user_item").removeClass("active");
$("#issue_item").removeClass("active");
$("#dataset_item").removeClass("active");
$("#model_item").removeClass("active");
$("#org_item").removeClass("active");
$("#pr_item").removeClass("active");
if (name == null) {
@@ -885,27 +1001,27 @@ function LetterAvatar(name, size, color) {
name = name || "";
size = size || 60;
var colours = [
"#1abc9c",
"#2ecc71",
"#3498db",
"#9b59b6",
"#34495e",
"#16a085",
"#27ae60",
"#2980b9",
"#8e44ad",
"#2c3e50",
"#f1c40f",
"#e67e22",
"#e74c3c",
"#00bcd4",
"#95a5a6",
"#f39c12",
"#d35400",
"#c0392b",
"#bdc3c7",
"#7f8c8d",
],
"#1abc9c",
"#2ecc71",
"#3498db",
"#9b59b6",
"#34495e",
"#16a085",
"#27ae60",
"#2980b9",
"#8e44ad",
"#2c3e50",
"#f1c40f",
"#e67e22",
"#e74c3c",
"#00bcd4",
"#95a5a6",
"#f39c12",
"#d35400",
"#c0392b",
"#bdc3c7",
"#7f8c8d",
],
nameSplit = String(name).split(" "),
initials,
charIndex,
@@ -1624,6 +1740,7 @@ var zhCN = {
search: "搜索",
search_repo: "项目",
search_dataset: "数据集",
search_model: "模型",
search_issue: "任务",
search_pr: "合并请求",
search_user: "用户",
@@ -1631,10 +1748,11 @@ var zhCN = {
search_finded: "找到",
search_matched: "最佳匹配",
search_matched_download: "下载次数",
search_matched_reference: "引用次数",
search_lasted_update: "最后更新于",
search_letter_asc: "字母顺序排序",
search_letter_desc: "字母逆序排序",
search_lasted_create: "最创建",
search_lasted_create: "最创建",
search_early_create: "最早创建",
search_add_by: "加入于",
search_lasted: "最近更新",
@@ -1654,12 +1772,14 @@ var zhCN = {
find_title:
'“<strong class="highlight" id="keyword_desc">{keyword}</strong>”相关{tablename}约为{total}个',
search_empty: "<strong>请输入任意关键字开始搜索。</strong>",
create_time: "创建时间"
};

var esUN = {
search: "Search",
search_repo: "Repository",
search_dataset: "DataSet",
search_model: "Model",
search_issue: "Issue",
search_pr: "Pull Request",
search_user: "User",
@@ -1667,6 +1787,7 @@ var esUN = {
search_finded: "Find",
search_matched: "Best Match",
search_matched_download: "Most downloads",
search_matched_reference: "Most reference",
search_lasted_update: "Updated ",
search_letter_asc: "Alphabetically",
search_letter_desc: "Reverse alphabetically",
@@ -1691,6 +1812,7 @@ var esUN = {
' {total} "<strong class="highlight" id="keyword_desc">{keyword}</strong>" related {tablename}',
search_empty:
"<strong>Please enter any keyword to start the search.</strong>",
create_time: "Create Time"
};
initDiv(false);
document.onreadystatechange = function () {


+ 1
- 1
routers/repo/ai_model_manage.go View File

@@ -223,7 +223,7 @@ func SaveNewNameModel(ctx *context.Context) {
return
}
SaveModel(ctx)
ctx.Status(200)
//ctx.Status(200)
log.Info("save model end.")
}



+ 230
- 0
routers/search.go View File

@@ -87,6 +87,9 @@ func SearchApi(ctx *context.Context) {
searchIssueOrPr(ctx, "issue-es-index"+setting.INDEXPOSTFIX, Key, Page, PageSize, OnlyReturnNum, "t")
//searchPR(ctx, "issue-es-index", Key, Page, PageSize, OnlyReturnNum)
return
} else if TableName == "model" {
searchModel(ctx, "model-es-index"+setting.INDEXPOSTFIX, Key, Page, PageSize, OnlyReturnNum)
return
}
}

@@ -313,6 +316,8 @@ func searchRepo(ctx *context.Context, TableName string, Key string, Page int, Pa

res, err := client.Search(TableName).Query(boolQ).SortBy(getSort(SortBy, ascending, "num_stars", false)...).From(from).Size(Size).Highlight(queryHighlight("alias", "description", "topics")).Do(ctx.Req.Context())
if err == nil {
searchJson, _ := json.Marshal(res)
log.Info("searchJson=" + string(searchJson))
esresult := makeRepoResult(res, Key, OnlyReturnNum, language)
setForkRepoOrder(esresult, SortBy)
resultObj.Total = resultObj.PrivateTotal + esresult.Total
@@ -1221,3 +1226,228 @@ func makeIssueResult(sRes *elastic.SearchResult, Key string, OnlyReturnNum bool,

return returnObj
}

func searchModel(ctx *context.Context, TableName string, Key string, Page int, PageSize int, OnlyReturnNum bool) {
/*
模型,model-es-index
搜索:
name , 名称
description 描述
label 标签
file_name 数据集文件名称
排序:
download_count
reference_count
created_unix
*/
log.Info("query searchModel start")
SortBy := ctx.Query("SortBy")
ascending := ctx.QueryBool("Ascending")
PrivateTotal := ctx.QueryInt("PrivateTotal")
WebTotal := ctx.QueryInt("WebTotal")
language := ctx.Query("language")
if language == "" {
language = "zh-CN"
}
from := (Page - 1) * PageSize
if from == 0 {
WebTotal = 0
}
resultObj := &SearchRes{}
log.Info("WebTotal=" + fmt.Sprint(WebTotal))
log.Info("PrivateTotal=" + fmt.Sprint(PrivateTotal))
resultObj.Result = make([]map[string]interface{}, 0)

if ctx.User != nil && (from < PrivateTotal || from == 0) {

log.Info("actor is null?:" + fmt.Sprint(ctx.User == nil))
sortBy := "ai_model_manage.reference_count desc,ai_model_manage.download_count desc,ai_model_manage.created_unix desc"
if SortBy != "" && SortBy != "default" {
sortBy = "ai_model_manage." + SortBy
if ascending {
sortBy += " asc"
} else {
sortBy += " desc"
}
}
//Page, PageSize, Key, ctx.User.ID
privateModels, count, err := models.QueryModelForSearch(&models.AiModelQueryOptions{
ListOptions: models.ListOptions{
Page: Page,
PageSize: PageSize,
},
UserID: ctx.User.ID,
Namelike: Key,
SortType: sortBy,
})
if err != nil {
ctx.JSON(200, "")
return
}
resultObj.PrivateTotal = count
modelSize := len(privateModels)
if modelSize > 0 {
log.Info("Query private model number is:" + fmt.Sprint(modelSize) + " count=" + fmt.Sprint(count))
makePrivateModel(privateModels, resultObj, Key, language)
} else {
log.Info("not found private model, keyword=" + Key)
}
if modelSize >= PageSize {
if WebTotal > 0 { //next page, not first query.
resultObj.Total = int64(WebTotal)
ctx.JSON(200, resultObj)
return
}
}
} else {
resultObj.PrivateTotal = int64(PrivateTotal)
}

from = from - PrivateTotal
if from < 0 {
from = 0
}
Size := PageSize - len(resultObj.Result)

boolQ := elastic.NewBoolQuery()
if Key != "" {
fileNameQuery := elastic.NewMatchQuery("file_name", Key).Boost(3).QueryName("f_first")
nameQuery := elastic.NewMatchQuery("name", Key).Boost(2).QueryName("f_second")
descQuery := elastic.NewMatchQuery("description", Key).Boost(1.5).QueryName("f_three")
labelQuery := elastic.NewMatchQuery("label", Key).Boost(1).QueryName("f_fourth")
boolQ.Should(fileNameQuery, nameQuery, descQuery, labelQuery)
res, err := client.Search(TableName).Query(boolQ).SortBy(getSort(SortBy, ascending, "updated_unix.keyword", false)...).From(from).Size(Size).Highlight(queryHighlight("file_name", "name", "description", "label")).Do(ctx.Req.Context())
if err == nil {
searchJson, _ := json.Marshal(res)
log.Info("searchJson=" + string(searchJson))
esresult := makeModelResult(res, Key, OnlyReturnNum, language)
resultObj.Total = resultObj.PrivateTotal + esresult.Total
log.Info("query model es count=" + fmt.Sprint(esresult.Total) + " total=" + fmt.Sprint(resultObj.Total))
resultObj.Result = append(resultObj.Result, esresult.Result...)
ctx.JSON(200, resultObj)
} else {
log.Info("query es error," + err.Error())
}
} else {
log.Info("query all models.")
//搜索的属性要指定{"timestamp":{"unmapped_type":"date"}}
res, err := client.Search(TableName).SortBy(getSort(SortBy, ascending, "updated_unix.keyword", false)...).From(from).Size(Size).Do(ctx.Req.Context())
if err == nil {
searchJson, _ := json.Marshal(res)
log.Info("searchJson=" + string(searchJson))
esresult := makeModelResult(res, "", OnlyReturnNum, language)
resultObj.Total = resultObj.PrivateTotal + esresult.Total
log.Info("query model es count=" + fmt.Sprint(esresult.Total) + " total=" + fmt.Sprint(resultObj.Total))
resultObj.Result = append(resultObj.Result, esresult.Result...)
ctx.JSON(200, resultObj)
} else {
log.Info("query es error," + err.Error())
ctx.JSON(200, "")
}
}

}

func makePrivateModel(privateModels []*models.AiModelManage, res *SearchRes, Key string, language string) {
for _, model := range privateModels {
record := make(map[string]interface{})

record["id"] = model.ID
userId := model.UserId

user, errUser := models.GetUserByID(userId)
if errUser == nil {
record["owerName"] = user.GetDisplayName()
record["avatar"] = user.RelAvatarLink()
}

repo, errRepo := models.GetRepositoryByID(model.RepoId)
if errRepo == nil {
log.Info("repo_url=" + repo.FullName())
record["repoUrl"] = repo.FullName()
record["avatar"] = repo.RelAvatarLink()
} else {
log.Info("repo err=" + errRepo.Error())
}
resultfile := models.QueryModelFileByModelId(model.ID)
file_name := ""
if resultfile != nil && len(resultfile) > 0 {
for _, file := range resultfile {
file_name += file.Name + ","
}
file_name = file_name[0 : len(file_name)-1]
}
record["file_name"] = truncLongText(makeHighLight(Key, file_name), true)
record["name"] = makeHighLight(Key, model.Name)
record["real_name"] = model.Name
record["is_private"] = model.IsPrivate
record["description"] = truncLongText(makeHighLight(Key, model.Description), true)

record["label"] = makeHighLight(Key, model.Label)
record["download_count"] = model.DownloadCount
record["reference_count"] = model.ReferenceCount
record["engine"] = model.Engine
record["created_unix"] = model.CreatedUnix
record["updated_unix"] = model.UpdatedUnix
record["updated_html"] = timeutil.TimeSinceUnix(model.UpdatedUnix, language)

res.Result = append(res.Result, record)
}
}

func makeModelResult(sRes *elastic.SearchResult, Key string, OnlyReturnNum bool, language string) *SearchRes {
total := sRes.Hits.TotalHits.Value
result := make([]map[string]interface{}, 0)
if !OnlyReturnNum {
for i, hit := range sRes.Hits.Hits {
log.Info("this is model query " + fmt.Sprint(i) + " result.")
recordSource := make(map[string]interface{})
source, err := hit.Source.MarshalJSON()

if err == nil {
err = json.Unmarshal(source, &recordSource)
if err == nil {
record := make(map[string]interface{})
record["id"] = hit.Id
userIdStr := recordSource["user_id"].(string)
userId, cerr := strconv.ParseInt(userIdStr, 10, 64)
if cerr == nil {
user, errUser := models.GetUserByID(userId)
if errUser == nil {
record["owerName"] = user.GetDisplayName()
record["avatar"] = user.RelAvatarLink()
}
}
setRepoInfo(recordSource, record)
record["name"] = getLabelValue("name", recordSource, hit.Highlight)
record["real_name"] = recordSource["name"]
record["label"] = getLabelValue("label", recordSource, hit.Highlight)
if recordSource["description"] != nil {
desc := getLabelValue("description", recordSource, hit.Highlight)
record["description"] = dealLongText(desc, Key, hit.MatchedQueries)
} else {
record["description"] = ""
}
record["is_private"] = recordSource["is_private"]
record["engine"] = recordSource["engine"]
record["file_name"] = getDatasetFileName(getLabelValue("file_name", recordSource, hit.Highlight))
record["reference_count"] = recordSource["reference_count"]
record["download_count"] = recordSource["download_count"]
record["created_unix"] = recordSource["created_unix"]
setUpdateHtml(record, recordSource["updated_unix"].(string), language)
result = append(result, record)
} else {
log.Info("deal model source error," + err.Error())
}
} else {
log.Info("deal model source error," + err.Error())
}
}
}
returnObj := &SearchRes{
Total: total,
Result: result,
}

return returnObj
}

+ 4
- 0
templates/explore/search_new.tmpl View File

@@ -32,6 +32,10 @@
{{.i18n.Tr "home.search_dataset"}}
<span class="ui circular mini label" id="dataset_total"></span>
</a>
<a id="model_item" class="item" href="javascript:searchItem(7,70);">
{{.i18n.Tr "home.search_model"}}
<span class="ui circular mini label" id="model_total"></span>
</a>
<a id="issue_item" class="item" href="javascript:searchItem(2,20);">
{{.i18n.Tr "home.search_issue"}}
<span class="ui circular mini label" id="issue_total"></span>


+ 0
- 1
templates/repo/modelarts/trainjob/show.tmpl View File

@@ -305,7 +305,6 @@
point_hr: {{$.i18n.Tr "cloudbrain.point_hr"}},
memory: {{$.i18n.Tr "cloudbrain.memory"}},
shared_memory: {{$.i18n.Tr "cloudbrain.shared_memory"}},
no_use_resource:{{$.i18n.Tr "cloudbrain.no_use_resource"}},
});
$('td.ti-text-form-content.spec{{$k}} div').text(specStr);
})();


Loading…
Cancel
Save