Commit f45e3ad5 authored by linpeiqin's avatar linpeiqin

修改压缩和任务发布的逻辑,暂时验证通过了,后续应该要先存储再更新状态的方式,现在这种方式最终保存,再机械硬盘下时间过长

parent f54e1196
......@@ -9,10 +9,12 @@ import com.github.pagehelper.Page;
import com.yice.common.core.base.dao.BaseDaoMapper;
import com.yice.common.core.base.service.BaseService;
import com.yice.common.core.object.MyRelationParam;
import com.yice.common.core.object.TokenData;
import com.yice.common.core.util.MyModelUtil;
import com.yice.common.sequence.wrapper.IdGeneratorWrapper;
import com.yice.webadmin.app.config.PythonConfig;
import com.yice.webadmin.app.dao.ModelCompressMapper;
import com.yice.webadmin.app.dto.RunPublishDto;
import com.yice.webadmin.app.model.*;
import com.yice.webadmin.app.service.ModelCompressService;
import com.yice.webadmin.app.service.ModelManageService;
......@@ -102,6 +104,7 @@ public class ModelCompressServiceImpl extends BaseService<ModelCompress, Long> i
public ModelCompress saveNewWithRelation(ModelCompress modelCompress, JSONObject relationData){
String targetModelVersionURl;
Long taskId = idGenerator.nextLongId();
Long userID = TokenData.takeFromRequest().getUserId();
modelCompress.setTaskId(taskId);
ModelVersion sourceModelVersion = this.modelVersionService.getById(modelCompress.getSourceVersionId());
if (modelCompress.getCreateMethod() == 0) {
......@@ -128,9 +131,8 @@ public class ModelCompressServiceImpl extends BaseService<ModelCompress, Long> i
JSONObject sendJson = new JSONObject();
if (receiveMsg.equals("send_hash")) {
sendJson.put("fn_index", 44);
sendJson.put("session_hash", taskId);
log.info("发送服务端的消息:" + message);
System.out.println(sendJson.toJSONString());
sendJson.put("session_hash", String.valueOf(taskId));
log.info("发送服务端的消息:" + sendJson.toJSONString());
this.send(sendJson.toJSONString());
} else if (receiveMsg.equals("send_data")) {
JSONArray array = new JSONArray();
......@@ -146,35 +148,12 @@ public class ModelCompressServiceImpl extends BaseService<ModelCompress, Long> i
sendJson.put("data", array);
sendJson.put("event_data", "null");
sendJson.put("fn_index", 44);
sendJson.put("session_hash", taskId);
System.out.println(array.toJSONString());
sendJson.put("session_hash", String.valueOf(taskId));
log.info("发送服务端的消息:" + sendJson.toJSONString());
this.send(sendJson.toJSONString());
} else if (receiveMsg.equals("process_completed")) {
this.close();
if (receiveJson.getBoolean("success")){
ModelVersion modelVersionS = new ModelVersion();
modelVersionS.setTaskId(modelCompress.getTaskId());
modelVersionS.setIsCompress(1);
modelVersionS.setModelUrl(targetModelVersionURl);
saveNew(modelCompress);
saveOrUpdateOneToOneRelationData(modelCompress, relationData);
ModelManage modelManageS = new ModelManage();
if (modelCompress.getCreateMethod() == 0) {
modelVersionS.setModelId(modelCompress.getTargetModelId());
modelManageS = modelManageService.getById(modelCompress.getTargetModelId());
modelManageS.setModelDescribe(modelCompress.getTaskDescribe());
modelManageS.setModelType(0);
modelManageS.setIsBaseModel(0);
modelManageService.updateById(modelManageS);
modelVersionService.saveNew(modelVersionS);
} else {
modelManageS.setModelName(modelCompress.getTargetModelName());
modelManageS.setModelDescribe(modelCompress.getTaskDescribe());
modelManageS.setModelType(0);
modelManageS.setIsBaseModel(0);
modelManageService.saveAndCreateVersion(modelManageS, modelVersionS);
}
saveAll(modelCompress,relationData,targetModelVersionURl,userID);
}
}
}
......@@ -190,6 +169,32 @@ public class ModelCompressServiceImpl extends BaseService<ModelCompress, Long> i
return modelCompress;
}
@Transactional(rollbackFor = Exception.class)
public synchronized void saveAll(ModelCompress modelCompress, JSONObject relationData, String targetModelVersionURl, Long userID) {
ModelVersion modelVersionS = new ModelVersion();
modelVersionS.setTaskId(modelCompress.getTaskId());
modelVersionS.setIsCompress(1);
modelVersionS.setModelUrl(targetModelVersionURl);
modelVersionS.setCreateUserId(userID);
modelVersionS.setUpdateUserId(userID);
modelCompress.setCreateUserId(userID);
modelCompress.setUpdateUserId(userID);
saveNew(modelCompress);
saveOrUpdateOneToOneRelationData(modelCompress, relationData,userID);
if (modelCompress.getCreateMethod() == 0) {
modelVersionS.setModelId(modelCompress.getTargetModelId());
modelVersionService.saveNew(modelVersionS);
} else {
ModelManage modelManageS = new ModelManage();
modelManageS.setCreateUserId(userID);
modelManageS.setUpdateUserId(userID);
modelManageS.setModelName(modelCompress.getTargetModelName());
modelManageS.setModelDescribe(modelCompress.getTaskDescribe());
modelManageS.setModelType(0);
modelManageS.setIsBaseModel(0);
modelManageService.saveAndCreateVersion(modelManageS, modelVersionS);
}
}
/**
* 更新数据对象。
*
......@@ -214,15 +219,17 @@ public class ModelCompressServiceImpl extends BaseService<ModelCompress, Long> i
if (modelCompress != null && !this.update(modelCompress, originalModelCompress)) {
return false;
}
this.saveOrUpdateOneToOneRelationData(originalModelCompress, relationData);
this.saveOrUpdateOneToOneRelationData(originalModelCompress, relationData,TokenData.takeFromRequest().getUserId());
return true;
}
private void saveOrUpdateOneToOneRelationData(ModelCompress modelCompress, JSONObject relationData) {
private void saveOrUpdateOneToOneRelationData(ModelCompress modelCompress, JSONObject relationData,Long userId) {
// 对于一对一新增或更新,如果主键值为空就新增,否则就更新,同时更新updateTime和updateUserId。
ModelTask modelTask = relationData.getObject("modelTask", ModelTask.class);
if (modelTask != null) {
modelTask.setTaskId(modelCompress.getTaskId());
modelTask.setCreateUserId(userId);
modelTask.setUpdateUserId(userId);
modelTaskService.saveNew(modelTask);
/*modelTaskService.saveNewOrUpdate(modelTask,
modelTaskService::saveNew, modelTaskService::update);*/
......
......@@ -76,13 +76,13 @@ public class ModelVersionServiceImpl extends BaseService<ModelVersion, Long> imp
modelVersion.setVersionName(modelName + "_V" + modelVersion.getModelVersion());
modelVersionMapper.insert(this.buildDefaultValue(modelVersion));
//此处应该调用精调运行发布的方法生成模型任务,不应该直接生成!!!!!!!!!!!!!!!!!
ModelTask modelTask = new ModelTask();
/*ModelTask modelTask = new ModelTask();
modelTask.setModelVersion(modelVersion.getModelVersion());
modelTask.setModelId(modelVersion.getModelId());
modelTask.setTaskType(0);
modelTask.setVersionId(modelVersion.getVersionId());
modelTask.setVersionName(modelVersion.getVersionName());
this.modelTaskService.saveNew(modelTask);
this.modelTaskService.saveNew(modelTask);*/
return modelVersion;
}
......
......@@ -11,6 +11,7 @@ import com.yice.common.core.base.dao.BaseDaoMapper;
import com.yice.common.core.base.service.BaseService;
import com.yice.common.core.object.CallResult;
import com.yice.common.core.object.MyRelationParam;
import com.yice.common.core.object.TokenData;
import com.yice.common.core.util.MyModelUtil;
import com.yice.common.sequence.wrapper.IdGeneratorWrapper;
import com.yice.webadmin.app.config.PythonConfig;
......@@ -27,7 +28,6 @@ import com.yice.webadmin.app.service.TuningTaskService;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.java_websocket.client.WebSocketClient;
import org.java_websocket.drafts.Draft_6455;
import org.java_websocket.handshake.ServerHandshake;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
......@@ -35,7 +35,7 @@ import org.springframework.transaction.annotation.Transactional;
import java.net.URI;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.TimeUnit;
/**
* 精调任务运行数据操作服务类。
......@@ -61,9 +61,6 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
@Autowired
private PythonConfig pythonConfig;
private AtomicInteger isSuccess = new AtomicInteger(0);
/**
* 返回当前Service的主表Mapper对象。
*
......@@ -190,6 +187,7 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
@Override
public boolean publishToModelVersion(RunPublishDto runPublishDto) {
String targetModelVersionURl;
Long userID = TokenData.takeFromRequest().getUserId();
TuningRun tuningRun = this.getById(runPublishDto.getRunId());
ModelVersion modelVersion = this.modelVersionService.getById(tuningRun.getModelVersionId());
if (runPublishDto.getPublishWay() == 0) {
......@@ -203,7 +201,7 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
ModelManage modelManage = this.modelManageService.getById(runPublishDto.getModelId());
targetModelVersionURl = pythonConfig.getModelOutputFileBaseDir() + modelManage.getModelName() + "_V" + newVersion;
}
new WebSocketClient(new URI(this.pythonConfig.getPythonWebsocketUri()), new Draft_6455()) {
new WebSocketClient(new URI(this.pythonConfig.getPythonWebsocketUri())) {
@Override
public void onOpen(ServerHandshake serverHandshake) {
log.info("-------------与大模型建立连接-------------");
......@@ -217,11 +215,13 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
String receiveMsg = receiveJson.getString("msg");
JSONObject sendJson = new JSONObject();
if (receiveMsg.equals("send_hash")) {
System.out.println("isSuccess1:" + System.currentTimeMillis());
sendJson.put("fn_index", 44);
sendJson.put("session_hash", runPublishDto.getRunId().toString());
this.send(sendJson.toJSONString());
}
if (receiveMsg.equals("send_data")) {
System.out.println("isSuccess2:" + System.currentTimeMillis());
JSONArray array = new JSONArray();
array.add("zh");
array.add(modelVersion.getVersionName());
......@@ -242,54 +242,50 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
}
log.info("发送服务端的消息:" + sendJson.toJSONString());
if (receiveMsg.equals("process_completed")) {
// this.close();
if (receiveJson.getBoolean("success")) {
isSuccess.set(1);
}
System.out.println("isSuccess4:" + System.currentTimeMillis());
saveAll(receiveJson.getBoolean("success"),tuningRun, targetModelVersionURl, runPublishDto, userID);
}
}
@Override
public void onClose(int i, String s, boolean b) {
log.info("关闭连接:::" + "i = " + i + ":::s = " + s + ":::b = " + b);
}
@Override
public void onError(Exception e) {
log.error("报错了:::" + e.getMessage());
}
}.connect();
System.out.println("isSuccess.get():" + isSuccess.get());
if (isSuccess.get() == 1) {
saveAll(tuningRun, targetModelVersionURl, runPublishDto);
tuningRun.setPublishStatus(1);
} else {
tuningRun.setPublishStatus(-1);
}
return this.updateById(tuningRun);
}.connectBlocking(5000, TimeUnit.MILLISECONDS);
System.out.println("isSuccess5:" + System.currentTimeMillis());
return true;
}
@Transactional(rollbackFor = Exception.class)
public void saveAll(TuningRun tuningRun, String targetModelVersionURl, RunPublishDto runPublishDto) {
ModelVersion modelVersionS = new ModelVersion();
modelVersionS.setRunId(tuningRun.getRunId());
modelVersionS.setTaskId(tuningRun.getTaskId());
modelVersionS.setModelUrl(targetModelVersionURl);
ModelManage modelManageS = new ModelManage();
if (runPublishDto.getPublishWay() == 0) {
modelManageS.setModelDescribe(runPublishDto.getModelDescribe());
modelManageS.setModelName(runPublishDto.getModelName());
modelManageS.setModelType(runPublishDto.getModelType());
modelManageS.setIsBaseModel(0);
modelManageService.saveAndCreateVersion(modelManageS, modelVersionS);
public synchronized Boolean saveAll(Boolean flag, TuningRun tuningRun, String targetModelVersionURl, RunPublishDto runPublishDto, Long userID) {
if (flag) {
tuningRun.setPublishStatus(1);
ModelVersion modelVersionS = new ModelVersion();
modelVersionS.setRunId(tuningRun.getRunId());
modelVersionS.setTaskId(tuningRun.getTaskId());
modelVersionS.setModelUrl(targetModelVersionURl);
modelVersionS.setCreateUserId(userID);
modelVersionS.setUpdateUserId(userID);
ModelManage modelManageS = new ModelManage();
modelManageS.setCreateUserId(userID);
modelManageS.setUpdateUserId(userID);
if (runPublishDto.getPublishWay() == 0) {
modelManageS.setModelDescribe(runPublishDto.getModelDescribe());
modelManageS.setModelName(runPublishDto.getModelName());
modelManageS.setModelType(0);
modelManageS.setIsBaseModel(0);
modelManageService.saveAndCreateVersion(modelManageS, modelVersionS);
} else {
modelVersionS.setModelId(runPublishDto.getModelId());
modelVersionService.saveNew(modelVersionS);
}
} else {
modelManageS = modelManageService.getById(runPublishDto.getModelId());
modelManageS.setModelDescribe(runPublishDto.getModelDescribe());
modelManageS.setIsBaseModel(0);
modelManageService.updateById(modelManageS);
modelVersionS.setModelId(runPublishDto.getModelId());
modelVersionService.saveNew(modelVersionS);
tuningRun.setPublishStatus(-1);
}
return updateById(tuningRun);
}
/**
......
......@@ -739,7 +739,8 @@ public class MyModelUtil {
*/
public static <M> void fillCommonsForInsert(M data) {
Field createdByField = ReflectUtil.getField(data.getClass(), CREATE_USER_ID_FIELD_NAME);
if (createdByField != null) {
Object createdByFieldValue = ReflectUtil.getFieldValue(data, createdByField);
if (createdByField != null && createdByFieldValue == null) {
ReflectUtil.setFieldValue(data, createdByField, TokenData.takeFromRequest().getUserId());
}
Field createTimeField = ReflectUtil.getField(data.getClass(), CREATE_TIME_FIELD_NAME);
......@@ -747,7 +748,8 @@ public class MyModelUtil {
ReflectUtil.setFieldValue(data, createTimeField, new Date());
}
Field updatedByField = ReflectUtil.getField(data.getClass(), UPDATE_USER_ID_FIELD_NAME);
if (updatedByField != null) {
Object updatedByFieldValue = ReflectUtil.getFieldValue(data, updatedByField);
if (updatedByField != null && updatedByFieldValue == null) {
ReflectUtil.setFieldValue(data, updatedByField, TokenData.takeFromRequest().getUserId());
}
Field updateTimeField = ReflectUtil.getField(data.getClass(), UPDATE_TIME_FIELD_NAME);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment