Commit 5ec4476b authored by linpeiqin's avatar linpeiqin

修改模型导出调用的核心逻辑

parent 7297ac91
......@@ -4,6 +4,7 @@ import com.alibaba.fastjson.JSONObject;
import com.github.pagehelper.page.PageMethod;
import com.github.xiaoymin.knife4j.annotations.ApiOperationSupport;
import com.yice.common.core.annotation.MyRequestBody;
import com.yice.common.core.annotation.NoAuthInterface;
import com.yice.common.core.constant.ErrorCodeEnum;
import com.yice.common.core.object.*;
import com.yice.common.core.util.MyCommonUtil;
......@@ -25,6 +26,7 @@ import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*;
import java.util.Date;
import java.util.List;
/**
......@@ -77,6 +79,27 @@ public class ModelCompressController {
return ResponseResult.success(modelCompress.getTaskId());
}
@NoAuthInterface
@OperationLog(type = SysOperationLogType.UPLOAD, saveResponse = false)
@PostMapping("/postCompressStatus")
public ResponseResult<Void> postCompressStatus(@MyRequestBody Long taskId,
@MyRequestBody Long runTime,
@MyRequestBody Integer taskStatus) {
String errorMessage;
ModelCompress modelCompress = this.modelCompressService.getById(taskId);
if (modelCompress == null) {
errorMessage = "数据验证失败,当前模型压缩并不存在,请联系管理员!";
return ResponseResult.error(ErrorCodeEnum.DATA_NOT_EXIST, errorMessage);
}
modelCompress.setTaskStatus(taskStatus);
modelCompress.setUpdateTime(new Date());
if (!modelCompressService.updateById(modelCompress)) {
errorMessage = "型压缩信息提交错误!";
return ResponseResult.error(ErrorCodeEnum.DATA_NOT_EXIST, errorMessage);
}
return ResponseResult.success();
}
/**
* 修改模型压缩数据,及其关联的从表数据。
*
......
......@@ -35,6 +35,7 @@ import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
/**
......@@ -153,6 +154,7 @@ public class ModelEstimateController {
return ResponseResult.error(ErrorCodeEnum.DATA_NOT_EXIST, errorMessage);
}
modelEstimate.setTaskStatus(taskStatus);
modelEstimate.setUpdateTime(new Date());
if (!modelEstimateService.updateById(modelEstimate)) {
errorMessage = "任务状态信息提交错误!";
return ResponseResult.error(ErrorCodeEnum.DATA_NOT_EXIST, errorMessage);
......
......@@ -153,15 +153,8 @@ public class ModelVersionController {
*/
@GetMapping("/lastModelVersion")
public ResponseResult<ModelVersionVo> lastModelVersion(@RequestParam Long modelId) {
ModelVersion modelVersionFilter = new ModelVersion();
modelVersionFilter.setModelId(modelId);
List<ModelVersion> modelVersions = modelVersionService.getModelVersionList(modelVersionFilter, "model_version");
if (modelVersions == null) {
return ResponseResult.error(ErrorCodeEnum.DATA_NOT_EXIST);
}
ModelVersion modelVersion = modelVersions.get(modelVersions.size() - 1);
ModelVersionVo modelVersionVo = ModelVersion.INSTANCE.fromModel(modelVersion);
return ResponseResult.success(modelVersionVo);
ModelVersion modelVersion = this.modelVersionService.lastModelVersion(modelId);
return ResponseResult.success(ModelVersion.INSTANCE.fromModel(modelVersion));
}
private ResponseResult<Void> doDelete(Long versionId) {
......
......@@ -33,6 +33,7 @@ import org.springframework.web.bind.annotation.*;
import java.io.File;
import java.io.IOException;
import java.util.Date;
import java.util.List;
/**
......@@ -281,6 +282,7 @@ public class TuningRunController {
}
tuningRun.setRunStatus(runStatus);
tuningRun.setRunTime(runTime);
tuningRun.setUpdateTime(new Date());
if (!tuningRunService.updateById(tuningRun)) {
errorMessage = "运行状态信息提交错误!";
return ResponseResult.error(ErrorCodeEnum.DATA_NOT_EXIST, errorMessage);
......@@ -288,6 +290,32 @@ public class TuningRunController {
return ResponseResult.success();
}
/**
* 提交发布状态信息。
*
* @return 应答结果对象,包含查询结果集。
*/
@NoAuthInterface
@OperationLog(type = SysOperationLogType.UPLOAD, saveResponse = false)
@PostMapping("/postPublishStatus")
public ResponseResult<Void> postPublishStatus(@MyRequestBody Long runId,
@MyRequestBody Long runTime,
@MyRequestBody Integer publishStatus) {
String errorMessage;
TuningRun tuningRun = this.tuningRunService.getById(runId);
if (tuningRun == null) {
errorMessage = "数据验证失败,当前微调运行并不存在,请联系管理员!";
return ResponseResult.error(ErrorCodeEnum.DATA_NOT_EXIST, errorMessage);
}
tuningRun.setPublishStatus(publishStatus);
tuningRun.setUpdateTime(new Date());
if (!tuningRunService.updateById(tuningRun)) {
errorMessage = "运行发布信息提交错误!";
return ResponseResult.error(ErrorCodeEnum.DATA_NOT_EXIST, errorMessage);
}
return ResponseResult.success();
}
/**
* 查看指定精调任务运行对象详情。
*
......
......@@ -60,6 +60,12 @@ public class ModelCompressDto {
@ApiModelProperty(value = "新模型名称")
private String targetModelName;
/**
* 任务状态。
*/
@ApiModelProperty(value = "任务状态")
private Integer taskStatus;
/**
* task_name LIKE搜索字符串。
*/
......
......@@ -8,6 +8,7 @@ import com.yice.common.core.base.mapper.BaseModelMapper;
import com.yice.common.core.base.model.BaseModel;
import com.yice.common.core.util.MyCommonUtil;
import com.yice.webadmin.app.vo.ModelCompressVo;
import io.swagger.annotations.ApiModelProperty;
import lombok.Data;
import lombok.EqualsAndHashCode;
import org.mapstruct.Mapper;
......@@ -68,6 +69,10 @@ public class ModelCompress extends BaseModel {
@TableField(exist = false)
private String targetModelName;
/**
* 任务状态。
*/
private Integer taskStatus;
/**
* task_name LIKE搜索字符串。
*/
......
......@@ -38,7 +38,7 @@ public interface ModelCompressService extends IBaseService<ModelCompress, Long>
* @param relationData 全部关联从表数据。
* @return 返回新增主表对象。
*/
ModelCompress saveNewWithRelation(ModelCompress modelCompress, JSONObject relationData) throws URISyntaxException;
ModelCompress saveNewWithRelation(ModelCompress modelCompress, JSONObject relationData);
/**
* 更新数据对象。
......
......@@ -81,4 +81,6 @@ public interface ModelVersionService extends IBaseService<ModelVersion, Long> {
* @return 查询结果集。
*/
List<ModelVersion> getModelVersionListWithRelation(ModelVersion filter, String orderBy);
ModelVersion lastModelVersion(Long modelId);
}
......@@ -18,6 +18,7 @@ import com.yice.webadmin.app.service.ModelCompressService;
import com.yice.webadmin.app.service.ModelManageService;
import com.yice.webadmin.app.service.ModelTaskService;
import com.yice.webadmin.app.service.ModelVersionService;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.java_websocket.client.WebSocketClient;
import org.java_websocket.drafts.Draft_6455;
......@@ -29,6 +30,8 @@ import org.springframework.transaction.annotation.Transactional;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
/**
* 模型压缩数据操作服务类。
......@@ -53,7 +56,7 @@ public class ModelCompressServiceImpl extends BaseService<ModelCompress, Long> i
@Autowired
private PythonConfig pythonConfig;
private String modelVersionURl;
/**
* 返回当前Service的主表Mapper对象。
......@@ -93,35 +96,24 @@ public class ModelCompressServiceImpl extends BaseService<ModelCompress, Long> i
}
@SneakyThrows
@Transactional(rollbackFor = Exception.class)
@Override
public ModelCompress saveNewWithRelation(ModelCompress modelCompress, JSONObject relationData) throws URISyntaxException {
this.saveNew(modelCompress);
this.saveOrUpdateOneToOneRelationData(modelCompress, relationData);
ModelVersion modelVersion = this.modelVersionService.getById(modelCompress.getSourceVersionId());
ModelVersion modelVersionS = new ModelVersion();
modelVersionS.setTaskId(modelCompress.getTaskId());
modelVersionS.setIsCompress(1);
public ModelCompress saveNewWithRelation(ModelCompress modelCompress, JSONObject relationData){
String targetModelVersionURl;
Long taskId = idGenerator.nextLongId();
modelCompress.setTaskId(taskId);
ModelVersion sourceModelVersion = this.modelVersionService.getById(modelCompress.getSourceVersionId());
if (modelCompress.getCreateMethod() == 0) {
modelVersionS.setModelId(modelCompress.getTargetModelId());
ModelManage modelManageS = this.modelManageService.getById(modelCompress.getTargetModelId());
modelManageS.setModelDescribe(modelCompress.getTaskDescribe());
modelManageS.setModelType(0);
modelManageS.setIsBaseModel(0);
this.modelManageService.updateById(modelManageS);
ModelVersion modelVersionR = this.modelVersionService.saveNew(modelVersionS);
modelVersionURl = pythonConfig.getModelOutputFileBaseDir() + modelVersionR.getVersionName();
modelVersionR.setModelUrl(modelVersionURl);
this.modelVersionService.updateById(modelVersionR);
ModelVersion lastModelVersion = this.modelVersionService.lastModelVersion(modelCompress.getTargetModelId());
Integer newVersion = 1;
if (lastModelVersion != null) {
newVersion = lastModelVersion.getModelVersion() + 1;
}
ModelManage modelManage = this.modelManageService.getById(modelCompress.getTargetModelId());
targetModelVersionURl = pythonConfig.getModelOutputFileBaseDir() + modelManage.getModelName() + "_V" + newVersion;
} else {
ModelManage modelManageS = new ModelManage();
modelManageS.setModelName(modelCompress.getTargetModelName());
modelManageS.setModelDescribe(modelCompress.getTaskDescribe());
modelManageS.setModelType(0);
modelManageS.setIsBaseModel(0);
modelVersionURl = pythonConfig.getModelOutputFileBaseDir() + modelCompress.getTargetModelName() + "_V1";
modelVersionS.setModelUrl(modelVersionURl);
this.modelManageService.saveAndCreateVersion(modelManageS, modelVersionS);
targetModelVersionURl = pythonConfig.getModelOutputFileBaseDir() + modelCompress.getTargetModelName() + "_V1";
}
new WebSocketClient(new URI(this.pythonConfig.getPythonWebsocketUri()), new Draft_6455()) {
@Override
......@@ -136,40 +128,60 @@ public class ModelCompressServiceImpl extends BaseService<ModelCompress, Long> i
JSONObject sendJson = new JSONObject();
if (receiveMsg.equals("send_hash")) {
sendJson.put("fn_index", 44);
sendJson.put("session_hash", modelCompress.getTaskId().toString());
sendJson.put("session_hash", taskId);
log.info("发送服务端的消息:" + message);
System.out.println(sendJson.toJSONString());
this.send(sendJson.toJSONString());
}
if (receiveMsg.equals("send_data")) {
} else if (receiveMsg.equals("send_data")) {
JSONArray array = new JSONArray();
array.add("zh");
array.add(modelVersion.getVersionName());
array.add(modelVersion.getModelUrl());
array.add(sourceModelVersion.getVersionName());
array.add(sourceModelVersion.getModelUrl());
array.add(new JSONArray());
array.add("");
array.add("");
array.add("lora");
array.add("chatglm3");
array.add(2);
array.add(modelVersionURl);
array.add(targetModelVersionURl);
array.add("8");
sendJson.put("data", array);
sendJson.put("event_data", "null");
sendJson.put("fn_index", 44);
sendJson.put("session_hash", modelCompress.getTaskId().toString());
sendJson.put("session_hash", taskId);
System.out.println(array.toJSONString());
log.info("发送服务端的消息:" + sendJson.toJSONString());
this.send(sendJson.toJSONString());
}
if (receiveMsg.equals("process_completed")) {
} else if (receiveMsg.equals("process_completed")) {
if (receiveJson.getBoolean("success") == true){
ModelVersion modelVersionS = new ModelVersion();
modelVersionS.setTaskId(modelCompress.getTaskId());
modelVersionS.setIsCompress(1);
modelVersionS.setModelUrl(targetModelVersionURl);
saveNew(modelCompress);
saveOrUpdateOneToOneRelationData(modelCompress, relationData);
ModelManage modelManageS = new ModelManage();
if (modelCompress.getCreateMethod() == 0) {
modelVersionS.setModelId(modelCompress.getTargetModelId());
modelManageS = modelManageService.getById(modelCompress.getTargetModelId());
modelManageS.setModelDescribe(modelCompress.getTaskDescribe());
modelManageS.setModelType(0);
modelManageS.setIsBaseModel(0);
modelManageService.updateById(modelManageS);
modelVersionService.saveNew(modelVersionS);
} else {
modelManageS.setModelName(modelCompress.getTargetModelName());
modelManageS.setModelDescribe(modelCompress.getTaskDescribe());
modelManageS.setModelType(0);
modelManageS.setIsBaseModel(0);
modelManageService.saveAndCreateVersion(modelManageS, modelVersionS);
}
}
this.close();
}
}
@Override
public void onClose(int i, String s, boolean b) {
log.info("关闭连接:::" + "i = " + i + ":::s = " + s + ":::b = " + b);
}
@Override
public void onError(Exception e) {
log.error("报错了:::" + e.getMessage());
......
......@@ -6,8 +6,10 @@ import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper;
import com.github.pagehelper.Page;
import com.yice.common.core.base.dao.BaseDaoMapper;
import com.yice.common.core.base.service.BaseService;
import com.yice.common.core.constant.ErrorCodeEnum;
import com.yice.common.core.object.CallResult;
import com.yice.common.core.object.MyRelationParam;
import com.yice.common.core.object.ResponseResult;
import com.yice.common.core.util.MyModelUtil;
import com.yice.common.sequence.wrapper.IdGeneratorWrapper;
import com.yice.webadmin.app.dao.ModelVersionMapper;
......@@ -179,6 +181,14 @@ public class ModelVersionServiceImpl extends BaseService<ModelVersion, Long> imp
return resultList;
}
@Override
public ModelVersion lastModelVersion(Long modelId) {
ModelVersion modelVersionFilter = new ModelVersion();
modelVersionFilter.setModelId(modelId);
List<ModelVersion> modelVersions = this.getModelVersionList(modelVersionFilter, "model_version");
return modelVersions.get(modelVersions.size() - 1);
}
/**
* 根据最新对象和原有对象的数据对比,判断关联的字典数据和多对一主表数据是否都是合法数据。
*
......
......@@ -60,7 +60,7 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
private IdGeneratorWrapper idGenerator;
@Autowired
private PythonConfig pythonConfig;
private String modelVersionURl;
/**
* 返回当前Service的主表Mapper对象。
......@@ -187,36 +187,25 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
@Transactional(rollbackFor = Exception.class)
@Override
public boolean publishToModelVersion(RunPublishDto runPublishDto) {
String targetModelVersionURl;
TuningRun tuningRun = this.getById(runPublishDto.getRunId());
ModelVersion modelVersion = this.modelVersionService.getById(tuningRun.getModelVersionId());
ModelVersion modelVersionS = new ModelVersion();
modelVersionS.setRunId(tuningRun.getRunId());
modelVersionS.setTaskId(tuningRun.getTaskId());
if (runPublishDto.getPublishWay() == 0) {
ModelManage modelManageS = new ModelManage();
modelManageS.setModelDescribe(runPublishDto.getModelDescribe());
modelManageS.setModelName(runPublishDto.getModelName());
modelManageS.setModelType(runPublishDto.getModelType());
modelManageS.setIsBaseModel(0);
modelVersionURl = pythonConfig.getModelOutputFileBaseDir() + runPublishDto.getModelName() + "_V1";
modelVersionS.setModelUrl(modelVersionURl);
this.modelManageService.saveAndCreateVersion(modelManageS, modelVersionS);
targetModelVersionURl = pythonConfig.getModelOutputFileBaseDir() + runPublishDto.getModelName() + "_V1";
} else {
modelVersionS.setModelId(runPublishDto.getModelId());
ModelManage modelManageS = this.modelManageService.getById(runPublishDto.getModelId());
modelManageS.setModelDescribe(runPublishDto.getModelDescribe());
this.modelManageService.updateById(modelManageS);
ModelVersion modelVersionR = this.modelVersionService.saveNew(modelVersionS);
modelVersionURl = pythonConfig.getModelOutputFileBaseDir() + modelVersionR.getVersionName();
modelVersionR.setModelUrl(modelVersionURl);
this.modelVersionService.updateById(modelVersionR);
ModelVersion lastModelVersion = this.modelVersionService.lastModelVersion(runPublishDto.getModelId());
Integer newVersion = 1;
if (lastModelVersion != null) {
newVersion = lastModelVersion.getModelVersion() + 1;
}
ModelManage modelManage = this.modelManageService.getById(runPublishDto.getModelId());
targetModelVersionURl = pythonConfig.getModelOutputFileBaseDir() + modelManage.getModelName() + "_V" + newVersion;
}
new WebSocketClient(new URI(this.pythonConfig.getPythonWebsocketUri()), new Draft_6455()) {
@Override
public void onOpen(ServerHandshake serverHandshake) {
log.info("-------------与大模型建立连接-------------");
}
@Override
public void onMessage(String message) {
log.info("收到来自服务端的消息:" + message);
......@@ -227,8 +216,6 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
if (receiveMsg.equals("send_hash")) {
sendJson.put("fn_index", 44);
sendJson.put("session_hash", runPublishDto.getRunId().toString());
log.info("发送服务端的消息:" + message);
System.out.println(sendJson.toJSONString());
this.send(sendJson.toJSONString());
}
if (receiveMsg.equals("send_data")) {
......@@ -242,32 +229,52 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
array.add(tuningRun.getTrainMethod());
array.add(jsonObject.get("promptTemplate"));
array.add(2);
array.add(modelVersionURl);
array.add(targetModelVersionURl);
array.add("none");
sendJson.put("data", array);
sendJson.put("event_data", "null");
sendJson.put("fn_index", 44);
sendJson.put("session_hash", runPublishDto.getRunId().toString());
System.out.println(array.toJSONString());
log.info("发送服务端的消息:" + sendJson.toJSONString());
this.send(sendJson.toJSONString());
}
log.info("发送服务端的消息:" + sendJson.toJSONString());
if (receiveMsg.equals("process_completed")) {
if (receiveJson.getBoolean("success") == true){
ModelVersion modelVersionS = new ModelVersion();
modelVersionS.setRunId(tuningRun.getRunId());
modelVersionS.setTaskId(tuningRun.getTaskId());
modelVersionS.setModelUrl(targetModelVersionURl);
ModelManage modelManageS = new ModelManage();
if (runPublishDto.getPublishWay() == 0) {
modelManageS.setModelDescribe(runPublishDto.getModelDescribe());
modelManageS.setModelName(runPublishDto.getModelName());
modelManageS.setModelType(runPublishDto.getModelType());
modelManageS.setIsBaseModel(0);
modelManageService.saveAndCreateVersion(modelManageS, modelVersionS);
} else {
modelManageS = modelManageService.getById(runPublishDto.getModelId());
modelManageS.setModelDescribe(runPublishDto.getModelDescribe());
modelManageS.setIsBaseModel(0);
modelManageService.updateById(modelManageS);
modelVersionS.setModelId(runPublishDto.getModelId());
modelVersionService.saveNew(modelVersionS);
}
tuningRun.setPublishStatus(1);
} else {
tuningRun.setPublishStatus(-1);
}
this.close();
}
}
@Override
public void onClose(int i, String s, boolean b) {
log.info("关闭连接:::" + "i = " + i + ":::s = " + s + ":::b = " + b);
}
@Override
public void onError(Exception e) {
log.error("报错了:::" + e.getMessage());
}
}.connect();
tuningRun.setPublishStatus(1);
return this.updateById(tuningRun);
}
......
......@@ -55,6 +55,13 @@ public class ModelCompressVo extends BaseVo {
@ApiModelProperty(value = "目标模型")
private Long targetVersionId;
/**
* 任务状态。
*/
@ApiModelProperty(value = "任务状态")
private Integer taskStatus;
/**
* taskId 的一对一关联数据对象,数据对应类型为ModelTaskVo。
*/
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment