@@ -1,11 +1,10 @@
package container_builder
import (
"code.gitea.io/gitea/services/ai_model"
"path"
"strings"
"code.gitea.io/gitea/services/ai_model"
"code.gitea.io/gitea/routers/response"
"code.gitea.io/gitea/entity"
@@ -29,17 +28,39 @@ func (b *PretrainModelBuilder) SetOpts(opts *entity.ContainerBuildOpts) {
}
func (b *PretrainModelBuilder) Build(ctx *context.CreationContext) ([]entity.ContainerData, *response.BizError) {
form := ctx.Request
var preTrainModelEntity []entity.ContainerData
if ctx.Request.Cluster == entity.C2Net && (ctx.Request.JobType == models.JobTypeDebug || ctx.Request.JobType == models.JobTypeTrain) && ctx.Request.ComputeSource.Name == models.GPU {
//挂载一个文件夹保证容器内pretrainmodel目录提交镜像时不被打包
uploader := storage_helper.SelectStorageHelperFromStorageType(entity.OBS)
objectKey := path.Join(uploader.GetJobDefaultObjectKeyPrefix(form.JobName), "pretrain_model_mount")
uploader.MKDIR(objectKey, "pretrain model folder")
preTrainModelEntity = append(preTrainModelEntity, entity.ContainerData{
Name: "pretrain_model_mount",
Bucket: uploader.GetBucket(),
EndPoint: uploader.GetEndpoint(),
ObjectKey: objectKey + "/",
ReadOnly: false,
ContainerPath: b.Opts.ContainerPath,
RealPath: uploader.GetRealPath(objectKey),
S3DownloadUrl: uploader.GetS3DownloadUrl(objectKey),
IsDir: true,
IsOverwrite: true,
IsNeedUnzip: false,
})
}
if b.Opts.Disable {
return nil, nil
return preTrainModelEntity , nil
}
form := ctx.Request
storageTypes := b.Opts.AcceptStorageType
if storageTypes == nil || len(storageTypes) == 0 {
return nil, response.SYSTEM_ERROR
}
//未选择预训练模型,跳过此步
if form.PretrainModelId == "" {
return nil, nil
return preTrainModelEntity , nil
}
//查出模型数据
uuids := strings.Split(form.PretrainModelId, ";")
@@ -48,7 +69,7 @@ func (b *PretrainModelBuilder) Build(ctx *context.CreationContext) ([]entity.Con
log.Error("Can not find model", err)
return nil, response.MODEL_NOT_EXISTS
}
var preTrainModelEntity []entity.ContainerData
for _, m := range modelInfoMaps {
ai_model.InitModelMeta(m.ID)
data, err := b.buildModelData(m, form.JobName)
@@ -57,6 +78,7 @@ func (b *PretrainModelBuilder) Build(ctx *context.CreationContext) ([]entity.Con
}
preTrainModelEntity = append(preTrainModelEntity, data)
}
return preTrainModelEntity, nil
}
@@ -101,7 +123,7 @@ func (b *PretrainModelBuilder) buildModelData(m *models.AiModelManage, jobName s
Bucket: uploader.GetBucket(),
EndPoint: uploader.GetEndpoint(),
ObjectKey: preTrainModelPath,
ReadOnly: b.Opts.ReadOnly ,
ReadOnly: false ,
ContainerPath: path.Join(b.Opts.ContainerPath, m.Name),
RealPath: uploader.GetRealPath(preTrainModelPath),
S3DownloadUrl: uploader.GetS3DownloadUrl(preTrainModelPath),