Commit f45e3ad5 authored by linpeiqin's avatar linpeiqin

修改压缩和任务发布的逻辑,暂时验证通过了,后续应该要先存储再更新状态的方式,现在这种方式最终保存,再机械硬盘下时间过长

parent f54e1196
...@@ -9,10 +9,12 @@ import com.github.pagehelper.Page; ...@@ -9,10 +9,12 @@ import com.github.pagehelper.Page;
import com.yice.common.core.base.dao.BaseDaoMapper; import com.yice.common.core.base.dao.BaseDaoMapper;
import com.yice.common.core.base.service.BaseService; import com.yice.common.core.base.service.BaseService;
import com.yice.common.core.object.MyRelationParam; import com.yice.common.core.object.MyRelationParam;
import com.yice.common.core.object.TokenData;
import com.yice.common.core.util.MyModelUtil; import com.yice.common.core.util.MyModelUtil;
import com.yice.common.sequence.wrapper.IdGeneratorWrapper; import com.yice.common.sequence.wrapper.IdGeneratorWrapper;
import com.yice.webadmin.app.config.PythonConfig; import com.yice.webadmin.app.config.PythonConfig;
import com.yice.webadmin.app.dao.ModelCompressMapper; import com.yice.webadmin.app.dao.ModelCompressMapper;
import com.yice.webadmin.app.dto.RunPublishDto;
import com.yice.webadmin.app.model.*; import com.yice.webadmin.app.model.*;
import com.yice.webadmin.app.service.ModelCompressService; import com.yice.webadmin.app.service.ModelCompressService;
import com.yice.webadmin.app.service.ModelManageService; import com.yice.webadmin.app.service.ModelManageService;
...@@ -102,6 +104,7 @@ public class ModelCompressServiceImpl extends BaseService<ModelCompress, Long> i ...@@ -102,6 +104,7 @@ public class ModelCompressServiceImpl extends BaseService<ModelCompress, Long> i
public ModelCompress saveNewWithRelation(ModelCompress modelCompress, JSONObject relationData){ public ModelCompress saveNewWithRelation(ModelCompress modelCompress, JSONObject relationData){
String targetModelVersionURl; String targetModelVersionURl;
Long taskId = idGenerator.nextLongId(); Long taskId = idGenerator.nextLongId();
Long userID = TokenData.takeFromRequest().getUserId();
modelCompress.setTaskId(taskId); modelCompress.setTaskId(taskId);
ModelVersion sourceModelVersion = this.modelVersionService.getById(modelCompress.getSourceVersionId()); ModelVersion sourceModelVersion = this.modelVersionService.getById(modelCompress.getSourceVersionId());
if (modelCompress.getCreateMethod() == 0) { if (modelCompress.getCreateMethod() == 0) {
...@@ -128,9 +131,8 @@ public class ModelCompressServiceImpl extends BaseService<ModelCompress, Long> i ...@@ -128,9 +131,8 @@ public class ModelCompressServiceImpl extends BaseService<ModelCompress, Long> i
JSONObject sendJson = new JSONObject(); JSONObject sendJson = new JSONObject();
if (receiveMsg.equals("send_hash")) { if (receiveMsg.equals("send_hash")) {
sendJson.put("fn_index", 44); sendJson.put("fn_index", 44);
sendJson.put("session_hash", taskId); sendJson.put("session_hash", String.valueOf(taskId));
log.info("发送服务端的消息:" + message); log.info("发送服务端的消息:" + sendJson.toJSONString());
System.out.println(sendJson.toJSONString());
this.send(sendJson.toJSONString()); this.send(sendJson.toJSONString());
} else if (receiveMsg.equals("send_data")) { } else if (receiveMsg.equals("send_data")) {
JSONArray array = new JSONArray(); JSONArray array = new JSONArray();
...@@ -146,29 +148,46 @@ public class ModelCompressServiceImpl extends BaseService<ModelCompress, Long> i ...@@ -146,29 +148,46 @@ public class ModelCompressServiceImpl extends BaseService<ModelCompress, Long> i
sendJson.put("data", array); sendJson.put("data", array);
sendJson.put("event_data", "null"); sendJson.put("event_data", "null");
sendJson.put("fn_index", 44); sendJson.put("fn_index", 44);
sendJson.put("session_hash", taskId); sendJson.put("session_hash", String.valueOf(taskId));
System.out.println(array.toJSONString());
log.info("发送服务端的消息:" + sendJson.toJSONString()); log.info("发送服务端的消息:" + sendJson.toJSONString());
this.send(sendJson.toJSONString()); this.send(sendJson.toJSONString());
} else if (receiveMsg.equals("process_completed")) { } else if (receiveMsg.equals("process_completed")) {
this.close();
if (receiveJson.getBoolean("success")){ if (receiveJson.getBoolean("success")){
saveAll(modelCompress,relationData,targetModelVersionURl,userID);
}
}
}
@Override
public void onClose(int i, String s, boolean b) {
log.info("关闭连接:::" + "i = " + i + ":::s = " + s + ":::b = " + b);
}
@Override
public void onError(Exception e) {
log.error("报错了:::" + e.getMessage());
}
}.connect();
return modelCompress;
}
@Transactional(rollbackFor = Exception.class)
public synchronized void saveAll(ModelCompress modelCompress, JSONObject relationData, String targetModelVersionURl, Long userID) {
ModelVersion modelVersionS = new ModelVersion(); ModelVersion modelVersionS = new ModelVersion();
modelVersionS.setTaskId(modelCompress.getTaskId()); modelVersionS.setTaskId(modelCompress.getTaskId());
modelVersionS.setIsCompress(1); modelVersionS.setIsCompress(1);
modelVersionS.setModelUrl(targetModelVersionURl); modelVersionS.setModelUrl(targetModelVersionURl);
modelVersionS.setCreateUserId(userID);
modelVersionS.setUpdateUserId(userID);
modelCompress.setCreateUserId(userID);
modelCompress.setUpdateUserId(userID);
saveNew(modelCompress); saveNew(modelCompress);
saveOrUpdateOneToOneRelationData(modelCompress, relationData); saveOrUpdateOneToOneRelationData(modelCompress, relationData,userID);
ModelManage modelManageS = new ModelManage();
if (modelCompress.getCreateMethod() == 0) { if (modelCompress.getCreateMethod() == 0) {
modelVersionS.setModelId(modelCompress.getTargetModelId()); modelVersionS.setModelId(modelCompress.getTargetModelId());
modelManageS = modelManageService.getById(modelCompress.getTargetModelId());
modelManageS.setModelDescribe(modelCompress.getTaskDescribe());
modelManageS.setModelType(0);
modelManageS.setIsBaseModel(0);
modelManageService.updateById(modelManageS);
modelVersionService.saveNew(modelVersionS); modelVersionService.saveNew(modelVersionS);
} else { } else {
ModelManage modelManageS = new ModelManage();
modelManageS.setCreateUserId(userID);
modelManageS.setUpdateUserId(userID);
modelManageS.setModelName(modelCompress.getTargetModelName()); modelManageS.setModelName(modelCompress.getTargetModelName());
modelManageS.setModelDescribe(modelCompress.getTaskDescribe()); modelManageS.setModelDescribe(modelCompress.getTaskDescribe());
modelManageS.setModelType(0); modelManageS.setModelType(0);
...@@ -176,20 +195,6 @@ public class ModelCompressServiceImpl extends BaseService<ModelCompress, Long> i ...@@ -176,20 +195,6 @@ public class ModelCompressServiceImpl extends BaseService<ModelCompress, Long> i
modelManageService.saveAndCreateVersion(modelManageS, modelVersionS); modelManageService.saveAndCreateVersion(modelManageS, modelVersionS);
} }
} }
}
}
@Override
public void onClose(int i, String s, boolean b) {
log.info("关闭连接:::" + "i = " + i + ":::s = " + s + ":::b = " + b);
}
@Override
public void onError(Exception e) {
log.error("报错了:::" + e.getMessage());
}
}.connect();
return modelCompress;
}
/** /**
* 更新数据对象。 * 更新数据对象。
* *
...@@ -214,15 +219,17 @@ public class ModelCompressServiceImpl extends BaseService<ModelCompress, Long> i ...@@ -214,15 +219,17 @@ public class ModelCompressServiceImpl extends BaseService<ModelCompress, Long> i
if (modelCompress != null && !this.update(modelCompress, originalModelCompress)) { if (modelCompress != null && !this.update(modelCompress, originalModelCompress)) {
return false; return false;
} }
this.saveOrUpdateOneToOneRelationData(originalModelCompress, relationData); this.saveOrUpdateOneToOneRelationData(originalModelCompress, relationData,TokenData.takeFromRequest().getUserId());
return true; return true;
} }
private void saveOrUpdateOneToOneRelationData(ModelCompress modelCompress, JSONObject relationData) { private void saveOrUpdateOneToOneRelationData(ModelCompress modelCompress, JSONObject relationData,Long userId) {
// 对于一对一新增或更新,如果主键值为空就新增,否则就更新,同时更新updateTime和updateUserId。 // 对于一对一新增或更新,如果主键值为空就新增,否则就更新,同时更新updateTime和updateUserId。
ModelTask modelTask = relationData.getObject("modelTask", ModelTask.class); ModelTask modelTask = relationData.getObject("modelTask", ModelTask.class);
if (modelTask != null) { if (modelTask != null) {
modelTask.setTaskId(modelCompress.getTaskId()); modelTask.setTaskId(modelCompress.getTaskId());
modelTask.setCreateUserId(userId);
modelTask.setUpdateUserId(userId);
modelTaskService.saveNew(modelTask); modelTaskService.saveNew(modelTask);
/*modelTaskService.saveNewOrUpdate(modelTask, /*modelTaskService.saveNewOrUpdate(modelTask,
modelTaskService::saveNew, modelTaskService::update);*/ modelTaskService::saveNew, modelTaskService::update);*/
......
...@@ -76,13 +76,13 @@ public class ModelVersionServiceImpl extends BaseService<ModelVersion, Long> imp ...@@ -76,13 +76,13 @@ public class ModelVersionServiceImpl extends BaseService<ModelVersion, Long> imp
modelVersion.setVersionName(modelName + "_V" + modelVersion.getModelVersion()); modelVersion.setVersionName(modelName + "_V" + modelVersion.getModelVersion());
modelVersionMapper.insert(this.buildDefaultValue(modelVersion)); modelVersionMapper.insert(this.buildDefaultValue(modelVersion));
//此处应该调用精调运行发布的方法生成模型任务,不应该直接生成!!!!!!!!!!!!!!!!! //此处应该调用精调运行发布的方法生成模型任务,不应该直接生成!!!!!!!!!!!!!!!!!
ModelTask modelTask = new ModelTask(); /*ModelTask modelTask = new ModelTask();
modelTask.setModelVersion(modelVersion.getModelVersion()); modelTask.setModelVersion(modelVersion.getModelVersion());
modelTask.setModelId(modelVersion.getModelId()); modelTask.setModelId(modelVersion.getModelId());
modelTask.setTaskType(0); modelTask.setTaskType(0);
modelTask.setVersionId(modelVersion.getVersionId()); modelTask.setVersionId(modelVersion.getVersionId());
modelTask.setVersionName(modelVersion.getVersionName()); modelTask.setVersionName(modelVersion.getVersionName());
this.modelTaskService.saveNew(modelTask); this.modelTaskService.saveNew(modelTask);*/
return modelVersion; return modelVersion;
} }
......
...@@ -11,6 +11,7 @@ import com.yice.common.core.base.dao.BaseDaoMapper; ...@@ -11,6 +11,7 @@ import com.yice.common.core.base.dao.BaseDaoMapper;
import com.yice.common.core.base.service.BaseService; import com.yice.common.core.base.service.BaseService;
import com.yice.common.core.object.CallResult; import com.yice.common.core.object.CallResult;
import com.yice.common.core.object.MyRelationParam; import com.yice.common.core.object.MyRelationParam;
import com.yice.common.core.object.TokenData;
import com.yice.common.core.util.MyModelUtil; import com.yice.common.core.util.MyModelUtil;
import com.yice.common.sequence.wrapper.IdGeneratorWrapper; import com.yice.common.sequence.wrapper.IdGeneratorWrapper;
import com.yice.webadmin.app.config.PythonConfig; import com.yice.webadmin.app.config.PythonConfig;
...@@ -27,7 +28,6 @@ import com.yice.webadmin.app.service.TuningTaskService; ...@@ -27,7 +28,6 @@ import com.yice.webadmin.app.service.TuningTaskService;
import lombok.SneakyThrows; import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.java_websocket.client.WebSocketClient; import org.java_websocket.client.WebSocketClient;
import org.java_websocket.drafts.Draft_6455;
import org.java_websocket.handshake.ServerHandshake; import org.java_websocket.handshake.ServerHandshake;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
...@@ -35,7 +35,7 @@ import org.springframework.transaction.annotation.Transactional; ...@@ -35,7 +35,7 @@ import org.springframework.transaction.annotation.Transactional;
import java.net.URI; import java.net.URI;
import java.util.List; import java.util.List;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.TimeUnit;
/** /**
* 精调任务运行数据操作服务类。 * 精调任务运行数据操作服务类。
...@@ -61,9 +61,6 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement ...@@ -61,9 +61,6 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
@Autowired @Autowired
private PythonConfig pythonConfig; private PythonConfig pythonConfig;
private AtomicInteger isSuccess = new AtomicInteger(0);
/** /**
* 返回当前Service的主表Mapper对象。 * 返回当前Service的主表Mapper对象。
* *
...@@ -190,6 +187,7 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement ...@@ -190,6 +187,7 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
@Override @Override
public boolean publishToModelVersion(RunPublishDto runPublishDto) { public boolean publishToModelVersion(RunPublishDto runPublishDto) {
String targetModelVersionURl; String targetModelVersionURl;
Long userID = TokenData.takeFromRequest().getUserId();
TuningRun tuningRun = this.getById(runPublishDto.getRunId()); TuningRun tuningRun = this.getById(runPublishDto.getRunId());
ModelVersion modelVersion = this.modelVersionService.getById(tuningRun.getModelVersionId()); ModelVersion modelVersion = this.modelVersionService.getById(tuningRun.getModelVersionId());
if (runPublishDto.getPublishWay() == 0) { if (runPublishDto.getPublishWay() == 0) {
...@@ -203,7 +201,7 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement ...@@ -203,7 +201,7 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
ModelManage modelManage = this.modelManageService.getById(runPublishDto.getModelId()); ModelManage modelManage = this.modelManageService.getById(runPublishDto.getModelId());
targetModelVersionURl = pythonConfig.getModelOutputFileBaseDir() + modelManage.getModelName() + "_V" + newVersion; targetModelVersionURl = pythonConfig.getModelOutputFileBaseDir() + modelManage.getModelName() + "_V" + newVersion;
} }
new WebSocketClient(new URI(this.pythonConfig.getPythonWebsocketUri()), new Draft_6455()) { new WebSocketClient(new URI(this.pythonConfig.getPythonWebsocketUri())) {
@Override @Override
public void onOpen(ServerHandshake serverHandshake) { public void onOpen(ServerHandshake serverHandshake) {
log.info("-------------与大模型建立连接-------------"); log.info("-------------与大模型建立连接-------------");
...@@ -217,11 +215,13 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement ...@@ -217,11 +215,13 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
String receiveMsg = receiveJson.getString("msg"); String receiveMsg = receiveJson.getString("msg");
JSONObject sendJson = new JSONObject(); JSONObject sendJson = new JSONObject();
if (receiveMsg.equals("send_hash")) { if (receiveMsg.equals("send_hash")) {
System.out.println("isSuccess1:" + System.currentTimeMillis());
sendJson.put("fn_index", 44); sendJson.put("fn_index", 44);
sendJson.put("session_hash", runPublishDto.getRunId().toString()); sendJson.put("session_hash", runPublishDto.getRunId().toString());
this.send(sendJson.toJSONString()); this.send(sendJson.toJSONString());
} }
if (receiveMsg.equals("send_data")) { if (receiveMsg.equals("send_data")) {
System.out.println("isSuccess2:" + System.currentTimeMillis());
JSONArray array = new JSONArray(); JSONArray array = new JSONArray();
array.add("zh"); array.add("zh");
array.add(modelVersion.getVersionName()); array.add(modelVersion.getVersionName());
...@@ -242,54 +242,50 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement ...@@ -242,54 +242,50 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
} }
log.info("发送服务端的消息:" + sendJson.toJSONString()); log.info("发送服务端的消息:" + sendJson.toJSONString());
if (receiveMsg.equals("process_completed")) { if (receiveMsg.equals("process_completed")) {
// this.close(); System.out.println("isSuccess4:" + System.currentTimeMillis());
if (receiveJson.getBoolean("success")) { saveAll(receiveJson.getBoolean("success"),tuningRun, targetModelVersionURl, runPublishDto, userID);
isSuccess.set(1);
} }
} }
}
@Override @Override
public void onClose(int i, String s, boolean b) { public void onClose(int i, String s, boolean b) {
log.info("关闭连接:::" + "i = " + i + ":::s = " + s + ":::b = " + b); log.info("关闭连接:::" + "i = " + i + ":::s = " + s + ":::b = " + b);
} }
@Override @Override
public void onError(Exception e) { public void onError(Exception e) {
log.error("报错了:::" + e.getMessage()); log.error("报错了:::" + e.getMessage());
} }
}.connect(); }.connectBlocking(5000, TimeUnit.MILLISECONDS);
System.out.println("isSuccess.get():" + isSuccess.get()); System.out.println("isSuccess5:" + System.currentTimeMillis());
if (isSuccess.get() == 1) { return true;
saveAll(tuningRun, targetModelVersionURl, runPublishDto);
tuningRun.setPublishStatus(1);
} else {
tuningRun.setPublishStatus(-1);
}
return this.updateById(tuningRun);
} }
@Transactional(rollbackFor = Exception.class) @Transactional(rollbackFor = Exception.class)
public void saveAll(TuningRun tuningRun, String targetModelVersionURl, RunPublishDto runPublishDto) { public synchronized Boolean saveAll(Boolean flag, TuningRun tuningRun, String targetModelVersionURl, RunPublishDto runPublishDto, Long userID) {
if (flag) {
tuningRun.setPublishStatus(1);
ModelVersion modelVersionS = new ModelVersion(); ModelVersion modelVersionS = new ModelVersion();
modelVersionS.setRunId(tuningRun.getRunId()); modelVersionS.setRunId(tuningRun.getRunId());
modelVersionS.setTaskId(tuningRun.getTaskId()); modelVersionS.setTaskId(tuningRun.getTaskId());
modelVersionS.setModelUrl(targetModelVersionURl); modelVersionS.setModelUrl(targetModelVersionURl);
modelVersionS.setCreateUserId(userID);
modelVersionS.setUpdateUserId(userID);
ModelManage modelManageS = new ModelManage(); ModelManage modelManageS = new ModelManage();
modelManageS.setCreateUserId(userID);
modelManageS.setUpdateUserId(userID);
if (runPublishDto.getPublishWay() == 0) { if (runPublishDto.getPublishWay() == 0) {
modelManageS.setModelDescribe(runPublishDto.getModelDescribe()); modelManageS.setModelDescribe(runPublishDto.getModelDescribe());
modelManageS.setModelName(runPublishDto.getModelName()); modelManageS.setModelName(runPublishDto.getModelName());
modelManageS.setModelType(runPublishDto.getModelType()); modelManageS.setModelType(0);
modelManageS.setIsBaseModel(0); modelManageS.setIsBaseModel(0);
modelManageService.saveAndCreateVersion(modelManageS, modelVersionS); modelManageService.saveAndCreateVersion(modelManageS, modelVersionS);
} else { } else {
modelManageS = modelManageService.getById(runPublishDto.getModelId());
modelManageS.setModelDescribe(runPublishDto.getModelDescribe());
modelManageS.setIsBaseModel(0);
modelManageService.updateById(modelManageS);
modelVersionS.setModelId(runPublishDto.getModelId()); modelVersionS.setModelId(runPublishDto.getModelId());
modelVersionService.saveNew(modelVersionS); modelVersionService.saveNew(modelVersionS);
} }
} else {
tuningRun.setPublishStatus(-1);
}
return updateById(tuningRun);
} }
/** /**
......
...@@ -739,7 +739,8 @@ public class MyModelUtil { ...@@ -739,7 +739,8 @@ public class MyModelUtil {
*/ */
public static <M> void fillCommonsForInsert(M data) { public static <M> void fillCommonsForInsert(M data) {
Field createdByField = ReflectUtil.getField(data.getClass(), CREATE_USER_ID_FIELD_NAME); Field createdByField = ReflectUtil.getField(data.getClass(), CREATE_USER_ID_FIELD_NAME);
if (createdByField != null) { Object createdByFieldValue = ReflectUtil.getFieldValue(data, createdByField);
if (createdByField != null && createdByFieldValue == null) {
ReflectUtil.setFieldValue(data, createdByField, TokenData.takeFromRequest().getUserId()); ReflectUtil.setFieldValue(data, createdByField, TokenData.takeFromRequest().getUserId());
} }
Field createTimeField = ReflectUtil.getField(data.getClass(), CREATE_TIME_FIELD_NAME); Field createTimeField = ReflectUtil.getField(data.getClass(), CREATE_TIME_FIELD_NAME);
...@@ -747,7 +748,8 @@ public class MyModelUtil { ...@@ -747,7 +748,8 @@ public class MyModelUtil {
ReflectUtil.setFieldValue(data, createTimeField, new Date()); ReflectUtil.setFieldValue(data, createTimeField, new Date());
} }
Field updatedByField = ReflectUtil.getField(data.getClass(), UPDATE_USER_ID_FIELD_NAME); Field updatedByField = ReflectUtil.getField(data.getClass(), UPDATE_USER_ID_FIELD_NAME);
if (updatedByField != null) { Object updatedByFieldValue = ReflectUtil.getFieldValue(data, updatedByField);
if (updatedByField != null && updatedByFieldValue == null) {
ReflectUtil.setFieldValue(data, updatedByField, TokenData.takeFromRequest().getUserId()); ReflectUtil.setFieldValue(data, updatedByField, TokenData.takeFromRequest().getUserId());
} }
Field updateTimeField = ReflectUtil.getField(data.getClass(), UPDATE_TIME_FIELD_NAME); Field updateTimeField = ReflectUtil.getField(data.getClass(), UPDATE_TIME_FIELD_NAME);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment