Commit 5ec4476b authored by linpeiqin's avatar linpeiqin

修改模型导出调用的核心逻辑

parent 7297ac91
...@@ -4,6 +4,7 @@ import com.alibaba.fastjson.JSONObject; ...@@ -4,6 +4,7 @@ import com.alibaba.fastjson.JSONObject;
import com.github.pagehelper.page.PageMethod; import com.github.pagehelper.page.PageMethod;
import com.github.xiaoymin.knife4j.annotations.ApiOperationSupport; import com.github.xiaoymin.knife4j.annotations.ApiOperationSupport;
import com.yice.common.core.annotation.MyRequestBody; import com.yice.common.core.annotation.MyRequestBody;
import com.yice.common.core.annotation.NoAuthInterface;
import com.yice.common.core.constant.ErrorCodeEnum; import com.yice.common.core.constant.ErrorCodeEnum;
import com.yice.common.core.object.*; import com.yice.common.core.object.*;
import com.yice.common.core.util.MyCommonUtil; import com.yice.common.core.util.MyCommonUtil;
...@@ -25,6 +26,7 @@ import lombok.extern.slf4j.Slf4j; ...@@ -25,6 +26,7 @@ import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*; import org.springframework.web.bind.annotation.*;
import java.util.Date;
import java.util.List; import java.util.List;
/** /**
...@@ -77,6 +79,27 @@ public class ModelCompressController { ...@@ -77,6 +79,27 @@ public class ModelCompressController {
return ResponseResult.success(modelCompress.getTaskId()); return ResponseResult.success(modelCompress.getTaskId());
} }
@NoAuthInterface
@OperationLog(type = SysOperationLogType.UPLOAD, saveResponse = false)
@PostMapping("/postCompressStatus")
public ResponseResult<Void> postCompressStatus(@MyRequestBody Long taskId,
@MyRequestBody Long runTime,
@MyRequestBody Integer taskStatus) {
String errorMessage;
ModelCompress modelCompress = this.modelCompressService.getById(taskId);
if (modelCompress == null) {
errorMessage = "数据验证失败,当前模型压缩并不存在,请联系管理员!";
return ResponseResult.error(ErrorCodeEnum.DATA_NOT_EXIST, errorMessage);
}
modelCompress.setTaskStatus(taskStatus);
modelCompress.setUpdateTime(new Date());
if (!modelCompressService.updateById(modelCompress)) {
errorMessage = "型压缩信息提交错误!";
return ResponseResult.error(ErrorCodeEnum.DATA_NOT_EXIST, errorMessage);
}
return ResponseResult.success();
}
/** /**
* 修改模型压缩数据,及其关联的从表数据。 * 修改模型压缩数据,及其关联的从表数据。
* *
......
...@@ -35,6 +35,7 @@ import java.io.File; ...@@ -35,6 +35,7 @@ import java.io.File;
import java.io.FileReader; import java.io.FileReader;
import java.io.IOException; import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Date;
import java.util.List; import java.util.List;
/** /**
...@@ -153,6 +154,7 @@ public class ModelEstimateController { ...@@ -153,6 +154,7 @@ public class ModelEstimateController {
return ResponseResult.error(ErrorCodeEnum.DATA_NOT_EXIST, errorMessage); return ResponseResult.error(ErrorCodeEnum.DATA_NOT_EXIST, errorMessage);
} }
modelEstimate.setTaskStatus(taskStatus); modelEstimate.setTaskStatus(taskStatus);
modelEstimate.setUpdateTime(new Date());
if (!modelEstimateService.updateById(modelEstimate)) { if (!modelEstimateService.updateById(modelEstimate)) {
errorMessage = "任务状态信息提交错误!"; errorMessage = "任务状态信息提交错误!";
return ResponseResult.error(ErrorCodeEnum.DATA_NOT_EXIST, errorMessage); return ResponseResult.error(ErrorCodeEnum.DATA_NOT_EXIST, errorMessage);
......
...@@ -153,15 +153,8 @@ public class ModelVersionController { ...@@ -153,15 +153,8 @@ public class ModelVersionController {
*/ */
@GetMapping("/lastModelVersion") @GetMapping("/lastModelVersion")
public ResponseResult<ModelVersionVo> lastModelVersion(@RequestParam Long modelId) { public ResponseResult<ModelVersionVo> lastModelVersion(@RequestParam Long modelId) {
ModelVersion modelVersionFilter = new ModelVersion(); ModelVersion modelVersion = this.modelVersionService.lastModelVersion(modelId);
modelVersionFilter.setModelId(modelId); return ResponseResult.success(ModelVersion.INSTANCE.fromModel(modelVersion));
List<ModelVersion> modelVersions = modelVersionService.getModelVersionList(modelVersionFilter, "model_version");
if (modelVersions == null) {
return ResponseResult.error(ErrorCodeEnum.DATA_NOT_EXIST);
}
ModelVersion modelVersion = modelVersions.get(modelVersions.size() - 1);
ModelVersionVo modelVersionVo = ModelVersion.INSTANCE.fromModel(modelVersion);
return ResponseResult.success(modelVersionVo);
} }
private ResponseResult<Void> doDelete(Long versionId) { private ResponseResult<Void> doDelete(Long versionId) {
......
...@@ -33,6 +33,7 @@ import org.springframework.web.bind.annotation.*; ...@@ -33,6 +33,7 @@ import org.springframework.web.bind.annotation.*;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.util.Date;
import java.util.List; import java.util.List;
/** /**
...@@ -281,6 +282,7 @@ public class TuningRunController { ...@@ -281,6 +282,7 @@ public class TuningRunController {
} }
tuningRun.setRunStatus(runStatus); tuningRun.setRunStatus(runStatus);
tuningRun.setRunTime(runTime); tuningRun.setRunTime(runTime);
tuningRun.setUpdateTime(new Date());
if (!tuningRunService.updateById(tuningRun)) { if (!tuningRunService.updateById(tuningRun)) {
errorMessage = "运行状态信息提交错误!"; errorMessage = "运行状态信息提交错误!";
return ResponseResult.error(ErrorCodeEnum.DATA_NOT_EXIST, errorMessage); return ResponseResult.error(ErrorCodeEnum.DATA_NOT_EXIST, errorMessage);
...@@ -288,6 +290,32 @@ public class TuningRunController { ...@@ -288,6 +290,32 @@ public class TuningRunController {
return ResponseResult.success(); return ResponseResult.success();
} }
/**
* 提交发布状态信息。
*
* @return 应答结果对象,包含查询结果集。
*/
@NoAuthInterface
@OperationLog(type = SysOperationLogType.UPLOAD, saveResponse = false)
@PostMapping("/postPublishStatus")
public ResponseResult<Void> postPublishStatus(@MyRequestBody Long runId,
@MyRequestBody Long runTime,
@MyRequestBody Integer publishStatus) {
String errorMessage;
TuningRun tuningRun = this.tuningRunService.getById(runId);
if (tuningRun == null) {
errorMessage = "数据验证失败,当前微调运行并不存在,请联系管理员!";
return ResponseResult.error(ErrorCodeEnum.DATA_NOT_EXIST, errorMessage);
}
tuningRun.setPublishStatus(publishStatus);
tuningRun.setUpdateTime(new Date());
if (!tuningRunService.updateById(tuningRun)) {
errorMessage = "运行发布信息提交错误!";
return ResponseResult.error(ErrorCodeEnum.DATA_NOT_EXIST, errorMessage);
}
return ResponseResult.success();
}
/** /**
* 查看指定精调任务运行对象详情。 * 查看指定精调任务运行对象详情。
* *
......
...@@ -60,6 +60,12 @@ public class ModelCompressDto { ...@@ -60,6 +60,12 @@ public class ModelCompressDto {
@ApiModelProperty(value = "新模型名称") @ApiModelProperty(value = "新模型名称")
private String targetModelName; private String targetModelName;
/**
* 任务状态。
*/
@ApiModelProperty(value = "任务状态")
private Integer taskStatus;
/** /**
* task_name LIKE搜索字符串。 * task_name LIKE搜索字符串。
*/ */
......
...@@ -8,6 +8,7 @@ import com.yice.common.core.base.mapper.BaseModelMapper; ...@@ -8,6 +8,7 @@ import com.yice.common.core.base.mapper.BaseModelMapper;
import com.yice.common.core.base.model.BaseModel; import com.yice.common.core.base.model.BaseModel;
import com.yice.common.core.util.MyCommonUtil; import com.yice.common.core.util.MyCommonUtil;
import com.yice.webadmin.app.vo.ModelCompressVo; import com.yice.webadmin.app.vo.ModelCompressVo;
import io.swagger.annotations.ApiModelProperty;
import lombok.Data; import lombok.Data;
import lombok.EqualsAndHashCode; import lombok.EqualsAndHashCode;
import org.mapstruct.Mapper; import org.mapstruct.Mapper;
...@@ -68,6 +69,10 @@ public class ModelCompress extends BaseModel { ...@@ -68,6 +69,10 @@ public class ModelCompress extends BaseModel {
@TableField(exist = false) @TableField(exist = false)
private String targetModelName; private String targetModelName;
/**
* 任务状态。
*/
private Integer taskStatus;
/** /**
* task_name LIKE搜索字符串。 * task_name LIKE搜索字符串。
*/ */
......
...@@ -38,7 +38,7 @@ public interface ModelCompressService extends IBaseService<ModelCompress, Long> ...@@ -38,7 +38,7 @@ public interface ModelCompressService extends IBaseService<ModelCompress, Long>
* @param relationData 全部关联从表数据。 * @param relationData 全部关联从表数据。
* @return 返回新增主表对象。 * @return 返回新增主表对象。
*/ */
ModelCompress saveNewWithRelation(ModelCompress modelCompress, JSONObject relationData) throws URISyntaxException; ModelCompress saveNewWithRelation(ModelCompress modelCompress, JSONObject relationData);
/** /**
* 更新数据对象。 * 更新数据对象。
......
...@@ -81,4 +81,6 @@ public interface ModelVersionService extends IBaseService<ModelVersion, Long> { ...@@ -81,4 +81,6 @@ public interface ModelVersionService extends IBaseService<ModelVersion, Long> {
* @return 查询结果集。 * @return 查询结果集。
*/ */
List<ModelVersion> getModelVersionListWithRelation(ModelVersion filter, String orderBy); List<ModelVersion> getModelVersionListWithRelation(ModelVersion filter, String orderBy);
ModelVersion lastModelVersion(Long modelId);
} }
...@@ -18,6 +18,7 @@ import com.yice.webadmin.app.service.ModelCompressService; ...@@ -18,6 +18,7 @@ import com.yice.webadmin.app.service.ModelCompressService;
import com.yice.webadmin.app.service.ModelManageService; import com.yice.webadmin.app.service.ModelManageService;
import com.yice.webadmin.app.service.ModelTaskService; import com.yice.webadmin.app.service.ModelTaskService;
import com.yice.webadmin.app.service.ModelVersionService; import com.yice.webadmin.app.service.ModelVersionService;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.java_websocket.client.WebSocketClient; import org.java_websocket.client.WebSocketClient;
import org.java_websocket.drafts.Draft_6455; import org.java_websocket.drafts.Draft_6455;
...@@ -29,6 +30,8 @@ import org.springframework.transaction.annotation.Transactional; ...@@ -29,6 +30,8 @@ import org.springframework.transaction.annotation.Transactional;
import java.net.URI; import java.net.URI;
import java.net.URISyntaxException; import java.net.URISyntaxException;
import java.util.List; import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
/** /**
* 模型压缩数据操作服务类。 * 模型压缩数据操作服务类。
...@@ -53,7 +56,7 @@ public class ModelCompressServiceImpl extends BaseService<ModelCompress, Long> i ...@@ -53,7 +56,7 @@ public class ModelCompressServiceImpl extends BaseService<ModelCompress, Long> i
@Autowired @Autowired
private PythonConfig pythonConfig; private PythonConfig pythonConfig;
private String modelVersionURl;
/** /**
* 返回当前Service的主表Mapper对象。 * 返回当前Service的主表Mapper对象。
...@@ -93,35 +96,24 @@ public class ModelCompressServiceImpl extends BaseService<ModelCompress, Long> i ...@@ -93,35 +96,24 @@ public class ModelCompressServiceImpl extends BaseService<ModelCompress, Long> i
} }
@SneakyThrows
@Transactional(rollbackFor = Exception.class) @Transactional(rollbackFor = Exception.class)
@Override @Override
public ModelCompress saveNewWithRelation(ModelCompress modelCompress, JSONObject relationData) throws URISyntaxException { public ModelCompress saveNewWithRelation(ModelCompress modelCompress, JSONObject relationData){
this.saveNew(modelCompress); String targetModelVersionURl;
this.saveOrUpdateOneToOneRelationData(modelCompress, relationData); Long taskId = idGenerator.nextLongId();
ModelVersion modelVersion = this.modelVersionService.getById(modelCompress.getSourceVersionId()); modelCompress.setTaskId(taskId);
ModelVersion modelVersionS = new ModelVersion(); ModelVersion sourceModelVersion = this.modelVersionService.getById(modelCompress.getSourceVersionId());
modelVersionS.setTaskId(modelCompress.getTaskId());
modelVersionS.setIsCompress(1);
if (modelCompress.getCreateMethod() == 0) { if (modelCompress.getCreateMethod() == 0) {
modelVersionS.setModelId(modelCompress.getTargetModelId()); ModelVersion lastModelVersion = this.modelVersionService.lastModelVersion(modelCompress.getTargetModelId());
ModelManage modelManageS = this.modelManageService.getById(modelCompress.getTargetModelId()); Integer newVersion = 1;
modelManageS.setModelDescribe(modelCompress.getTaskDescribe()); if (lastModelVersion != null) {
modelManageS.setModelType(0); newVersion = lastModelVersion.getModelVersion() + 1;
modelManageS.setIsBaseModel(0); }
this.modelManageService.updateById(modelManageS); ModelManage modelManage = this.modelManageService.getById(modelCompress.getTargetModelId());
ModelVersion modelVersionR = this.modelVersionService.saveNew(modelVersionS); targetModelVersionURl = pythonConfig.getModelOutputFileBaseDir() + modelManage.getModelName() + "_V" + newVersion;
modelVersionURl = pythonConfig.getModelOutputFileBaseDir() + modelVersionR.getVersionName();
modelVersionR.setModelUrl(modelVersionURl);
this.modelVersionService.updateById(modelVersionR);
} else { } else {
ModelManage modelManageS = new ModelManage(); targetModelVersionURl = pythonConfig.getModelOutputFileBaseDir() + modelCompress.getTargetModelName() + "_V1";
modelManageS.setModelName(modelCompress.getTargetModelName());
modelManageS.setModelDescribe(modelCompress.getTaskDescribe());
modelManageS.setModelType(0);
modelManageS.setIsBaseModel(0);
modelVersionURl = pythonConfig.getModelOutputFileBaseDir() + modelCompress.getTargetModelName() + "_V1";
modelVersionS.setModelUrl(modelVersionURl);
this.modelManageService.saveAndCreateVersion(modelManageS, modelVersionS);
} }
new WebSocketClient(new URI(this.pythonConfig.getPythonWebsocketUri()), new Draft_6455()) { new WebSocketClient(new URI(this.pythonConfig.getPythonWebsocketUri()), new Draft_6455()) {
@Override @Override
...@@ -136,40 +128,60 @@ public class ModelCompressServiceImpl extends BaseService<ModelCompress, Long> i ...@@ -136,40 +128,60 @@ public class ModelCompressServiceImpl extends BaseService<ModelCompress, Long> i
JSONObject sendJson = new JSONObject(); JSONObject sendJson = new JSONObject();
if (receiveMsg.equals("send_hash")) { if (receiveMsg.equals("send_hash")) {
sendJson.put("fn_index", 44); sendJson.put("fn_index", 44);
sendJson.put("session_hash", modelCompress.getTaskId().toString()); sendJson.put("session_hash", taskId);
log.info("发送服务端的消息:" + message); log.info("发送服务端的消息:" + message);
System.out.println(sendJson.toJSONString()); System.out.println(sendJson.toJSONString());
this.send(sendJson.toJSONString()); this.send(sendJson.toJSONString());
} } else if (receiveMsg.equals("send_data")) {
if (receiveMsg.equals("send_data")) {
JSONArray array = new JSONArray(); JSONArray array = new JSONArray();
array.add("zh"); array.add("zh");
array.add(modelVersion.getVersionName()); array.add(sourceModelVersion.getVersionName());
array.add(modelVersion.getModelUrl()); array.add(sourceModelVersion.getModelUrl());
array.add(new JSONArray()); array.add(new JSONArray());
array.add(""); array.add("lora");
array.add(""); array.add("chatglm3");
array.add(2); array.add(2);
array.add(modelVersionURl); array.add(targetModelVersionURl);
array.add("8"); array.add("8");
sendJson.put("data", array); sendJson.put("data", array);
sendJson.put("event_data", "null"); sendJson.put("event_data", "null");
sendJson.put("fn_index", 44); sendJson.put("fn_index", 44);
sendJson.put("session_hash", modelCompress.getTaskId().toString()); sendJson.put("session_hash", taskId);
System.out.println(array.toJSONString()); System.out.println(array.toJSONString());
log.info("发送服务端的消息:" + sendJson.toJSONString()); log.info("发送服务端的消息:" + sendJson.toJSONString());
this.send(sendJson.toJSONString()); this.send(sendJson.toJSONString());
} } else if (receiveMsg.equals("process_completed")) {
if (receiveMsg.equals("process_completed")) { if (receiveJson.getBoolean("success") == true){
ModelVersion modelVersionS = new ModelVersion();
modelVersionS.setTaskId(modelCompress.getTaskId());
modelVersionS.setIsCompress(1);
modelVersionS.setModelUrl(targetModelVersionURl);
saveNew(modelCompress);
saveOrUpdateOneToOneRelationData(modelCompress, relationData);
ModelManage modelManageS = new ModelManage();
if (modelCompress.getCreateMethod() == 0) {
modelVersionS.setModelId(modelCompress.getTargetModelId());
modelManageS = modelManageService.getById(modelCompress.getTargetModelId());
modelManageS.setModelDescribe(modelCompress.getTaskDescribe());
modelManageS.setModelType(0);
modelManageS.setIsBaseModel(0);
modelManageService.updateById(modelManageS);
modelVersionService.saveNew(modelVersionS);
} else {
modelManageS.setModelName(modelCompress.getTargetModelName());
modelManageS.setModelDescribe(modelCompress.getTaskDescribe());
modelManageS.setModelType(0);
modelManageS.setIsBaseModel(0);
modelManageService.saveAndCreateVersion(modelManageS, modelVersionS);
}
}
this.close(); this.close();
} }
} }
@Override @Override
public void onClose(int i, String s, boolean b) { public void onClose(int i, String s, boolean b) {
log.info("关闭连接:::" + "i = " + i + ":::s = " + s + ":::b = " + b); log.info("关闭连接:::" + "i = " + i + ":::s = " + s + ":::b = " + b);
} }
@Override @Override
public void onError(Exception e) { public void onError(Exception e) {
log.error("报错了:::" + e.getMessage()); log.error("报错了:::" + e.getMessage());
......
...@@ -6,8 +6,10 @@ import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper; ...@@ -6,8 +6,10 @@ import com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper;
import com.github.pagehelper.Page; import com.github.pagehelper.Page;
import com.yice.common.core.base.dao.BaseDaoMapper; import com.yice.common.core.base.dao.BaseDaoMapper;
import com.yice.common.core.base.service.BaseService; import com.yice.common.core.base.service.BaseService;
import com.yice.common.core.constant.ErrorCodeEnum;
import com.yice.common.core.object.CallResult; import com.yice.common.core.object.CallResult;
import com.yice.common.core.object.MyRelationParam; import com.yice.common.core.object.MyRelationParam;
import com.yice.common.core.object.ResponseResult;
import com.yice.common.core.util.MyModelUtil; import com.yice.common.core.util.MyModelUtil;
import com.yice.common.sequence.wrapper.IdGeneratorWrapper; import com.yice.common.sequence.wrapper.IdGeneratorWrapper;
import com.yice.webadmin.app.dao.ModelVersionMapper; import com.yice.webadmin.app.dao.ModelVersionMapper;
...@@ -179,6 +181,14 @@ public class ModelVersionServiceImpl extends BaseService<ModelVersion, Long> imp ...@@ -179,6 +181,14 @@ public class ModelVersionServiceImpl extends BaseService<ModelVersion, Long> imp
return resultList; return resultList;
} }
@Override
public ModelVersion lastModelVersion(Long modelId) {
ModelVersion modelVersionFilter = new ModelVersion();
modelVersionFilter.setModelId(modelId);
List<ModelVersion> modelVersions = this.getModelVersionList(modelVersionFilter, "model_version");
return modelVersions.get(modelVersions.size() - 1);
}
/** /**
* 根据最新对象和原有对象的数据对比,判断关联的字典数据和多对一主表数据是否都是合法数据。 * 根据最新对象和原有对象的数据对比,判断关联的字典数据和多对一主表数据是否都是合法数据。
* *
......
...@@ -60,7 +60,7 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement ...@@ -60,7 +60,7 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
private IdGeneratorWrapper idGenerator; private IdGeneratorWrapper idGenerator;
@Autowired @Autowired
private PythonConfig pythonConfig; private PythonConfig pythonConfig;
private String modelVersionURl;
/** /**
* 返回当前Service的主表Mapper对象。 * 返回当前Service的主表Mapper对象。
...@@ -187,36 +187,25 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement ...@@ -187,36 +187,25 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
@Transactional(rollbackFor = Exception.class) @Transactional(rollbackFor = Exception.class)
@Override @Override
public boolean publishToModelVersion(RunPublishDto runPublishDto) { public boolean publishToModelVersion(RunPublishDto runPublishDto) {
String targetModelVersionURl;
TuningRun tuningRun = this.getById(runPublishDto.getRunId()); TuningRun tuningRun = this.getById(runPublishDto.getRunId());
ModelVersion modelVersion = this.modelVersionService.getById(tuningRun.getModelVersionId()); ModelVersion modelVersion = this.modelVersionService.getById(tuningRun.getModelVersionId());
ModelVersion modelVersionS = new ModelVersion();
modelVersionS.setRunId(tuningRun.getRunId());
modelVersionS.setTaskId(tuningRun.getTaskId());
if (runPublishDto.getPublishWay() == 0) { if (runPublishDto.getPublishWay() == 0) {
ModelManage modelManageS = new ModelManage(); targetModelVersionURl = pythonConfig.getModelOutputFileBaseDir() + runPublishDto.getModelName() + "_V1";
modelManageS.setModelDescribe(runPublishDto.getModelDescribe());
modelManageS.setModelName(runPublishDto.getModelName());
modelManageS.setModelType(runPublishDto.getModelType());
modelManageS.setIsBaseModel(0);
modelVersionURl = pythonConfig.getModelOutputFileBaseDir() + runPublishDto.getModelName() + "_V1";
modelVersionS.setModelUrl(modelVersionURl);
this.modelManageService.saveAndCreateVersion(modelManageS, modelVersionS);
} else { } else {
modelVersionS.setModelId(runPublishDto.getModelId()); ModelVersion lastModelVersion = this.modelVersionService.lastModelVersion(runPublishDto.getModelId());
ModelManage modelManageS = this.modelManageService.getById(runPublishDto.getModelId()); Integer newVersion = 1;
modelManageS.setModelDescribe(runPublishDto.getModelDescribe()); if (lastModelVersion != null) {
this.modelManageService.updateById(modelManageS); newVersion = lastModelVersion.getModelVersion() + 1;
ModelVersion modelVersionR = this.modelVersionService.saveNew(modelVersionS); }
modelVersionURl = pythonConfig.getModelOutputFileBaseDir() + modelVersionR.getVersionName(); ModelManage modelManage = this.modelManageService.getById(runPublishDto.getModelId());
modelVersionR.setModelUrl(modelVersionURl); targetModelVersionURl = pythonConfig.getModelOutputFileBaseDir() + modelManage.getModelName() + "_V" + newVersion;
this.modelVersionService.updateById(modelVersionR);
} }
new WebSocketClient(new URI(this.pythonConfig.getPythonWebsocketUri()), new Draft_6455()) { new WebSocketClient(new URI(this.pythonConfig.getPythonWebsocketUri()), new Draft_6455()) {
@Override @Override
public void onOpen(ServerHandshake serverHandshake) { public void onOpen(ServerHandshake serverHandshake) {
log.info("-------------与大模型建立连接-------------"); log.info("-------------与大模型建立连接-------------");
} }
@Override @Override
public void onMessage(String message) { public void onMessage(String message) {
log.info("收到来自服务端的消息:" + message); log.info("收到来自服务端的消息:" + message);
...@@ -227,8 +216,6 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement ...@@ -227,8 +216,6 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
if (receiveMsg.equals("send_hash")) { if (receiveMsg.equals("send_hash")) {
sendJson.put("fn_index", 44); sendJson.put("fn_index", 44);
sendJson.put("session_hash", runPublishDto.getRunId().toString()); sendJson.put("session_hash", runPublishDto.getRunId().toString());
log.info("发送服务端的消息:" + message);
System.out.println(sendJson.toJSONString());
this.send(sendJson.toJSONString()); this.send(sendJson.toJSONString());
} }
if (receiveMsg.equals("send_data")) { if (receiveMsg.equals("send_data")) {
...@@ -242,32 +229,52 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement ...@@ -242,32 +229,52 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
array.add(tuningRun.getTrainMethod()); array.add(tuningRun.getTrainMethod());
array.add(jsonObject.get("promptTemplate")); array.add(jsonObject.get("promptTemplate"));
array.add(2); array.add(2);
array.add(modelVersionURl); array.add(targetModelVersionURl);
array.add("none"); array.add("none");
sendJson.put("data", array); sendJson.put("data", array);
sendJson.put("event_data", "null"); sendJson.put("event_data", "null");
sendJson.put("fn_index", 44); sendJson.put("fn_index", 44);
sendJson.put("session_hash", runPublishDto.getRunId().toString()); sendJson.put("session_hash", runPublishDto.getRunId().toString());
System.out.println(array.toJSONString());
log.info("发送服务端的消息:" + sendJson.toJSONString());
this.send(sendJson.toJSONString()); this.send(sendJson.toJSONString());
} }
log.info("发送服务端的消息:" + sendJson.toJSONString());
if (receiveMsg.equals("process_completed")) { if (receiveMsg.equals("process_completed")) {
if (receiveJson.getBoolean("success") == true){
ModelVersion modelVersionS = new ModelVersion();
modelVersionS.setRunId(tuningRun.getRunId());
modelVersionS.setTaskId(tuningRun.getTaskId());
modelVersionS.setModelUrl(targetModelVersionURl);
ModelManage modelManageS = new ModelManage();
if (runPublishDto.getPublishWay() == 0) {
modelManageS.setModelDescribe(runPublishDto.getModelDescribe());
modelManageS.setModelName(runPublishDto.getModelName());
modelManageS.setModelType(runPublishDto.getModelType());
modelManageS.setIsBaseModel(0);
modelManageService.saveAndCreateVersion(modelManageS, modelVersionS);
} else {
modelManageS = modelManageService.getById(runPublishDto.getModelId());
modelManageS.setModelDescribe(runPublishDto.getModelDescribe());
modelManageS.setIsBaseModel(0);
modelManageService.updateById(modelManageS);
modelVersionS.setModelId(runPublishDto.getModelId());
modelVersionService.saveNew(modelVersionS);
}
tuningRun.setPublishStatus(1);
} else {
tuningRun.setPublishStatus(-1);
}
this.close(); this.close();
} }
} }
@Override @Override
public void onClose(int i, String s, boolean b) { public void onClose(int i, String s, boolean b) {
log.info("关闭连接:::" + "i = " + i + ":::s = " + s + ":::b = " + b); log.info("关闭连接:::" + "i = " + i + ":::s = " + s + ":::b = " + b);
} }
@Override @Override
public void onError(Exception e) { public void onError(Exception e) {
log.error("报错了:::" + e.getMessage()); log.error("报错了:::" + e.getMessage());
} }
}.connect(); }.connect();
tuningRun.setPublishStatus(1);
return this.updateById(tuningRun); return this.updateById(tuningRun);
} }
......
...@@ -55,6 +55,13 @@ public class ModelCompressVo extends BaseVo { ...@@ -55,6 +55,13 @@ public class ModelCompressVo extends BaseVo {
@ApiModelProperty(value = "目标模型") @ApiModelProperty(value = "目标模型")
private Long targetVersionId; private Long targetVersionId;
/**
* 任务状态。
*/
@ApiModelProperty(value = "任务状态")
private Integer taskStatus;
/** /**
* taskId 的一对一关联数据对象,数据对应类型为ModelTaskVo。 * taskId 的一对一关联数据对象,数据对应类型为ModelTaskVo。
*/ */
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment