Commit 4610982d authored by linpeiqin's avatar linpeiqin

状态优化

parent f17543bb
package com.yice.webadmin.app.config; package com.yice.webadmin.app.config;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
import org.springframework.web.reactive.function.client.WebClient; import org.springframework.web.reactive.function.client.WebClient;
......
...@@ -16,7 +16,6 @@ import com.yice.webadmin.app.dto.ModelCompressDto; ...@@ -16,7 +16,6 @@ import com.yice.webadmin.app.dto.ModelCompressDto;
import com.yice.webadmin.app.dto.ModelTaskDto; import com.yice.webadmin.app.dto.ModelTaskDto;
import com.yice.webadmin.app.model.ModelCompress; import com.yice.webadmin.app.model.ModelCompress;
import com.yice.webadmin.app.model.ModelTask; import com.yice.webadmin.app.model.ModelTask;
import com.yice.webadmin.app.model.ModelVersion;
import com.yice.webadmin.app.service.ModelCompressService; import com.yice.webadmin.app.service.ModelCompressService;
import com.yice.webadmin.app.service.ModelTaskService; import com.yice.webadmin.app.service.ModelTaskService;
import com.yice.webadmin.app.service.ModelVersionService; import com.yice.webadmin.app.service.ModelVersionService;
...@@ -52,30 +51,19 @@ public class ModelCompressController { ...@@ -52,30 +51,19 @@ public class ModelCompressController {
* 新增模型压缩数据,及其关联的从表数据。 * 新增模型压缩数据,及其关联的从表数据。
* *
* @param modelCompressDto 新增主表对象。 * @param modelCompressDto 新增主表对象。
* @param modelTaskDto 一对一模型任务从表Dto。
* @return 应答结果对象,包含新增对象主键Id。 * @return 应答结果对象,包含新增对象主键Id。
*/ */
@ApiOperationSupport(ignoreParameters = {"modelCompressDto.taskId", "modelCompressDto.searchString"}) @ApiOperationSupport(ignoreParameters = {"modelCompressDto.taskId", "modelCompressDto.searchString"})
@OperationLog(type = SysOperationLogType.ADD) @OperationLog(type = SysOperationLogType.ADD)
@PostMapping("/add") @PostMapping("/add")
public ResponseResult<Long> add( public ResponseResult<Long> add(
@MyRequestBody ModelCompressDto modelCompressDto, @MyRequestBody ModelCompressDto modelCompressDto) {
@MyRequestBody ModelTaskDto modelTaskDto) { String errorMessage = MyCommonUtil.getModelValidationError(modelCompressDto, false);
Long modelVersionId = modelCompressDto.getSourceVersionId(); if (errorMessage != null) {
ModelVersion modelVersion = this.modelVersionService.getById(modelVersionId); return ResponseResult.error(ErrorCodeEnum.DATA_VALIDATED_FAILED, errorMessage);
modelTaskDto.setVersionId(modelVersion.getVersionId());
modelTaskDto.setTaskType(2);
modelTaskDto.setModelId(modelVersion.getModelId());
modelTaskDto.setModelVersion(modelVersion.getModelVersion());
modelTaskDto.setVersionName(modelVersion.getVersionName());
ResponseResult<Tuple2<ModelCompress, JSONObject>> verifyResult =
this.doBusinessDataVerifyAndConvert(modelCompressDto, false, modelTaskDto);
if (!verifyResult.isSuccess()) {
return ResponseResult.errorFrom(verifyResult);
} }
Tuple2<ModelCompress, JSONObject> bizData = verifyResult.getData(); ModelCompress modelCompress = MyModelUtil.copyTo(modelCompressDto, ModelCompress.class);
ModelCompress modelCompress = bizData.getFirst(); modelCompress = modelCompressService.saveNew(modelCompress);
modelCompress = modelCompressService.saveNewWithRelation(modelCompress, bizData.getSecond());
return ResponseResult.success(modelCompress.getTaskId()); return ResponseResult.success(modelCompress.getTaskId());
} }
......
...@@ -75,23 +75,12 @@ public class ModelEstimateController { ...@@ -75,23 +75,12 @@ public class ModelEstimateController {
@PostMapping("/add") @PostMapping("/add")
public ResponseResult<Long> add( public ResponseResult<Long> add(
@MyRequestBody ModelEstimateDto modelEstimateDto) { @MyRequestBody ModelEstimateDto modelEstimateDto) {
modelEstimateDto.setTaskStatus(0); String errorMessage = MyCommonUtil.getModelValidationError(modelEstimateDto, false);
Long modelVersionId = modelEstimateDto.getModelVersionId(); if (errorMessage != null) {
ModelVersion modelVersion = this.modelVersionService.getById(modelVersionId); return ResponseResult.error(ErrorCodeEnum.DATA_VALIDATED_FAILED, errorMessage);
ModelTaskDto modelTaskDto = new ModelTaskDto();
modelTaskDto.setVersionId(modelVersion.getVersionId());
modelTaskDto.setTaskType(1);
modelTaskDto.setModelId(modelVersion.getModelId());
modelTaskDto.setModelVersion(modelVersion.getModelVersion());
modelTaskDto.setVersionName(modelVersion.getVersionName());
ResponseResult<Tuple2<ModelEstimate, JSONObject>> verifyResult =
this.doBusinessDataVerifyAndConvert(modelEstimateDto, false, modelTaskDto);
if (!verifyResult.isSuccess()) {
return ResponseResult.errorFrom(verifyResult);
} }
Tuple2<ModelEstimate, JSONObject> bizData = verifyResult.getData(); ModelEstimate modelEstimate = MyModelUtil.copyTo(modelEstimateDto, ModelEstimate.class);
ModelEstimate modelEstimate = bizData.getFirst(); modelEstimate = modelEstimateService.saveNew(modelEstimate);
modelEstimate = modelEstimateService.saveNewWithRelation(modelEstimate, bizData.getSecond());
return ResponseResult.success(modelEstimate.getTaskId()); return ResponseResult.success(modelEstimate.getTaskId());
} }
...@@ -155,7 +144,7 @@ public class ModelEstimateController { ...@@ -155,7 +144,7 @@ public class ModelEstimateController {
} }
modelEstimate.setTaskStatus(taskStatus); modelEstimate.setTaskStatus(taskStatus);
modelEstimate.setUpdateTime(new Date()); modelEstimate.setUpdateTime(new Date());
if (!modelEstimateService.updateById(modelEstimate)) { if (!modelEstimateService.updateStatusById(modelEstimate)) {
errorMessage = "任务状态信息提交错误!"; errorMessage = "任务状态信息提交错误!";
return ResponseResult.error(ErrorCodeEnum.DATA_NOT_EXIST, errorMessage); return ResponseResult.error(ErrorCodeEnum.DATA_NOT_EXIST, errorMessage);
} }
......
package com.yice.webadmin.app.controller; package com.yice.webadmin.app.controller;
import com.alibaba.fastjson.JSONObject;
import com.github.pagehelper.page.PageMethod; import com.github.pagehelper.page.PageMethod;
import com.github.xiaoymin.knife4j.annotations.ApiOperationSupport; import com.github.xiaoymin.knife4j.annotations.ApiOperationSupport;
import com.yice.common.core.annotation.MyRequestBody; import com.yice.common.core.annotation.MyRequestBody;
...@@ -15,9 +14,11 @@ import com.yice.webadmin.app.dto.ModelDeployDto; ...@@ -15,9 +14,11 @@ import com.yice.webadmin.app.dto.ModelDeployDto;
import com.yice.webadmin.app.dto.ModelManageDto; import com.yice.webadmin.app.dto.ModelManageDto;
import com.yice.webadmin.app.dto.ModelTaskDto; import com.yice.webadmin.app.dto.ModelTaskDto;
import com.yice.webadmin.app.dto.ModelVersionDto; import com.yice.webadmin.app.dto.ModelVersionDto;
import com.yice.webadmin.app.model.*; import com.yice.webadmin.app.model.ModelDeploy;
import com.yice.webadmin.app.model.ModelManage;
import com.yice.webadmin.app.model.ModelTask;
import com.yice.webadmin.app.model.ModelVersion;
import com.yice.webadmin.app.service.ModelManageService; import com.yice.webadmin.app.service.ModelManageService;
import com.yice.webadmin.app.service.ModelVersionService;
import com.yice.webadmin.app.service.TuningRunService; import com.yice.webadmin.app.service.TuningRunService;
import com.yice.webadmin.app.vo.ModelManageVo; import com.yice.webadmin.app.vo.ModelManageVo;
import io.swagger.annotations.Api; import io.swagger.annotations.Api;
...@@ -63,7 +64,7 @@ public class ModelManageController { ...@@ -63,7 +64,7 @@ public class ModelManageController {
} }
ModelManage modelManage = MyModelUtil.copyTo(modelManageDto, ModelManage.class); ModelManage modelManage = MyModelUtil.copyTo(modelManageDto, ModelManage.class);
ModelVersion modelVersion = MyModelUtil.copyTo(modelVersionDto, ModelVersion.class); ModelVersion modelVersion = MyModelUtil.copyTo(modelVersionDto, ModelVersion.class);
modelManage = this.tuningRunService.createToModel(modelManage,modelVersion); modelManage = this.tuningRunService.createToModel(modelManage, modelVersion);
return ResponseResult.success(modelManage.getModelId()); return ResponseResult.success(modelManage.getModelId());
} }
...@@ -108,7 +109,6 @@ public class ModelManageController { ...@@ -108,7 +109,6 @@ public class ModelManageController {
} }
/** /**
* 删除模型管理数据。 * 删除模型管理数据。
* *
......
...@@ -129,11 +129,11 @@ public class ModelVersionController { ...@@ -129,11 +129,11 @@ public class ModelVersionController {
public ResponseResult<String> change(@MyRequestBody Long versionId) throws IOException { public ResponseResult<String> change(@MyRequestBody Long versionId) throws IOException {
ModelVersion modelVersion = this.modelVersionService.getById(versionId); ModelVersion modelVersion = this.modelVersionService.getById(versionId);
JSONObject jsonObject = new JSONObject(); JSONObject jsonObject = new JSONObject();
jsonObject.put("new_model_name",modelVersion.getVersionName()); jsonObject.put("new_model_name", modelVersion.getVersionName());
jsonObject.put("new_model_path",modelVersion.getModelUrl()); jsonObject.put("new_model_path", modelVersion.getModelUrl());
jsonObject.put("controller_address",pythonConfig.getControllerAddress()); jsonObject.put("controller_address", pythonConfig.getControllerAddress());
String url = this.pythonConfig.getChatAddress() + "llm_model/change"; String url = this.pythonConfig.getChatAddress() + "llm_model/change";
String result = proxyPythonService.predictPost(url,jsonObject.toJSONString()); String result = proxyPythonService.predictPost(url, jsonObject.toJSONString());
JSONObject jo = JSON.parseObject(result); JSONObject jo = JSON.parseObject(result);
Integer code = jo.getIntValue("code"); Integer code = jo.getIntValue("code");
String msg = jo.getString("msg"); String msg = jo.getString("msg");
...@@ -144,6 +144,7 @@ public class ModelVersionController { ...@@ -144,6 +144,7 @@ public class ModelVersionController {
return ResponseResult.create(ErrorCodeEnum.SERVER_INTERNAL_ERROR, msg, data); return ResponseResult.create(ErrorCodeEnum.SERVER_INTERNAL_ERROR, msg, data);
} }
} }
/** /**
* 停止指定的LLM模型(Model Worker) * 停止指定的LLM模型(Model Worker)
* *
...@@ -152,7 +153,7 @@ public class ModelVersionController { ...@@ -152,7 +153,7 @@ public class ModelVersionController {
@PostMapping("/stop") @PostMapping("/stop")
public ResponseResult<String> stop() throws IOException { public ResponseResult<String> stop() throws IOException {
JSONObject jsonObject = new JSONObject(); JSONObject jsonObject = new JSONObject();
jsonObject.put("controller_address",pythonConfig.getControllerAddress()); jsonObject.put("controller_address", pythonConfig.getControllerAddress());
String url = this.pythonConfig.getChatAddress() + "llm_model/stop"; String url = this.pythonConfig.getChatAddress() + "llm_model/stop";
String result = proxyPythonService.predictPost(url, jsonObject.toJSONString()); String result = proxyPythonService.predictPost(url, jsonObject.toJSONString());
JSONObject jo = JSON.parseObject(result); JSONObject jo = JSON.parseObject(result);
...@@ -165,6 +166,7 @@ public class ModelVersionController { ...@@ -165,6 +166,7 @@ public class ModelVersionController {
return ResponseResult.create(ErrorCodeEnum.SERVER_INTERNAL_ERROR, msg, data); return ResponseResult.create(ErrorCodeEnum.SERVER_INTERNAL_ERROR, msg, data);
} }
} }
/** /**
* 列出当前已加载的模型。 * 列出当前已加载的模型。
* *
...@@ -173,10 +175,10 @@ public class ModelVersionController { ...@@ -173,10 +175,10 @@ public class ModelVersionController {
@PostMapping("/listModels") @PostMapping("/listModels")
public ResponseResult<JSONArray> listModels() throws IOException { public ResponseResult<JSONArray> listModels() throws IOException {
JSONObject jsonObject = new JSONObject(); JSONObject jsonObject = new JSONObject();
jsonObject.put("placeholder","string"); jsonObject.put("placeholder", "string");
jsonObject.put("controller_address",pythonConfig.getControllerAddress()); jsonObject.put("controller_address", pythonConfig.getControllerAddress());
String url = this.pythonConfig.getChatAddress() + "llm_model/list_models"; String url = this.pythonConfig.getChatAddress() + "llm_model/list_models";
String result = proxyPythonService.predictPost(url,jsonObject.toJSONString()); String result = proxyPythonService.predictPost(url, jsonObject.toJSONString());
JSONObject jo = JSON.parseObject(result); JSONObject jo = JSON.parseObject(result);
Integer code = jo.getIntValue("code"); Integer code = jo.getIntValue("code");
String msg = jo.getString("msg"); String msg = jo.getString("msg");
......
...@@ -49,6 +49,6 @@ public class ProxyPythonController { ...@@ -49,6 +49,6 @@ public class ProxyPythonController {
@OperationLog(type = SysOperationLogType.OTHER) @OperationLog(type = SysOperationLogType.OTHER)
@PostMapping("/predict") @PostMapping("/predict")
public ResponseResult<String> predict(@RequestBody String requestBody) throws IOException { public ResponseResult<String> predict(@RequestBody String requestBody) throws IOException {
return ResponseResult.success(this.proxyPythonService.predictPost(pythonConfig.getFactoryInterface(),requestBody)); return ResponseResult.success(this.proxyPythonService.predictPost(pythonConfig.getFactoryInterface(), requestBody));
} }
} }
...@@ -49,6 +49,6 @@ public class ReceivePythonController { ...@@ -49,6 +49,6 @@ public class ReceivePythonController {
@OperationLog(type = SysOperationLogType.OTHER) @OperationLog(type = SysOperationLogType.OTHER)
@PostMapping("/predict") @PostMapping("/predict")
public ResponseResult<String> predict(@RequestBody String requestBody) throws IOException { public ResponseResult<String> predict(@RequestBody String requestBody) throws IOException {
return ResponseResult.success(this.proxyPythonService.predictPost(pythonConfig.getFactoryInterface(),requestBody)); return ResponseResult.success(this.proxyPythonService.predictPost(pythonConfig.getFactoryInterface(), requestBody));
} }
} }
...@@ -55,12 +55,10 @@ public class TuningRunController { ...@@ -55,12 +55,10 @@ public class TuningRunController {
@Autowired @Autowired
private ModelVersionService modelVersionService; private ModelVersionService modelVersionService;
@Autowired @Autowired
private ModelManageService modelManageService; private ModelTaskService modelTaskService;
@Autowired @Autowired
private DatasetVersionService datasetVersionService; private DatasetVersionService datasetVersionService;
@Autowired @Autowired
private DatasetManageService datasetManageService;
@Autowired
private PythonConfig pythonConfig; private PythonConfig pythonConfig;
/** /**
...@@ -283,7 +281,7 @@ public class TuningRunController { ...@@ -283,7 +281,7 @@ public class TuningRunController {
tuningRun.setRunStatus(runStatus); tuningRun.setRunStatus(runStatus);
tuningRun.setRunTime(runTime); tuningRun.setRunTime(runTime);
tuningRun.setUpdateTime(new Date()); tuningRun.setUpdateTime(new Date());
if (!tuningRunService.updateById(tuningRun)) { if (!tuningRunService.updateStatusById(tuningRun)) {
errorMessage = "运行状态信息提交错误!"; errorMessage = "运行状态信息提交错误!";
return ResponseResult.error(ErrorCodeEnum.DATA_NOT_EXIST, errorMessage); return ResponseResult.error(ErrorCodeEnum.DATA_NOT_EXIST, errorMessage);
} }
......
...@@ -4,7 +4,7 @@ import com.yice.common.core.base.dao.BaseDaoMapper; ...@@ -4,7 +4,7 @@ import com.yice.common.core.base.dao.BaseDaoMapper;
import com.yice.webadmin.app.model.KnowledgeManage; import com.yice.webadmin.app.model.KnowledgeManage;
import org.apache.ibatis.annotations.Param; import org.apache.ibatis.annotations.Param;
import java.util.*; import java.util.List;
/** /**
* 知识库管理数据操作访问接口。 * 知识库管理数据操作访问接口。
...@@ -25,7 +25,7 @@ public interface KnowledgeManageMapper extends BaseDaoMapper<KnowledgeManage> { ...@@ -25,7 +25,7 @@ public interface KnowledgeManageMapper extends BaseDaoMapper<KnowledgeManage> {
* 获取过滤后的对象列表。 * 获取过滤后的对象列表。
* *
* @param knowledgeManageFilter 主表过滤对象。 * @param knowledgeManageFilter 主表过滤对象。
* @param orderBy 排序字符串,order by从句的参数。 * @param orderBy 排序字符串,order by从句的参数。
* @return 对象列表。 * @return 对象列表。
*/ */
List<KnowledgeManage> getKnowledgeManageList( List<KnowledgeManage> getKnowledgeManageList(
......
...@@ -13,15 +13,15 @@ ...@@ -13,15 +13,15 @@
<insert id="insertList"> <insert id="insertList">
INSERT INTO lmp_knowledge_manage INSERT INTO lmp_knowledge_manage
(knowledge_id, (knowledge_id,
create_user_id, create_user_id,
create_time, create_time,
update_user_id, update_user_id,
update_time, update_time,
knowledge_name, knowledge_name,
knowledge_describe) knowledge_describe)
VALUES VALUES
<foreach collection="list" index="index" item="item" separator="," > <foreach collection="list" index="index" item="item" separator=",">
(#{item.knowledgeId}, (#{item.knowledgeId},
#{item.createUserId}, #{item.createUserId},
#{item.createTime}, #{item.createTime},
...@@ -51,13 +51,14 @@ ...@@ -51,13 +51,14 @@
AND lmp_knowledge_manage.knowledge_name = #{knowledgeManageFilter.knowledgeName} AND lmp_knowledge_manage.knowledge_name = #{knowledgeManageFilter.knowledgeName}
</if> </if>
<if test="knowledgeManageFilter.searchString != null and knowledgeManageFilter.searchString != ''"> <if test="knowledgeManageFilter.searchString != null and knowledgeManageFilter.searchString != ''">
<bind name = "safeKnowledgeManageSearchString" value = "'%' + knowledgeManageFilter.searchString + '%'" /> <bind name="safeKnowledgeManageSearchString" value="'%' + knowledgeManageFilter.searchString + '%'"/>
AND IFNULL(lmp_knowledge_manage.knowledge_name,'') LIKE #{safeKnowledgeManageSearchString} AND IFNULL(lmp_knowledge_manage.knowledge_name,'') LIKE #{safeKnowledgeManageSearchString}
</if> </if>
</if> </if>
</sql> </sql>
<select id="getKnowledgeManageList" resultMap="BaseResultMap" parameterType="com.yice.webadmin.app.model.KnowledgeManage"> <select id="getKnowledgeManageList" resultMap="BaseResultMap"
parameterType="com.yice.webadmin.app.model.KnowledgeManage">
SELECT * FROM lmp_knowledge_manage SELECT * FROM lmp_knowledge_manage
<where> <where>
<include refid="filterRef"/> <include refid="filterRef"/>
......
...@@ -87,7 +87,7 @@ ...@@ -87,7 +87,7 @@
#{item.taskId}, #{item.taskId},
#{item.runId}, #{item.runId},
#{item.basePromptTemplate} #{item.basePromptTemplate}
) )
</foreach> </foreach>
</insert> </insert>
......
package com.yice.webadmin.app.dto; package com.yice.webadmin.app.dto;
import com.yice.common.core.validator.UpdateGroup; import com.yice.common.core.validator.UpdateGroup;
import io.swagger.annotations.ApiModel; import io.swagger.annotations.ApiModel;
import io.swagger.annotations.ApiModelProperty; import io.swagger.annotations.ApiModelProperty;
import lombok.Data; import lombok.Data;
import javax.validation.constraints.*; import javax.validation.constraints.NotBlank;
import javax.validation.constraints.NotNull;
/** /**
* KnowledgeManageDto对象。 * KnowledgeManageDto对象。
......
package com.yice.webadmin.app.model; package com.yice.webadmin.app.model;
import com.baomidou.mybatisplus.annotation.*; import com.baomidou.mybatisplus.annotation.TableField;
import com.yice.common.core.util.MyCommonUtil; import com.baomidou.mybatisplus.annotation.TableId;
import com.yice.common.core.base.model.BaseModel; import com.baomidou.mybatisplus.annotation.TableName;
import com.yice.common.core.base.mapper.BaseModelMapper; import com.yice.common.core.base.mapper.BaseModelMapper;
import com.yice.common.core.base.model.BaseModel;
import com.yice.common.core.util.MyCommonUtil;
import com.yice.webadmin.app.vo.KnowledgeManageVo; import com.yice.webadmin.app.vo.KnowledgeManageVo;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import org.mapstruct.*; import org.mapstruct.Mapper;
import org.mapstruct.factory.Mappers; import org.mapstruct.factory.Mappers;
/** /**
...@@ -50,5 +52,6 @@ public class KnowledgeManage extends BaseModel { ...@@ -50,5 +52,6 @@ public class KnowledgeManage extends BaseModel {
@Mapper @Mapper
public interface KnowledgeManageModelMapper extends BaseModelMapper<KnowledgeManageVo, KnowledgeManage> { public interface KnowledgeManageModelMapper extends BaseModelMapper<KnowledgeManageVo, KnowledgeManage> {
} }
public static final KnowledgeManageModelMapper INSTANCE = Mappers.getMapper(KnowledgeManageModelMapper.class); public static final KnowledgeManageModelMapper INSTANCE = Mappers.getMapper(KnowledgeManageModelMapper.class);
} }
...@@ -8,7 +8,6 @@ import com.yice.common.core.base.mapper.BaseModelMapper; ...@@ -8,7 +8,6 @@ import com.yice.common.core.base.mapper.BaseModelMapper;
import com.yice.common.core.base.model.BaseModel; import com.yice.common.core.base.model.BaseModel;
import com.yice.common.core.util.MyCommonUtil; import com.yice.common.core.util.MyCommonUtil;
import com.yice.webadmin.app.vo.ModelCompressVo; import com.yice.webadmin.app.vo.ModelCompressVo;
import io.swagger.annotations.ApiModelProperty;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import org.mapstruct.Mapper; import org.mapstruct.Mapper;
......
package com.yice.webadmin.app.service; package com.yice.webadmin.app.service;
import com.yice.webadmin.app.model.*;
import com.yice.common.core.base.service.IBaseService; import com.yice.common.core.base.service.IBaseService;
import com.yice.webadmin.app.model.KnowledgeManage;
import java.util.*; import java.util.List;
/** /**
* 知识库管理数据操作服务接口。 * 知识库管理数据操作服务接口。
...@@ -60,7 +60,7 @@ public interface KnowledgeManageService extends IBaseService<KnowledgeManage, Lo ...@@ -60,7 +60,7 @@ public interface KnowledgeManageService extends IBaseService<KnowledgeManage, Lo
* 该查询会涉及到一对一从表的关联过滤,或一对多从表的嵌套关联过滤,因此性能不如单表过滤。 * 该查询会涉及到一对一从表的关联过滤,或一对多从表的嵌套关联过滤,因此性能不如单表过滤。
* 如果仅仅需要获取主表数据,请移步(getKnowledgeManageList),以便获取更好的查询性能。 * 如果仅仅需要获取主表数据,请移步(getKnowledgeManageList),以便获取更好的查询性能。
* *
* @param filter 主表过滤对象。 * @param filter 主表过滤对象。
* @param orderBy 排序参数。 * @param orderBy 排序参数。
* @return 查询结果集。 * @return 查询结果集。
*/ */
......
...@@ -5,7 +5,6 @@ import com.yice.common.core.base.service.IBaseService; ...@@ -5,7 +5,6 @@ import com.yice.common.core.base.service.IBaseService;
import com.yice.webadmin.app.model.ModelCompress; import com.yice.webadmin.app.model.ModelCompress;
import com.yice.webadmin.app.model.ModelTask; import com.yice.webadmin.app.model.ModelTask;
import java.net.URISyntaxException;
import java.util.List; import java.util.List;
/** /**
...@@ -31,15 +30,6 @@ public interface ModelCompressService extends IBaseService<ModelCompress, Long> ...@@ -31,15 +30,6 @@ public interface ModelCompressService extends IBaseService<ModelCompress, Long>
*/ */
void saveNewBatch(List<ModelCompress> modelCompressList); void saveNewBatch(List<ModelCompress> modelCompressList);
/**
* 保存新增主表对象及关联对象。
*
* @param modelCompress 新增主表对象。
* @param relationData 全部关联从表数据。
* @return 返回新增主表对象。
*/
ModelCompress saveNewWithRelation(ModelCompress modelCompress, JSONObject relationData);
/** /**
* 更新数据对象。 * 更新数据对象。
* *
......
...@@ -87,4 +87,6 @@ public interface ModelEstimateService extends IBaseService<ModelEstimate, Long> ...@@ -87,4 +87,6 @@ public interface ModelEstimateService extends IBaseService<ModelEstimate, Long>
* @return 查询结果集。 * @return 查询结果集。
*/ */
List<ModelEstimate> getModelEstimateListWithRelation(ModelEstimate filter, ModelTask modelTaskFilter, String orderBy); List<ModelEstimate> getModelEstimateListWithRelation(ModelEstimate filter, ModelTask modelTaskFilter, String orderBy);
boolean updateStatusById(ModelEstimate modelEstimate);
} }
...@@ -83,6 +83,7 @@ public interface ModelManageService extends IBaseService<ModelManage, Long> { ...@@ -83,6 +83,7 @@ public interface ModelManageService extends IBaseService<ModelManage, Long> {
List<ModelManage> getModelManageListWithRelation(ModelManage filter, ModelVersion modelVersionFilter, ModelTask modelTaskFilter, ModelDeploy modelDeployFilter, String orderBy); List<ModelManage> getModelManageListWithRelation(ModelManage filter, ModelVersion modelVersionFilter, ModelTask modelTaskFilter, ModelDeploy modelDeployFilter, String orderBy);
ModelManage saveAndCreateVersion(ModelManage modelManage, ModelVersion modelVersion); ModelManage saveAndCreateVersion(ModelManage modelManage, ModelVersion modelVersion);
ModelVersion saveAndCreateVersionV(ModelManage modelManage, ModelVersion modelVersion); ModelVersion saveAndCreateVersionV(ModelManage modelManage, ModelVersion modelVersion);
} }
...@@ -8,9 +8,11 @@ import java.io.IOException; ...@@ -8,9 +8,11 @@ import java.io.IOException;
* @author linking * @author linking
* @date 2023-04-13 * @date 2023-04-13
*/ */
public interface ProxyPythonService{ public interface ProxyPythonService {
public String predictPost(String url, String requestBody) throws IOException; public String predictPost(String url, String requestBody) throws IOException;
public String predictGet(String url, String requestBody) throws IOException; public String predictGet(String url, String requestBody) throws IOException;
public String predictPostForFile(String url, String requestBody) throws IOException; public String predictPostForFile(String url, String requestBody) throws IOException;
} }
...@@ -82,4 +82,6 @@ public interface TuningRunService extends IBaseService<TuningRun, Long> { ...@@ -82,4 +82,6 @@ public interface TuningRunService extends IBaseService<TuningRun, Long> {
ModelManage createToModel(ModelManage modelManage, ModelVersion modelVersion); ModelManage createToModel(ModelManage modelManage, ModelVersion modelVersion);
ModelVersion createToModelVersion(ModelVersion modelVersion); ModelVersion createToModelVersion(ModelVersion modelVersion);
boolean updateStatusById(TuningRun tuningRun);
} }
...@@ -10,12 +10,10 @@ import com.yice.common.core.object.CallResult; ...@@ -10,12 +10,10 @@ import com.yice.common.core.object.CallResult;
import com.yice.common.core.object.MyRelationParam; import com.yice.common.core.object.MyRelationParam;
import com.yice.common.core.util.MyModelUtil; import com.yice.common.core.util.MyModelUtil;
import com.yice.common.sequence.wrapper.IdGeneratorWrapper; import com.yice.common.sequence.wrapper.IdGeneratorWrapper;
import com.yice.webadmin.app.controller.DatasetManageController;
import com.yice.webadmin.app.dao.DatasetVersionMapper; import com.yice.webadmin.app.dao.DatasetVersionMapper;
import com.yice.webadmin.app.model.DatasetDetail; import com.yice.webadmin.app.model.DatasetDetail;
import com.yice.webadmin.app.model.DatasetManage; import com.yice.webadmin.app.model.DatasetManage;
import com.yice.webadmin.app.model.DatasetVersion; import com.yice.webadmin.app.model.DatasetVersion;
import com.yice.webadmin.app.model.ModelVersion;
import com.yice.webadmin.app.service.DatasetDetailService; import com.yice.webadmin.app.service.DatasetDetailService;
import com.yice.webadmin.app.service.DatasetManageService; import com.yice.webadmin.app.service.DatasetManageService;
import com.yice.webadmin.app.service.DatasetVersionService; import com.yice.webadmin.app.service.DatasetVersionService;
...@@ -67,13 +65,13 @@ public class DatasetVersionServiceImpl extends BaseService<DatasetVersion, Long> ...@@ -67,13 +65,13 @@ public class DatasetVersionServiceImpl extends BaseService<DatasetVersion, Long>
DatasetManage reDatasetManage = this.datasetManageService.getById(datasetVersion.getDatasetId()); DatasetManage reDatasetManage = this.datasetManageService.getById(datasetVersion.getDatasetId());
DatasetVersion datasetVersionFilter = new DatasetVersion(); DatasetVersion datasetVersionFilter = new DatasetVersion();
datasetVersionFilter.setDatasetId(datasetVersion.getDatasetId()); datasetVersionFilter.setDatasetId(datasetVersion.getDatasetId());
List<DatasetVersion> datasetVersionList = this.getDatasetVersionList(datasetVersionFilter,"dataset_version"); List<DatasetVersion> datasetVersionList = this.getDatasetVersionList(datasetVersionFilter, "dataset_version");
Integer version = 1; Integer version = 1;
if (datasetVersionList != null && datasetVersionList.size() != 0) { if (datasetVersionList != null && datasetVersionList.size() != 0) {
version = datasetVersionList.get(datasetVersionList.size() - 1).getDatasetVersion() + 1; version = datasetVersionList.get(datasetVersionList.size() - 1).getDatasetVersion() + 1;
} }
datasetVersion.setDatasetVersion(version); datasetVersion.setDatasetVersion(version);
datasetVersion.setVersionName(reDatasetManage.getDatasetName()+"_V"+datasetVersion.getDatasetVersion()); datasetVersion.setVersionName(reDatasetManage.getDatasetName() + "_V" + datasetVersion.getDatasetVersion());
datasetVersion.setDatasetId(reDatasetManage.getDatasetId()); datasetVersion.setDatasetId(reDatasetManage.getDatasetId());
datasetVersion.setCleanStatus(0); datasetVersion.setCleanStatus(0);
datasetVersion.setDataVolume(0L); datasetVersion.setDataVolume(0L);
......
...@@ -8,17 +8,14 @@ import com.yice.common.core.base.service.BaseService; ...@@ -8,17 +8,14 @@ import com.yice.common.core.base.service.BaseService;
import com.yice.common.core.object.MyRelationParam; import com.yice.common.core.object.MyRelationParam;
import com.yice.common.core.util.MyModelUtil; import com.yice.common.core.util.MyModelUtil;
import com.yice.common.sequence.wrapper.IdGeneratorWrapper; import com.yice.common.sequence.wrapper.IdGeneratorWrapper;
import com.yice.webadmin.app.config.KnowledgeConfig;
import com.yice.webadmin.app.dao.KnowledgeManageMapper; import com.yice.webadmin.app.dao.KnowledgeManageMapper;
import com.yice.webadmin.app.model.KnowledgeManage; import com.yice.webadmin.app.model.KnowledgeManage;
import com.yice.webadmin.app.service.KnowledgeManageService; import com.yice.webadmin.app.service.KnowledgeManageService;
import com.yice.webadmin.app.service.ProxyPythonService;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional; import org.springframework.transaction.annotation.Transactional;
import java.io.IOException;
import java.util.List; import java.util.List;
/** /**
......
...@@ -12,13 +12,16 @@ import com.yice.common.sequence.wrapper.IdGeneratorWrapper; ...@@ -12,13 +12,16 @@ import com.yice.common.sequence.wrapper.IdGeneratorWrapper;
import com.yice.webadmin.app.dao.ModelEstimateMapper; import com.yice.webadmin.app.dao.ModelEstimateMapper;
import com.yice.webadmin.app.model.ModelEstimate; import com.yice.webadmin.app.model.ModelEstimate;
import com.yice.webadmin.app.model.ModelTask; import com.yice.webadmin.app.model.ModelTask;
import com.yice.webadmin.app.model.ModelVersion;
import com.yice.webadmin.app.service.ModelEstimateService; import com.yice.webadmin.app.service.ModelEstimateService;
import com.yice.webadmin.app.service.ModelTaskService; import com.yice.webadmin.app.service.ModelTaskService;
import com.yice.webadmin.app.service.ModelVersionService;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional; import org.springframework.transaction.annotation.Transactional;
import java.util.Date;
import java.util.List; import java.util.List;
/** /**
...@@ -37,6 +40,8 @@ public class ModelEstimateServiceImpl extends BaseService<ModelEstimate, Long> i ...@@ -37,6 +40,8 @@ public class ModelEstimateServiceImpl extends BaseService<ModelEstimate, Long> i
private ModelTaskService modelTaskService; private ModelTaskService modelTaskService;
@Autowired @Autowired
private IdGeneratorWrapper idGenerator; private IdGeneratorWrapper idGenerator;
@Autowired
private ModelVersionService modelVersionService;
/** /**
* 返回当前Service的主表Mapper对象。 * 返回当前Service的主表Mapper对象。
...@@ -57,7 +62,18 @@ public class ModelEstimateServiceImpl extends BaseService<ModelEstimate, Long> i ...@@ -57,7 +62,18 @@ public class ModelEstimateServiceImpl extends BaseService<ModelEstimate, Long> i
@Transactional(rollbackFor = Exception.class) @Transactional(rollbackFor = Exception.class)
@Override @Override
public ModelEstimate saveNew(ModelEstimate modelEstimate) { public ModelEstimate saveNew(ModelEstimate modelEstimate) {
modelEstimate.setTaskStatus(0);
modelEstimateMapper.insert(this.buildDefaultValue(modelEstimate)); modelEstimateMapper.insert(this.buildDefaultValue(modelEstimate));
ModelVersion modelVersion = this.modelVersionService.getById(modelEstimate.getModelVersionId());
ModelTask modelTask = new ModelTask();
modelTask.setVersionId(modelVersion.getVersionId());
modelTask.setTaskType(1);
modelTask.setTaskStatus(0);
modelTask.setModelId(modelVersion.getModelId());
modelTask.setModelVersion(modelVersion.getModelVersion());
modelTask.setVersionName(modelVersion.getVersionName());
modelTask.setTaskId(modelEstimate.getTaskId());
this.modelTaskService.saveNew(modelTask);
return modelEstimate; return modelEstimate;
} }
...@@ -172,6 +188,16 @@ public class ModelEstimateServiceImpl extends BaseService<ModelEstimate, Long> i ...@@ -172,6 +188,16 @@ public class ModelEstimateServiceImpl extends BaseService<ModelEstimate, Long> i
return resultList; return resultList;
} }
@Transactional(rollbackFor = Exception.class)
@Override
public boolean updateStatusById(ModelEstimate modelEstimate) {
ModelTask modelTask = this.modelTaskService.getById(modelEstimate.getTaskId());
modelTask.setTaskStatus(modelEstimate.getTaskStatus());
modelTask.setCompleteTime(new Date());
modelTaskService.updateById(modelTask);
return this.updateById(modelEstimate);
}
private ModelEstimate buildDefaultValue(ModelEstimate modelEstimate) { private ModelEstimate buildDefaultValue(ModelEstimate modelEstimate) {
if (modelEstimate.getTaskId() == null) { if (modelEstimate.getTaskId() == null) {
modelEstimate.setTaskId(idGenerator.nextLongId()); modelEstimate.setTaskId(idGenerator.nextLongId());
......
...@@ -11,8 +11,14 @@ import com.yice.common.core.object.MyRelationParam; ...@@ -11,8 +11,14 @@ import com.yice.common.core.object.MyRelationParam;
import com.yice.common.core.util.MyModelUtil; import com.yice.common.core.util.MyModelUtil;
import com.yice.common.sequence.wrapper.IdGeneratorWrapper; import com.yice.common.sequence.wrapper.IdGeneratorWrapper;
import com.yice.webadmin.app.dao.ModelManageMapper; import com.yice.webadmin.app.dao.ModelManageMapper;
import com.yice.webadmin.app.model.*; import com.yice.webadmin.app.model.ModelDeploy;
import com.yice.webadmin.app.service.*; import com.yice.webadmin.app.model.ModelManage;
import com.yice.webadmin.app.model.ModelTask;
import com.yice.webadmin.app.model.ModelVersion;
import com.yice.webadmin.app.service.ModelDeployService;
import com.yice.webadmin.app.service.ModelManageService;
import com.yice.webadmin.app.service.ModelTaskService;
import com.yice.webadmin.app.service.ModelVersionService;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
...@@ -42,8 +48,6 @@ public class ModelManageServiceImpl extends BaseService<ModelManage, Long> imple ...@@ -42,8 +48,6 @@ public class ModelManageServiceImpl extends BaseService<ModelManage, Long> imple
private ModelDeployService modelDeployService; private ModelDeployService modelDeployService;
@Autowired @Autowired
private IdGeneratorWrapper idGenerator; private IdGeneratorWrapper idGenerator;
@Autowired
private TuningRunService tuningRunService;
/** /**
* 返回当前Service的主表Mapper对象。 * 返回当前Service的主表Mapper对象。
...@@ -186,6 +190,7 @@ public class ModelManageServiceImpl extends BaseService<ModelManage, Long> imple ...@@ -186,6 +190,7 @@ public class ModelManageServiceImpl extends BaseService<ModelManage, Long> imple
this.modelVersionService.saveNew(modelVersion); this.modelVersionService.saveNew(modelVersion);
return reModelManage; return reModelManage;
} }
@Override @Override
public ModelVersion saveAndCreateVersionV(ModelManage modelManage, ModelVersion modelVersion) { public ModelVersion saveAndCreateVersionV(ModelManage modelManage, ModelVersion modelVersion) {
ModelManage reModelManage = this.saveNew(modelManage); ModelManage reModelManage = this.saveNew(modelManage);
......
...@@ -6,14 +6,11 @@ import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper; ...@@ -6,14 +6,11 @@ import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper;
import com.github.pagehelper.Page; import com.github.pagehelper.Page;
import com.yice.common.core.base.dao.BaseDaoMapper; import com.yice.common.core.base.dao.BaseDaoMapper;
import com.yice.common.core.base.service.BaseService; import com.yice.common.core.base.service.BaseService;
import com.yice.common.core.constant.ErrorCodeEnum;
import com.yice.common.core.object.CallResult; import com.yice.common.core.object.CallResult;
import com.yice.common.core.object.MyRelationParam; import com.yice.common.core.object.MyRelationParam;
import com.yice.common.core.object.ResponseResult;
import com.yice.common.core.util.MyModelUtil; import com.yice.common.core.util.MyModelUtil;
import com.yice.common.sequence.wrapper.IdGeneratorWrapper; import com.yice.common.sequence.wrapper.IdGeneratorWrapper;
import com.yice.webadmin.app.dao.ModelVersionMapper; import com.yice.webadmin.app.dao.ModelVersionMapper;
import com.yice.webadmin.app.model.ModelTask;
import com.yice.webadmin.app.model.ModelVersion; import com.yice.webadmin.app.model.ModelVersion;
import com.yice.webadmin.app.service.ModelManageService; import com.yice.webadmin.app.service.ModelManageService;
import com.yice.webadmin.app.service.ModelTaskService; import com.yice.webadmin.app.service.ModelTaskService;
...@@ -65,7 +62,7 @@ public class ModelVersionServiceImpl extends BaseService<ModelVersion, Long> imp ...@@ -65,7 +62,7 @@ public class ModelVersionServiceImpl extends BaseService<ModelVersion, Long> imp
String modelName = this.modelManageService.getById(modelVersion.getModelId()).getModelName(); String modelName = this.modelManageService.getById(modelVersion.getModelId()).getModelName();
ModelVersion modelVersionFilter = new ModelVersion(); ModelVersion modelVersionFilter = new ModelVersion();
modelVersionFilter.setModelId(modelVersion.getModelId()); modelVersionFilter.setModelId(modelVersion.getModelId());
List<ModelVersion> modelVersionList = this.getModelVersionList(modelVersionFilter,"model_Version"); List<ModelVersion> modelVersionList = this.getModelVersionList(modelVersionFilter, "model_Version");
Integer version = 1; Integer version = 1;
if (modelVersionList != null && modelVersionList.size() != 0) { if (modelVersionList != null && modelVersionList.size() != 0) {
version = modelVersionList.get(modelVersionList.size() - 1).getModelVersion() + 1; version = modelVersionList.get(modelVersionList.size() - 1).getModelVersion() + 1;
......
package com.yice.webadmin.app.service.impl; package com.yice.webadmin.app.service.impl;
import com.yice.webadmin.app.config.PythonConfig;
import com.yice.webadmin.app.service.ProxyPythonService; import com.yice.webadmin.app.service.ProxyPythonService;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.apache.http.HttpEntity; import org.apache.http.HttpEntity;
...@@ -12,7 +11,6 @@ import org.apache.http.entity.StringEntity; ...@@ -12,7 +11,6 @@ import org.apache.http.entity.StringEntity;
import org.apache.http.impl.client.CloseableHttpClient; import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClients; import org.apache.http.impl.client.HttpClients;
import org.apache.http.util.EntityUtils; import org.apache.http.util.EntityUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import java.io.IOException; import java.io.IOException;
...@@ -51,6 +49,7 @@ public class ProxyPythonServiceImpl implements ProxyPythonService { ...@@ -51,6 +49,7 @@ public class ProxyPythonServiceImpl implements ProxyPythonService {
} }
return null; return null;
} }
public String predictPostForFile(String url, String requestBody) throws IOException { public String predictPostForFile(String url, String requestBody) throws IOException {
CloseableHttpClient httpClient = HttpClients.createDefault(); CloseableHttpClient httpClient = HttpClients.createDefault();
try { try {
...@@ -76,6 +75,7 @@ public class ProxyPythonServiceImpl implements ProxyPythonService { ...@@ -76,6 +75,7 @@ public class ProxyPythonServiceImpl implements ProxyPythonService {
} }
return null; return null;
} }
public String predictGet(String url, String requestBody) throws IOException { public String predictGet(String url, String requestBody) throws IOException {
CloseableHttpClient httpClient = HttpClients.createDefault(); CloseableHttpClient httpClient = HttpClients.createDefault();
try { try {
......
...@@ -186,7 +186,7 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement ...@@ -186,7 +186,7 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
ModelVersion modelVersionSource = this.modelVersionService.getById(tuningRun.getModelVersionId()); ModelVersion modelVersionSource = this.modelVersionService.getById(tuningRun.getModelVersionId());
ModelTask modelTask = new ModelTask(); ModelTask modelTask = new ModelTask();
ModelVersion modelVersionTarget = saveAll(tuningRun, runPublishDto, modelVersionSource, modelTask); ModelVersion modelVersionTarget = saveAll(tuningRun, runPublishDto, modelVersionSource, modelTask);
messageWithSocket(tuningRun, modelVersionSource, pythonConfig.getModelOutputFileBaseDir() + modelVersionTarget.getVersionName(), modelVersionTarget, modelTask); messageWithSocket(tuningRun, modelVersionSource, modelVersionTarget, modelTask);
return true; return true;
} }
...@@ -197,7 +197,7 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement ...@@ -197,7 +197,7 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
ModelVersion modelVersionSource = this.modelVersionService.getById(tuningRun.getModelVersionId()); ModelVersion modelVersionSource = this.modelVersionService.getById(tuningRun.getModelVersionId());
ModelTask modelTask = new ModelTask(); ModelTask modelTask = new ModelTask();
ModelVersion modelVersionTarget = saveAll(tuningRun, modelManage, modelVersion, modelVersionSource, modelTask); ModelVersion modelVersionTarget = saveAll(tuningRun, modelManage, modelVersion, modelVersionSource, modelTask);
messageWithSocket(tuningRun, modelVersionSource, pythonConfig.getModelOutputFileBaseDir() + modelVersionTarget.getVersionName(), modelVersionTarget, modelTask); messageWithSocket(tuningRun, modelVersionSource, modelVersionTarget, modelTask);
return modelManage; return modelManage;
} }
...@@ -208,13 +208,23 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement ...@@ -208,13 +208,23 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
ModelVersion modelVersionSource = this.modelVersionService.getById(tuningRun.getModelVersionId()); ModelVersion modelVersionSource = this.modelVersionService.getById(tuningRun.getModelVersionId());
ModelTask modelTask = new ModelTask(); ModelTask modelTask = new ModelTask();
ModelVersion modelVersionTarget = saveAll(tuningRun, modelVersion, modelVersionSource, modelTask); ModelVersion modelVersionTarget = saveAll(tuningRun, modelVersion, modelVersionSource, modelTask);
messageWithSocket(tuningRun, modelVersionSource, pythonConfig.getModelOutputFileBaseDir() + modelVersionTarget.getVersionName(), modelVersionTarget, modelTask); messageWithSocket(tuningRun, modelVersionSource, modelVersionTarget, modelTask);
return modelVersionTarget; return modelVersionTarget;
} }
@Transactional(rollbackFor = Exception.class)
@Override
public boolean updateStatusById(TuningRun tuningRun) {
ModelTask modelTask = this.modelTaskService.getById(tuningRun.getTaskId());
modelTask.setTaskStatus(tuningRun.getRunStatus());
modelTask.setCompleteTime(new Date());
modelTaskService.updateById(modelTask);
return this.updateById(tuningRun);
}
@SneakyThrows @SneakyThrows
private void messageWithSocket(TuningRun tuningRun, ModelVersion modelVersionSource, String targetModelVersionURl, ModelVersion modelVersionTarget, ModelTask modelTask) { private void messageWithSocket(TuningRun tuningRun, ModelVersion modelVersionSource, ModelVersion modelVersionTarget, ModelTask modelTask) {
new WebSocketClient(new URI(this.pythonConfig.getPythonWebsocketUri())) { new WebSocketClient(new URI(this.pythonConfig.getPythonWebsocketUri())) {
@Override @Override
public void onOpen(ServerHandshake serverHandshake) { public void onOpen(ServerHandshake serverHandshake) {
...@@ -246,7 +256,7 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement ...@@ -246,7 +256,7 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
array.add(tuningRun.getTrainMethod()); array.add(tuningRun.getTrainMethod());
array.add(jsonObject.get("promptTemplate")); array.add(jsonObject.get("promptTemplate"));
array.add(2); array.add(2);
array.add(targetModelVersionURl); array.add(pythonConfig.getModelOutputFileBaseDir() + modelVersionTarget.getVersionName());
array.add("none"); array.add("none");
sendJson.put("data", array); sendJson.put("data", array);
sendJson.put("event_data", "null"); sendJson.put("event_data", "null");
...@@ -257,7 +267,7 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement ...@@ -257,7 +267,7 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
log.info("发送服务端的消息:" + sendJson.toJSONString()); log.info("发送服务端的消息:" + sendJson.toJSONString());
if (receiveMsg.equals("process_completed")) { if (receiveMsg.equals("process_completed")) {
System.out.println("isSuccess4:" + System.currentTimeMillis()); System.out.println("isSuccess4:" + System.currentTimeMillis());
updateStatus(receiveJson.getBoolean("success"), tuningRun, modelVersionTarget, targetModelVersionURl, modelTask); updateStatus(receiveJson.getBoolean("success"), tuningRun, modelVersionTarget, modelTask);
} }
} }
...@@ -336,9 +346,9 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement ...@@ -336,9 +346,9 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
modelTask.setTaskId(tuningRun.getTaskId()); modelTask.setTaskId(tuningRun.getTaskId());
} }
public Boolean updateStatus(Boolean flag, TuningRun tuningRun, ModelVersion modelVersionTarget, String targetModelVersionURl, ModelTask modelTask) { public Boolean updateStatus(Boolean flag, TuningRun tuningRun, ModelVersion modelVersionTarget, ModelTask modelTask) {
if (flag) { if (flag) {
modelVersionTarget.setModelUrl(targetModelVersionURl); modelVersionTarget.setModelUrl(pythonConfig.getModelOutputFileBaseDir() + modelVersionTarget.getVersionName());
modelVersionTarget.setStatus(1); modelVersionTarget.setStatus(1);
modelTask.setTaskStatus(1); modelTask.setTaskStatus(1);
tuningRun.setPublishStatus(1); tuningRun.setPublishStatus(1);
......
...@@ -6,8 +6,6 @@ import io.swagger.annotations.ApiModelProperty; ...@@ -6,8 +6,6 @@ import io.swagger.annotations.ApiModelProperty;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import java.util.Date;
/** /**
* KnowledgeManageVO视图对象。 * KnowledgeManageVO视图对象。
* *
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment