Commit 4610982d authored by linpeiqin's avatar linpeiqin

状态优化

parent f17543bb
package com.yice.webadmin.app.config;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.reactive.function.client.WebClient;
......
......@@ -16,7 +16,6 @@ import com.yice.webadmin.app.dto.ModelCompressDto;
import com.yice.webadmin.app.dto.ModelTaskDto;
import com.yice.webadmin.app.model.ModelCompress;
import com.yice.webadmin.app.model.ModelTask;
import com.yice.webadmin.app.model.ModelVersion;
import com.yice.webadmin.app.service.ModelCompressService;
import com.yice.webadmin.app.service.ModelTaskService;
import com.yice.webadmin.app.service.ModelVersionService;
......@@ -52,30 +51,19 @@ public class ModelCompressController {
* 新增模型压缩数据,及其关联的从表数据。
*
* @param modelCompressDto 新增主表对象。
* @param modelTaskDto 一对一模型任务从表Dto。
* @return 应答结果对象,包含新增对象主键Id。
*/
@ApiOperationSupport(ignoreParameters = {"modelCompressDto.taskId", "modelCompressDto.searchString"})
@OperationLog(type = SysOperationLogType.ADD)
@PostMapping("/add")
public ResponseResult<Long> add(
@MyRequestBody ModelCompressDto modelCompressDto,
@MyRequestBody ModelTaskDto modelTaskDto) {
Long modelVersionId = modelCompressDto.getSourceVersionId();
ModelVersion modelVersion = this.modelVersionService.getById(modelVersionId);
modelTaskDto.setVersionId(modelVersion.getVersionId());
modelTaskDto.setTaskType(2);
modelTaskDto.setModelId(modelVersion.getModelId());
modelTaskDto.setModelVersion(modelVersion.getModelVersion());
modelTaskDto.setVersionName(modelVersion.getVersionName());
ResponseResult<Tuple2<ModelCompress, JSONObject>> verifyResult =
this.doBusinessDataVerifyAndConvert(modelCompressDto, false, modelTaskDto);
if (!verifyResult.isSuccess()) {
return ResponseResult.errorFrom(verifyResult);
@MyRequestBody ModelCompressDto modelCompressDto) {
String errorMessage = MyCommonUtil.getModelValidationError(modelCompressDto, false);
if (errorMessage != null) {
return ResponseResult.error(ErrorCodeEnum.DATA_VALIDATED_FAILED, errorMessage);
}
Tuple2<ModelCompress, JSONObject> bizData = verifyResult.getData();
ModelCompress modelCompress = bizData.getFirst();
modelCompress = modelCompressService.saveNewWithRelation(modelCompress, bizData.getSecond());
ModelCompress modelCompress = MyModelUtil.copyTo(modelCompressDto, ModelCompress.class);
modelCompress = modelCompressService.saveNew(modelCompress);
return ResponseResult.success(modelCompress.getTaskId());
}
......
......@@ -75,23 +75,12 @@ public class ModelEstimateController {
@PostMapping("/add")
public ResponseResult<Long> add(
@MyRequestBody ModelEstimateDto modelEstimateDto) {
modelEstimateDto.setTaskStatus(0);
Long modelVersionId = modelEstimateDto.getModelVersionId();
ModelVersion modelVersion = this.modelVersionService.getById(modelVersionId);
ModelTaskDto modelTaskDto = new ModelTaskDto();
modelTaskDto.setVersionId(modelVersion.getVersionId());
modelTaskDto.setTaskType(1);
modelTaskDto.setModelId(modelVersion.getModelId());
modelTaskDto.setModelVersion(modelVersion.getModelVersion());
modelTaskDto.setVersionName(modelVersion.getVersionName());
ResponseResult<Tuple2<ModelEstimate, JSONObject>> verifyResult =
this.doBusinessDataVerifyAndConvert(modelEstimateDto, false, modelTaskDto);
if (!verifyResult.isSuccess()) {
return ResponseResult.errorFrom(verifyResult);
String errorMessage = MyCommonUtil.getModelValidationError(modelEstimateDto, false);
if (errorMessage != null) {
return ResponseResult.error(ErrorCodeEnum.DATA_VALIDATED_FAILED, errorMessage);
}
Tuple2<ModelEstimate, JSONObject> bizData = verifyResult.getData();
ModelEstimate modelEstimate = bizData.getFirst();
modelEstimate = modelEstimateService.saveNewWithRelation(modelEstimate, bizData.getSecond());
ModelEstimate modelEstimate = MyModelUtil.copyTo(modelEstimateDto, ModelEstimate.class);
modelEstimate = modelEstimateService.saveNew(modelEstimate);
return ResponseResult.success(modelEstimate.getTaskId());
}
......@@ -155,7 +144,7 @@ public class ModelEstimateController {
}
modelEstimate.setTaskStatus(taskStatus);
modelEstimate.setUpdateTime(new Date());
if (!modelEstimateService.updateById(modelEstimate)) {
if (!modelEstimateService.updateStatusById(modelEstimate)) {
errorMessage = "任务状态信息提交错误!";
return ResponseResult.error(ErrorCodeEnum.DATA_NOT_EXIST, errorMessage);
}
......
package com.yice.webadmin.app.controller;
import com.alibaba.fastjson.JSONObject;
import com.github.pagehelper.page.PageMethod;
import com.github.xiaoymin.knife4j.annotations.ApiOperationSupport;
import com.yice.common.core.annotation.MyRequestBody;
......@@ -15,9 +14,11 @@ import com.yice.webadmin.app.dto.ModelDeployDto;
import com.yice.webadmin.app.dto.ModelManageDto;
import com.yice.webadmin.app.dto.ModelTaskDto;
import com.yice.webadmin.app.dto.ModelVersionDto;
import com.yice.webadmin.app.model.*;
import com.yice.webadmin.app.model.ModelDeploy;
import com.yice.webadmin.app.model.ModelManage;
import com.yice.webadmin.app.model.ModelTask;
import com.yice.webadmin.app.model.ModelVersion;
import com.yice.webadmin.app.service.ModelManageService;
import com.yice.webadmin.app.service.ModelVersionService;
import com.yice.webadmin.app.service.TuningRunService;
import com.yice.webadmin.app.vo.ModelManageVo;
import io.swagger.annotations.Api;
......@@ -63,7 +64,7 @@ public class ModelManageController {
}
ModelManage modelManage = MyModelUtil.copyTo(modelManageDto, ModelManage.class);
ModelVersion modelVersion = MyModelUtil.copyTo(modelVersionDto, ModelVersion.class);
modelManage = this.tuningRunService.createToModel(modelManage,modelVersion);
modelManage = this.tuningRunService.createToModel(modelManage, modelVersion);
return ResponseResult.success(modelManage.getModelId());
}
......@@ -108,7 +109,6 @@ public class ModelManageController {
}
/**
* 删除模型管理数据。
*
......
......@@ -129,11 +129,11 @@ public class ModelVersionController {
public ResponseResult<String> change(@MyRequestBody Long versionId) throws IOException {
ModelVersion modelVersion = this.modelVersionService.getById(versionId);
JSONObject jsonObject = new JSONObject();
jsonObject.put("new_model_name",modelVersion.getVersionName());
jsonObject.put("new_model_path",modelVersion.getModelUrl());
jsonObject.put("controller_address",pythonConfig.getControllerAddress());
jsonObject.put("new_model_name", modelVersion.getVersionName());
jsonObject.put("new_model_path", modelVersion.getModelUrl());
jsonObject.put("controller_address", pythonConfig.getControllerAddress());
String url = this.pythonConfig.getChatAddress() + "llm_model/change";
String result = proxyPythonService.predictPost(url,jsonObject.toJSONString());
String result = proxyPythonService.predictPost(url, jsonObject.toJSONString());
JSONObject jo = JSON.parseObject(result);
Integer code = jo.getIntValue("code");
String msg = jo.getString("msg");
......@@ -144,6 +144,7 @@ public class ModelVersionController {
return ResponseResult.create(ErrorCodeEnum.SERVER_INTERNAL_ERROR, msg, data);
}
}
/**
* 停止指定的LLM模型(Model Worker)
*
......@@ -152,7 +153,7 @@ public class ModelVersionController {
@PostMapping("/stop")
public ResponseResult<String> stop() throws IOException {
JSONObject jsonObject = new JSONObject();
jsonObject.put("controller_address",pythonConfig.getControllerAddress());
jsonObject.put("controller_address", pythonConfig.getControllerAddress());
String url = this.pythonConfig.getChatAddress() + "llm_model/stop";
String result = proxyPythonService.predictPost(url, jsonObject.toJSONString());
JSONObject jo = JSON.parseObject(result);
......@@ -165,6 +166,7 @@ public class ModelVersionController {
return ResponseResult.create(ErrorCodeEnum.SERVER_INTERNAL_ERROR, msg, data);
}
}
/**
* 列出当前已加载的模型。
*
......@@ -173,10 +175,10 @@ public class ModelVersionController {
@PostMapping("/listModels")
public ResponseResult<JSONArray> listModels() throws IOException {
JSONObject jsonObject = new JSONObject();
jsonObject.put("placeholder","string");
jsonObject.put("controller_address",pythonConfig.getControllerAddress());
jsonObject.put("placeholder", "string");
jsonObject.put("controller_address", pythonConfig.getControllerAddress());
String url = this.pythonConfig.getChatAddress() + "llm_model/list_models";
String result = proxyPythonService.predictPost(url,jsonObject.toJSONString());
String result = proxyPythonService.predictPost(url, jsonObject.toJSONString());
JSONObject jo = JSON.parseObject(result);
Integer code = jo.getIntValue("code");
String msg = jo.getString("msg");
......
......@@ -49,6 +49,6 @@ public class ProxyPythonController {
@OperationLog(type = SysOperationLogType.OTHER)
@PostMapping("/predict")
public ResponseResult<String> predict(@RequestBody String requestBody) throws IOException {
return ResponseResult.success(this.proxyPythonService.predictPost(pythonConfig.getFactoryInterface(),requestBody));
return ResponseResult.success(this.proxyPythonService.predictPost(pythonConfig.getFactoryInterface(), requestBody));
}
}
......@@ -49,6 +49,6 @@ public class ReceivePythonController {
@OperationLog(type = SysOperationLogType.OTHER)
@PostMapping("/predict")
public ResponseResult<String> predict(@RequestBody String requestBody) throws IOException {
return ResponseResult.success(this.proxyPythonService.predictPost(pythonConfig.getFactoryInterface(),requestBody));
return ResponseResult.success(this.proxyPythonService.predictPost(pythonConfig.getFactoryInterface(), requestBody));
}
}
......@@ -55,12 +55,10 @@ public class TuningRunController {
@Autowired
private ModelVersionService modelVersionService;
@Autowired
private ModelManageService modelManageService;
private ModelTaskService modelTaskService;
@Autowired
private DatasetVersionService datasetVersionService;
@Autowired
private DatasetManageService datasetManageService;
@Autowired
private PythonConfig pythonConfig;
/**
......@@ -283,7 +281,7 @@ public class TuningRunController {
tuningRun.setRunStatus(runStatus);
tuningRun.setRunTime(runTime);
tuningRun.setUpdateTime(new Date());
if (!tuningRunService.updateById(tuningRun)) {
if (!tuningRunService.updateStatusById(tuningRun)) {
errorMessage = "运行状态信息提交错误!";
return ResponseResult.error(ErrorCodeEnum.DATA_NOT_EXIST, errorMessage);
}
......
......@@ -4,7 +4,7 @@ import com.yice.common.core.base.dao.BaseDaoMapper;
import com.yice.webadmin.app.model.KnowledgeManage;
import org.apache.ibatis.annotations.Param;
import java.util.*;
import java.util.List;
/**
* 知识库管理数据操作访问接口。
......@@ -25,7 +25,7 @@ public interface KnowledgeManageMapper extends BaseDaoMapper<KnowledgeManage> {
* 获取过滤后的对象列表。
*
* @param knowledgeManageFilter 主表过滤对象。
* @param orderBy 排序字符串,order by从句的参数。
* @param orderBy 排序字符串,order by从句的参数。
* @return 对象列表。
*/
List<KnowledgeManage> getKnowledgeManageList(
......
......@@ -13,15 +13,15 @@
<insert id="insertList">
INSERT INTO lmp_knowledge_manage
(knowledge_id,
create_user_id,
create_time,
update_user_id,
update_time,
knowledge_name,
knowledge_describe)
(knowledge_id,
create_user_id,
create_time,
update_user_id,
update_time,
knowledge_name,
knowledge_describe)
VALUES
<foreach collection="list" index="index" item="item" separator="," >
<foreach collection="list" index="index" item="item" separator=",">
(#{item.knowledgeId},
#{item.createUserId},
#{item.createTime},
......@@ -51,13 +51,14 @@
AND lmp_knowledge_manage.knowledge_name = #{knowledgeManageFilter.knowledgeName}
</if>
<if test="knowledgeManageFilter.searchString != null and knowledgeManageFilter.searchString != ''">
<bind name = "safeKnowledgeManageSearchString" value = "'%' + knowledgeManageFilter.searchString + '%'" />
<bind name="safeKnowledgeManageSearchString" value="'%' + knowledgeManageFilter.searchString + '%'"/>
AND IFNULL(lmp_knowledge_manage.knowledge_name,'') LIKE #{safeKnowledgeManageSearchString}
</if>
</if>
</sql>
<select id="getKnowledgeManageList" resultMap="BaseResultMap" parameterType="com.yice.webadmin.app.model.KnowledgeManage">
<select id="getKnowledgeManageList" resultMap="BaseResultMap"
parameterType="com.yice.webadmin.app.model.KnowledgeManage">
SELECT * FROM lmp_knowledge_manage
<where>
<include refid="filterRef"/>
......
......@@ -87,7 +87,7 @@
#{item.taskId},
#{item.runId},
#{item.basePromptTemplate}
)
)
</foreach>
</insert>
......
package com.yice.webadmin.app.dto;
import com.yice.common.core.validator.UpdateGroup;
import io.swagger.annotations.ApiModel;
import io.swagger.annotations.ApiModelProperty;
import lombok.Data;
import javax.validation.constraints.*;
import javax.validation.constraints.NotBlank;
import javax.validation.constraints.NotNull;
/**
* KnowledgeManageDto对象。
......
package com.yice.webadmin.app.model;
import com.baomidou.mybatisplus.annotation.*;
import com.yice.common.core.util.MyCommonUtil;
import com.yice.common.core.base.model.BaseModel;
import com.baomidou.mybatisplus.annotation.TableField;
import com.baomidou.mybatisplus.annotation.TableId;
import com.baomidou.mybatisplus.annotation.TableName;
import com.yice.common.core.base.mapper.BaseModelMapper;
import com.yice.common.core.base.model.BaseModel;
import com.yice.common.core.util.MyCommonUtil;
import com.yice.webadmin.app.vo.KnowledgeManageVo;
import lombok.Data;
import lombok.EqualsAndHashCode;
import org.mapstruct.*;
import org.mapstruct.Mapper;
import org.mapstruct.factory.Mappers;
/**
......@@ -50,5 +52,6 @@ public class KnowledgeManage extends BaseModel {
@Mapper
public interface KnowledgeManageModelMapper extends BaseModelMapper<KnowledgeManageVo, KnowledgeManage> {
}
public static final KnowledgeManageModelMapper INSTANCE = Mappers.getMapper(KnowledgeManageModelMapper.class);
}
......@@ -8,7 +8,6 @@ import com.yice.common.core.base.mapper.BaseModelMapper;
import com.yice.common.core.base.model.BaseModel;
import com.yice.common.core.util.MyCommonUtil;
import com.yice.webadmin.app.vo.ModelCompressVo;
import io.swagger.annotations.ApiModelProperty;
import lombok.Data;
import lombok.EqualsAndHashCode;
import org.mapstruct.Mapper;
......
package com.yice.webadmin.app.service;
import com.yice.webadmin.app.model.*;
import com.yice.common.core.base.service.IBaseService;
import com.yice.webadmin.app.model.KnowledgeManage;
import java.util.*;
import java.util.List;
/**
* 知识库管理数据操作服务接口。
......@@ -60,7 +60,7 @@ public interface KnowledgeManageService extends IBaseService<KnowledgeManage, Lo
* 该查询会涉及到一对一从表的关联过滤,或一对多从表的嵌套关联过滤,因此性能不如单表过滤。
* 如果仅仅需要获取主表数据,请移步(getKnowledgeManageList),以便获取更好的查询性能。
*
* @param filter 主表过滤对象。
* @param filter 主表过滤对象。
* @param orderBy 排序参数。
* @return 查询结果集。
*/
......
......@@ -5,7 +5,6 @@ import com.yice.common.core.base.service.IBaseService;
import com.yice.webadmin.app.model.ModelCompress;
import com.yice.webadmin.app.model.ModelTask;
import java.net.URISyntaxException;
import java.util.List;
/**
......@@ -31,15 +30,6 @@ public interface ModelCompressService extends IBaseService<ModelCompress, Long>
*/
void saveNewBatch(List<ModelCompress> modelCompressList);
/**
* 保存新增主表对象及关联对象。
*
* @param modelCompress 新增主表对象。
* @param relationData 全部关联从表数据。
* @return 返回新增主表对象。
*/
ModelCompress saveNewWithRelation(ModelCompress modelCompress, JSONObject relationData);
/**
* 更新数据对象。
*
......
......@@ -87,4 +87,6 @@ public interface ModelEstimateService extends IBaseService<ModelEstimate, Long>
* @return 查询结果集。
*/
List<ModelEstimate> getModelEstimateListWithRelation(ModelEstimate filter, ModelTask modelTaskFilter, String orderBy);
boolean updateStatusById(ModelEstimate modelEstimate);
}
......@@ -83,6 +83,7 @@ public interface ModelManageService extends IBaseService<ModelManage, Long> {
List<ModelManage> getModelManageListWithRelation(ModelManage filter, ModelVersion modelVersionFilter, ModelTask modelTaskFilter, ModelDeploy modelDeployFilter, String orderBy);
ModelManage saveAndCreateVersion(ModelManage modelManage, ModelVersion modelVersion);
ModelVersion saveAndCreateVersionV(ModelManage modelManage, ModelVersion modelVersion);
}
......@@ -8,9 +8,11 @@ import java.io.IOException;
* @author linking
* @date 2023-04-13
*/
public interface ProxyPythonService{
public interface ProxyPythonService {
public String predictPost(String url, String requestBody) throws IOException;
public String predictGet(String url, String requestBody) throws IOException;
public String predictPostForFile(String url, String requestBody) throws IOException;
}
......@@ -82,4 +82,6 @@ public interface TuningRunService extends IBaseService<TuningRun, Long> {
ModelManage createToModel(ModelManage modelManage, ModelVersion modelVersion);
ModelVersion createToModelVersion(ModelVersion modelVersion);
boolean updateStatusById(TuningRun tuningRun);
}
......@@ -10,12 +10,10 @@ import com.yice.common.core.object.CallResult;
import com.yice.common.core.object.MyRelationParam;
import com.yice.common.core.util.MyModelUtil;
import com.yice.common.sequence.wrapper.IdGeneratorWrapper;
import com.yice.webadmin.app.controller.DatasetManageController;
import com.yice.webadmin.app.dao.DatasetVersionMapper;
import com.yice.webadmin.app.model.DatasetDetail;
import com.yice.webadmin.app.model.DatasetManage;
import com.yice.webadmin.app.model.DatasetVersion;
import com.yice.webadmin.app.model.ModelVersion;
import com.yice.webadmin.app.service.DatasetDetailService;
import com.yice.webadmin.app.service.DatasetManageService;
import com.yice.webadmin.app.service.DatasetVersionService;
......@@ -67,13 +65,13 @@ public class DatasetVersionServiceImpl extends BaseService<DatasetVersion, Long>
DatasetManage reDatasetManage = this.datasetManageService.getById(datasetVersion.getDatasetId());
DatasetVersion datasetVersionFilter = new DatasetVersion();
datasetVersionFilter.setDatasetId(datasetVersion.getDatasetId());
List<DatasetVersion> datasetVersionList = this.getDatasetVersionList(datasetVersionFilter,"dataset_version");
List<DatasetVersion> datasetVersionList = this.getDatasetVersionList(datasetVersionFilter, "dataset_version");
Integer version = 1;
if (datasetVersionList != null && datasetVersionList.size() != 0) {
version = datasetVersionList.get(datasetVersionList.size() - 1).getDatasetVersion() + 1;
}
datasetVersion.setDatasetVersion(version);
datasetVersion.setVersionName(reDatasetManage.getDatasetName()+"_V"+datasetVersion.getDatasetVersion());
datasetVersion.setVersionName(reDatasetManage.getDatasetName() + "_V" + datasetVersion.getDatasetVersion());
datasetVersion.setDatasetId(reDatasetManage.getDatasetId());
datasetVersion.setCleanStatus(0);
datasetVersion.setDataVolume(0L);
......
......@@ -8,17 +8,14 @@ import com.yice.common.core.base.service.BaseService;
import com.yice.common.core.object.MyRelationParam;
import com.yice.common.core.util.MyModelUtil;
import com.yice.common.sequence.wrapper.IdGeneratorWrapper;
import com.yice.webadmin.app.config.KnowledgeConfig;
import com.yice.webadmin.app.dao.KnowledgeManageMapper;
import com.yice.webadmin.app.model.KnowledgeManage;
import com.yice.webadmin.app.service.KnowledgeManageService;
import com.yice.webadmin.app.service.ProxyPythonService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import java.io.IOException;
import java.util.List;
/**
......
......@@ -12,13 +12,16 @@ import com.yice.common.sequence.wrapper.IdGeneratorWrapper;
import com.yice.webadmin.app.dao.ModelEstimateMapper;
import com.yice.webadmin.app.model.ModelEstimate;
import com.yice.webadmin.app.model.ModelTask;
import com.yice.webadmin.app.model.ModelVersion;
import com.yice.webadmin.app.service.ModelEstimateService;
import com.yice.webadmin.app.service.ModelTaskService;
import com.yice.webadmin.app.service.ModelVersionService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import java.util.Date;
import java.util.List;
/**
......@@ -37,6 +40,8 @@ public class ModelEstimateServiceImpl extends BaseService<ModelEstimate, Long> i
private ModelTaskService modelTaskService;
@Autowired
private IdGeneratorWrapper idGenerator;
@Autowired
private ModelVersionService modelVersionService;
/**
* 返回当前Service的主表Mapper对象。
......@@ -57,7 +62,18 @@ public class ModelEstimateServiceImpl extends BaseService<ModelEstimate, Long> i
@Transactional(rollbackFor = Exception.class)
@Override
public ModelEstimate saveNew(ModelEstimate modelEstimate) {
modelEstimate.setTaskStatus(0);
modelEstimateMapper.insert(this.buildDefaultValue(modelEstimate));
ModelVersion modelVersion = this.modelVersionService.getById(modelEstimate.getModelVersionId());
ModelTask modelTask = new ModelTask();
modelTask.setVersionId(modelVersion.getVersionId());
modelTask.setTaskType(1);
modelTask.setTaskStatus(0);
modelTask.setModelId(modelVersion.getModelId());
modelTask.setModelVersion(modelVersion.getModelVersion());
modelTask.setVersionName(modelVersion.getVersionName());
modelTask.setTaskId(modelEstimate.getTaskId());
this.modelTaskService.saveNew(modelTask);
return modelEstimate;
}
......@@ -172,6 +188,16 @@ public class ModelEstimateServiceImpl extends BaseService<ModelEstimate, Long> i
return resultList;
}
@Transactional(rollbackFor = Exception.class)
@Override
public boolean updateStatusById(ModelEstimate modelEstimate) {
ModelTask modelTask = this.modelTaskService.getById(modelEstimate.getTaskId());
modelTask.setTaskStatus(modelEstimate.getTaskStatus());
modelTask.setCompleteTime(new Date());
modelTaskService.updateById(modelTask);
return this.updateById(modelEstimate);
}
private ModelEstimate buildDefaultValue(ModelEstimate modelEstimate) {
if (modelEstimate.getTaskId() == null) {
modelEstimate.setTaskId(idGenerator.nextLongId());
......
......@@ -11,8 +11,14 @@ import com.yice.common.core.object.MyRelationParam;
import com.yice.common.core.util.MyModelUtil;
import com.yice.common.sequence.wrapper.IdGeneratorWrapper;
import com.yice.webadmin.app.dao.ModelManageMapper;
import com.yice.webadmin.app.model.*;
import com.yice.webadmin.app.service.*;
import com.yice.webadmin.app.model.ModelDeploy;
import com.yice.webadmin.app.model.ModelManage;
import com.yice.webadmin.app.model.ModelTask;
import com.yice.webadmin.app.model.ModelVersion;
import com.yice.webadmin.app.service.ModelDeployService;
import com.yice.webadmin.app.service.ModelManageService;
import com.yice.webadmin.app.service.ModelTaskService;
import com.yice.webadmin.app.service.ModelVersionService;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
......@@ -42,8 +48,6 @@ public class ModelManageServiceImpl extends BaseService<ModelManage, Long> imple
private ModelDeployService modelDeployService;
@Autowired
private IdGeneratorWrapper idGenerator;
@Autowired
private TuningRunService tuningRunService;
/**
* 返回当前Service的主表Mapper对象。
......@@ -186,6 +190,7 @@ public class ModelManageServiceImpl extends BaseService<ModelManage, Long> imple
this.modelVersionService.saveNew(modelVersion);
return reModelManage;
}
@Override
public ModelVersion saveAndCreateVersionV(ModelManage modelManage, ModelVersion modelVersion) {
ModelManage reModelManage = this.saveNew(modelManage);
......
......@@ -6,14 +6,11 @@ import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper;
import com.github.pagehelper.Page;
import com.yice.common.core.base.dao.BaseDaoMapper;
import com.yice.common.core.base.service.BaseService;
import com.yice.common.core.constant.ErrorCodeEnum;
import com.yice.common.core.object.CallResult;
import com.yice.common.core.object.MyRelationParam;
import com.yice.common.core.object.ResponseResult;
import com.yice.common.core.util.MyModelUtil;
import com.yice.common.sequence.wrapper.IdGeneratorWrapper;
import com.yice.webadmin.app.dao.ModelVersionMapper;
import com.yice.webadmin.app.model.ModelTask;
import com.yice.webadmin.app.model.ModelVersion;
import com.yice.webadmin.app.service.ModelManageService;
import com.yice.webadmin.app.service.ModelTaskService;
......@@ -65,7 +62,7 @@ public class ModelVersionServiceImpl extends BaseService<ModelVersion, Long> imp
String modelName = this.modelManageService.getById(modelVersion.getModelId()).getModelName();
ModelVersion modelVersionFilter = new ModelVersion();
modelVersionFilter.setModelId(modelVersion.getModelId());
List<ModelVersion> modelVersionList = this.getModelVersionList(modelVersionFilter,"model_Version");
List<ModelVersion> modelVersionList = this.getModelVersionList(modelVersionFilter, "model_Version");
Integer version = 1;
if (modelVersionList != null && modelVersionList.size() != 0) {
version = modelVersionList.get(modelVersionList.size() - 1).getModelVersion() + 1;
......
package com.yice.webadmin.app.service.impl;
import com.yice.webadmin.app.config.PythonConfig;
import com.yice.webadmin.app.service.ProxyPythonService;
import lombok.extern.slf4j.Slf4j;
import org.apache.http.HttpEntity;
......@@ -12,7 +11,6 @@ import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClients;
import org.apache.http.util.EntityUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import java.io.IOException;
......@@ -51,6 +49,7 @@ public class ProxyPythonServiceImpl implements ProxyPythonService {
}
return null;
}
public String predictPostForFile(String url, String requestBody) throws IOException {
CloseableHttpClient httpClient = HttpClients.createDefault();
try {
......@@ -76,6 +75,7 @@ public class ProxyPythonServiceImpl implements ProxyPythonService {
}
return null;
}
public String predictGet(String url, String requestBody) throws IOException {
CloseableHttpClient httpClient = HttpClients.createDefault();
try {
......
......@@ -186,7 +186,7 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
ModelVersion modelVersionSource = this.modelVersionService.getById(tuningRun.getModelVersionId());
ModelTask modelTask = new ModelTask();
ModelVersion modelVersionTarget = saveAll(tuningRun, runPublishDto, modelVersionSource, modelTask);
messageWithSocket(tuningRun, modelVersionSource, pythonConfig.getModelOutputFileBaseDir() + modelVersionTarget.getVersionName(), modelVersionTarget, modelTask);
messageWithSocket(tuningRun, modelVersionSource, modelVersionTarget, modelTask);
return true;
}
......@@ -197,7 +197,7 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
ModelVersion modelVersionSource = this.modelVersionService.getById(tuningRun.getModelVersionId());
ModelTask modelTask = new ModelTask();
ModelVersion modelVersionTarget = saveAll(tuningRun, modelManage, modelVersion, modelVersionSource, modelTask);
messageWithSocket(tuningRun, modelVersionSource, pythonConfig.getModelOutputFileBaseDir() + modelVersionTarget.getVersionName(), modelVersionTarget, modelTask);
messageWithSocket(tuningRun, modelVersionSource, modelVersionTarget, modelTask);
return modelManage;
}
......@@ -208,13 +208,23 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
ModelVersion modelVersionSource = this.modelVersionService.getById(tuningRun.getModelVersionId());
ModelTask modelTask = new ModelTask();
ModelVersion modelVersionTarget = saveAll(tuningRun, modelVersion, modelVersionSource, modelTask);
messageWithSocket(tuningRun, modelVersionSource, pythonConfig.getModelOutputFileBaseDir() + modelVersionTarget.getVersionName(), modelVersionTarget, modelTask);
messageWithSocket(tuningRun, modelVersionSource, modelVersionTarget, modelTask);
return modelVersionTarget;
}
@Transactional(rollbackFor = Exception.class)
@Override
public boolean updateStatusById(TuningRun tuningRun) {
ModelTask modelTask = this.modelTaskService.getById(tuningRun.getTaskId());
modelTask.setTaskStatus(tuningRun.getRunStatus());
modelTask.setCompleteTime(new Date());
modelTaskService.updateById(modelTask);
return this.updateById(tuningRun);
}
@SneakyThrows
private void messageWithSocket(TuningRun tuningRun, ModelVersion modelVersionSource, String targetModelVersionURl, ModelVersion modelVersionTarget, ModelTask modelTask) {
private void messageWithSocket(TuningRun tuningRun, ModelVersion modelVersionSource, ModelVersion modelVersionTarget, ModelTask modelTask) {
new WebSocketClient(new URI(this.pythonConfig.getPythonWebsocketUri())) {
@Override
public void onOpen(ServerHandshake serverHandshake) {
......@@ -246,7 +256,7 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
array.add(tuningRun.getTrainMethod());
array.add(jsonObject.get("promptTemplate"));
array.add(2);
array.add(targetModelVersionURl);
array.add(pythonConfig.getModelOutputFileBaseDir() + modelVersionTarget.getVersionName());
array.add("none");
sendJson.put("data", array);
sendJson.put("event_data", "null");
......@@ -257,7 +267,7 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
log.info("发送服务端的消息:" + sendJson.toJSONString());
if (receiveMsg.equals("process_completed")) {
System.out.println("isSuccess4:" + System.currentTimeMillis());
updateStatus(receiveJson.getBoolean("success"), tuningRun, modelVersionTarget, targetModelVersionURl, modelTask);
updateStatus(receiveJson.getBoolean("success"), tuningRun, modelVersionTarget, modelTask);
}
}
......@@ -336,9 +346,9 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
modelTask.setTaskId(tuningRun.getTaskId());
}
public Boolean updateStatus(Boolean flag, TuningRun tuningRun, ModelVersion modelVersionTarget, String targetModelVersionURl, ModelTask modelTask) {
public Boolean updateStatus(Boolean flag, TuningRun tuningRun, ModelVersion modelVersionTarget, ModelTask modelTask) {
if (flag) {
modelVersionTarget.setModelUrl(targetModelVersionURl);
modelVersionTarget.setModelUrl(pythonConfig.getModelOutputFileBaseDir() + modelVersionTarget.getVersionName());
modelVersionTarget.setStatus(1);
modelTask.setTaskStatus(1);
tuningRun.setPublishStatus(1);
......
......@@ -6,8 +6,6 @@ import io.swagger.annotations.ApiModelProperty;
import lombok.Data;
import lombok.EqualsAndHashCode;
import java.util.Date;
/**
* KnowledgeManageVO视图对象。
*
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment