Commit b18488c2 authored by linpeiqin's avatar linpeiqin

校准模型创建时的状态

parent 6363ed58
......@@ -15,12 +15,10 @@ import com.yice.webadmin.app.dto.ModelDeployDto;
import com.yice.webadmin.app.dto.ModelManageDto;
import com.yice.webadmin.app.dto.ModelTaskDto;
import com.yice.webadmin.app.dto.ModelVersionDto;
import com.yice.webadmin.app.model.ModelDeploy;
import com.yice.webadmin.app.model.ModelManage;
import com.yice.webadmin.app.model.ModelTask;
import com.yice.webadmin.app.model.ModelVersion;
import com.yice.webadmin.app.model.*;
import com.yice.webadmin.app.service.ModelManageService;
import com.yice.webadmin.app.service.ModelVersionService;
import com.yice.webadmin.app.service.TuningRunService;
import com.yice.webadmin.app.vo.ModelManageVo;
import io.swagger.annotations.Api;
import lombok.extern.slf4j.Slf4j;
......@@ -44,7 +42,8 @@ public class ModelManageController {
@Autowired
private ModelManageService modelManageService;
@Autowired
private ModelVersionService modelVersionService;
private TuningRunService tuningRunService;
/**
* 新增模型管理数据,及其关联的从表数据。
......@@ -64,7 +63,7 @@ public class ModelManageController {
}
ModelManage modelManage = MyModelUtil.copyTo(modelManageDto, ModelManage.class);
ModelVersion modelVersion = MyModelUtil.copyTo(modelVersionDto, ModelVersion.class);
modelManage = modelManageService.saveAndCreateVersion(modelManage, modelVersion);
modelManage = this.tuningRunService.createToModel(modelManage,modelVersion);
return ResponseResult.success(modelManage.getModelId());
}
......
......@@ -18,6 +18,7 @@ import com.yice.webadmin.app.dto.ModelVersionDto;
import com.yice.webadmin.app.model.ModelVersion;
import com.yice.webadmin.app.service.ModelVersionService;
import com.yice.webadmin.app.service.ProxyPythonService;
import com.yice.webadmin.app.service.TuningRunService;
import com.yice.webadmin.app.vo.ModelVersionVo;
import io.swagger.annotations.Api;
import lombok.extern.slf4j.Slf4j;
......@@ -45,6 +46,9 @@ public class ModelVersionController {
private PythonConfig pythonConfig;
@Autowired
private ProxyPythonService proxyPythonService;
@Autowired
private TuningRunService tuningRunService;
/**
* 新增模型版本数据。
......@@ -66,7 +70,7 @@ public class ModelVersionController {
if (!callResult.isSuccess()) {
return ResponseResult.errorFrom(callResult);
}
modelVersion = modelVersionService.saveNew(modelVersion);
modelVersion = this.tuningRunService.createToModelVersion(modelVersion);
return ResponseResult.success(modelVersion.getVersionId());
}
......
......@@ -83,4 +83,6 @@ public interface ModelManageService extends IBaseService<ModelManage, Long> {
List<ModelManage> getModelManageListWithRelation(ModelManage filter, ModelVersion modelVersionFilter, ModelTask modelTaskFilter, ModelDeploy modelDeployFilter, String orderBy);
ModelManage saveAndCreateVersion(ModelManage modelManage, ModelVersion modelVersion);
ModelVersion saveAndCreateVersionV(ModelManage modelManage, ModelVersion modelVersion);
}
......@@ -2,6 +2,8 @@ package com.yice.webadmin.app.service;
import com.yice.common.core.base.service.IBaseService;
import com.yice.webadmin.app.dto.RunPublishDto;
import com.yice.webadmin.app.model.ModelManage;
import com.yice.webadmin.app.model.ModelVersion;
import com.yice.webadmin.app.model.TuningRun;
import java.util.List;
......@@ -76,4 +78,8 @@ public interface TuningRunService extends IBaseService<TuningRun, Long> {
List<TuningRun> getTuningRunListWithRelation(TuningRun filter, String orderBy);
boolean publishToModelVersion(RunPublishDto runPublishDto);
ModelManage createToModel(ModelManage modelManage, ModelVersion modelVersion);
ModelVersion createToModelVersion(ModelVersion modelVersion);
}
......@@ -11,14 +11,8 @@ import com.yice.common.core.object.MyRelationParam;
import com.yice.common.core.util.MyModelUtil;
import com.yice.common.sequence.wrapper.IdGeneratorWrapper;
import com.yice.webadmin.app.dao.ModelManageMapper;
import com.yice.webadmin.app.model.ModelDeploy;
import com.yice.webadmin.app.model.ModelManage;
import com.yice.webadmin.app.model.ModelTask;
import com.yice.webadmin.app.model.ModelVersion;
import com.yice.webadmin.app.service.ModelDeployService;
import com.yice.webadmin.app.service.ModelManageService;
import com.yice.webadmin.app.service.ModelTaskService;
import com.yice.webadmin.app.service.ModelVersionService;
import com.yice.webadmin.app.model.*;
import com.yice.webadmin.app.service.*;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
......@@ -48,6 +42,8 @@ public class ModelManageServiceImpl extends BaseService<ModelManage, Long> imple
private ModelDeployService modelDeployService;
@Autowired
private IdGeneratorWrapper idGenerator;
@Autowired
private TuningRunService tuningRunService;
/**
* 返回当前Service的主表Mapper对象。
......@@ -191,6 +187,16 @@ public class ModelManageServiceImpl extends BaseService<ModelManage, Long> imple
this.modelVersionService.saveNew(modelVersion);
return reModelManage;
}
@Transactional
@Override
public ModelVersion saveAndCreateVersionV(ModelManage modelManage, ModelVersion modelVersion) {
ModelManage reModelManage = this.saveNew(modelManage);
modelVersion.setModelId(reModelManage.getModelId());
if (modelVersion.getBusinessLabel() == null && modelManage.getBusinessLabel() != null) {
modelVersion.setBusinessLabel(modelManage.getBusinessLabel());
}
return this.modelVersionService.saveNew(modelVersion);
}
private ModelManage buildDefaultValue(ModelManage modelManage) {
if (modelManage.getModelId() == null) {
......
......@@ -77,14 +77,7 @@ public class ModelVersionServiceImpl extends BaseService<ModelVersion, Long> imp
}
modelVersion.setVersionName(modelName + "_V" + modelVersion.getModelVersion());
modelVersionMapper.insert(this.buildDefaultValue(modelVersion));
//此处应该调用精调运行发布的方法生成模型任务,不应该直接生成!!!!!!!!!!!!!!!!!
/*ModelTask modelTask = new ModelTask();
modelTask.setModelVersion(modelVersion.getModelVersion());
modelTask.setModelId(modelVersion.getModelId());
modelTask.setTaskType(0);
modelTask.setVersionId(modelVersion.getVersionId());
modelTask.setVersionName(modelVersion.getVersionName());
this.modelTaskService.saveNew(modelTask);*/
return modelVersion;
}
......
......@@ -11,20 +11,13 @@ import com.yice.common.core.base.dao.BaseDaoMapper;
import com.yice.common.core.base.service.BaseService;
import com.yice.common.core.object.CallResult;
import com.yice.common.core.object.MyRelationParam;
import com.yice.common.core.object.TokenData;
import com.yice.common.core.util.MyModelUtil;
import com.yice.common.sequence.wrapper.IdGeneratorWrapper;
import com.yice.webadmin.app.config.PythonConfig;
import com.yice.webadmin.app.dao.TuningRunMapper;
import com.yice.webadmin.app.dto.RunPublishDto;
import com.yice.webadmin.app.model.ModelManage;
import com.yice.webadmin.app.model.ModelVersion;
import com.yice.webadmin.app.model.TuningRun;
import com.yice.webadmin.app.model.TuningTask;
import com.yice.webadmin.app.service.ModelManageService;
import com.yice.webadmin.app.service.ModelVersionService;
import com.yice.webadmin.app.service.TuningRunService;
import com.yice.webadmin.app.service.TuningTaskService;
import com.yice.webadmin.app.model.*;
import com.yice.webadmin.app.service.*;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.java_websocket.client.WebSocketClient;
......@@ -34,6 +27,7 @@ import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import java.net.URI;
import java.util.Date;
import java.util.List;
import java.util.concurrent.TimeUnit;
......@@ -55,11 +49,13 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
private ModelManageService modelManageService;
@Autowired
private ModelVersionService modelVersionService;
@Autowired
private IdGeneratorWrapper idGenerator;
@Autowired
private PythonConfig pythonConfig;
@Autowired
private ModelTaskService modelTaskService;
/**
* 返回当前Service的主表Mapper对象。
......@@ -182,25 +178,40 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
return resultList;
}
@SneakyThrows
@Override
public boolean publishToModelVersion(RunPublishDto runPublishDto) {
String targetModelVersionURl;
Long userID = TokenData.takeFromRequest().getUserId();
TuningRun tuningRun = this.getById(runPublishDto.getRunId());
ModelVersion modelVersion = this.modelVersionService.getById(tuningRun.getModelVersionId());
if (runPublishDto.getPublishWay() == 0) {
targetModelVersionURl = pythonConfig.getModelOutputFileBaseDir() + runPublishDto.getModelName() + "_V1";
} else {
ModelVersion lastModelVersion = this.modelVersionService.lastModelVersion(runPublishDto.getModelId());
Integer newVersion = 1;
if (lastModelVersion != null) {
newVersion = lastModelVersion.getModelVersion() + 1;
ModelVersion modelVersionSource = this.modelVersionService.getById(tuningRun.getModelVersionId());
ModelTask modelTask = new ModelTask();
ModelVersion modelVersionTarget = saveAll(tuningRun, runPublishDto, modelVersionSource,modelTask);
messageWithSocket(tuningRun,modelVersionSource,pythonConfig.getModelOutputFileBaseDir() + modelVersionTarget.getVersionName(),modelVersionTarget,modelTask);
return true;
}
ModelManage modelManage = this.modelManageService.getById(runPublishDto.getModelId());
targetModelVersionURl = pythonConfig.getModelOutputFileBaseDir() + modelManage.getModelName() + "_V" + newVersion;
@Override
public ModelManage createToModel(ModelManage modelManage, ModelVersion modelVersion) {
TuningRun tuningRun = this.getById(modelVersion.getRunId());
ModelVersion modelVersionSource = this.modelVersionService.getById(tuningRun.getModelVersionId());
ModelTask modelTask = new ModelTask();
ModelVersion modelVersionTarget = saveAll(tuningRun, modelManage, modelVersion,modelVersionSource,modelTask);
messageWithSocket(tuningRun,modelVersionSource,pythonConfig.getModelOutputFileBaseDir() + modelVersionTarget.getVersionName(),modelVersionTarget,modelTask);
return modelManage;
}
@Override
public ModelVersion createToModelVersion(ModelVersion modelVersion) {
TuningRun tuningRun = this.getById(modelVersion.getRunId());
ModelVersion modelVersionSource = this.modelVersionService.getById(tuningRun.getModelVersionId());
ModelTask modelTask = new ModelTask();
ModelVersion modelVersionTarget = saveAll(tuningRun, modelVersion,modelVersionSource,modelTask);
messageWithSocket(tuningRun,modelVersionSource,pythonConfig.getModelOutputFileBaseDir() + modelVersionTarget.getVersionName(),modelVersionTarget,modelTask);
return modelVersionTarget;
}
@SneakyThrows
private void messageWithSocket(TuningRun tuningRun,ModelVersion modelVersionSource,String targetModelVersionURl,ModelVersion modelVersionTarget,ModelTask modelTask){
new WebSocketClient(new URI(this.pythonConfig.getPythonWebsocketUri())) {
@Override
public void onOpen(ServerHandshake serverHandshake) {
......@@ -217,15 +228,15 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
if (receiveMsg.equals("send_hash")) {
System.out.println("isSuccess1:" + System.currentTimeMillis());
sendJson.put("fn_index", 44);
sendJson.put("session_hash", runPublishDto.getRunId().toString());
sendJson.put("session_hash", tuningRun.getRunId().toString());
this.send(sendJson.toJSONString());
}
if (receiveMsg.equals("send_data")) {
System.out.println("isSuccess2:" + System.currentTimeMillis());
JSONArray array = new JSONArray();
array.add("zh");
array.add(modelVersion.getVersionName());
array.add(modelVersion.getModelUrl());
array.add(modelVersionSource.getVersionName());
array.add(modelVersionSource.getModelUrl());
JSONArray ddJson = new JSONArray();
ddJson.add("train_" + tuningRun.getRunId());
array.add(ddJson);
......@@ -237,57 +248,109 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
sendJson.put("data", array);
sendJson.put("event_data", "null");
sendJson.put("fn_index", 44);
sendJson.put("session_hash", runPublishDto.getRunId().toString());
sendJson.put("session_hash", tuningRun.getRunId().toString());
this.send(sendJson.toJSONString());
}
log.info("发送服务端的消息:" + sendJson.toJSONString());
if (receiveMsg.equals("process_completed")) {
System.out.println("isSuccess4:" + System.currentTimeMillis());
saveAll(receiveJson.getBoolean("success"),tuningRun, targetModelVersionURl, runPublishDto, userID);
updateStatus(receiveJson.getBoolean("success"), tuningRun, modelVersionTarget, targetModelVersionURl,modelTask);
}
}
@Override
public void onClose(int i, String s, boolean b) {
log.info("关闭连接:::" + "i = " + i + ":::s = " + s + ":::b = " + b);
}
@Override
public void onError(Exception e) {
log.error("报错了:::" + e.getMessage());
}
}.connectBlocking(5000, TimeUnit.MILLISECONDS);
System.out.println("isSuccess5:" + System.currentTimeMillis());
return true;
}
@Transactional(rollbackFor = Exception.class)
public synchronized Boolean saveAll(Boolean flag, TuningRun tuningRun, String targetModelVersionURl, RunPublishDto runPublishDto, Long userID) {
if (flag) {
tuningRun.setPublishStatus(1);
ModelVersion modelVersionS = new ModelVersion();
modelVersionS.setRunId(tuningRun.getRunId());
modelVersionS.setTaskId(tuningRun.getTaskId());
modelVersionS.setModelUrl(targetModelVersionURl);
modelVersionS.setCreateUserId(userID);
modelVersionS.setUpdateUserId(userID);
ModelManage modelManageS = new ModelManage();
modelManageS.setCreateUserId(userID);
modelManageS.setUpdateUserId(userID);
public ModelVersion saveAll(TuningRun tuningRun, RunPublishDto runPublishDto, ModelVersion modelVersionSource,ModelTask modelTask) {
ModelVersion modelVersion = new ModelVersion();
this.initModelVersion(modelVersion,tuningRun,modelVersionSource);
ModelManage modelManag = new ModelManage();
if (runPublishDto.getPublishWay() == 0) {
modelManageS.setModelDescribe(runPublishDto.getModelDescribe());
modelManageS.setModelName(runPublishDto.getModelName());
modelManag.setModelDescribe(runPublishDto.getModelDescribe());
modelManag.setModelName(runPublishDto.getModelName());
modelManag.setModelType(0);
modelManag.setIsBaseModel(0);
modelManageService.saveAndCreateVersionV(modelManag, modelVersion);
} else {
modelVersion.setModelId(runPublishDto.getModelId());
modelVersionService.saveNew(modelVersion);
}
this.initModelTask(modelTask,tuningRun,modelVersion);
this.modelTaskService.saveNew(modelTask);
return modelVersion;
}
@Transactional(rollbackFor = Exception.class)
public ModelVersion saveAll(TuningRun tuningRun, ModelManage modelManageS,ModelVersion modelVersion, ModelVersion modelVersionSource,ModelTask modelTask) {
this.initModelVersion(modelVersion,tuningRun,modelVersionSource);
modelManageS.setModelType(0);
modelManageS.setIsBaseModel(0);
modelManageService.saveAndCreateVersion(modelManageS, modelVersionS);
} else {
modelVersionS.setModelId(runPublishDto.getModelId());
modelVersionService.saveNew(modelVersionS);
modelManageService.saveAndCreateVersionV(modelManageS, modelVersion);
this.initModelTask(modelTask,tuningRun,modelVersion);
this.modelTaskService.saveNew(modelTask);
return modelVersion;
}
@Transactional(rollbackFor = Exception.class)
public ModelVersion saveAll(TuningRun tuningRun, ModelVersion modelVersion, ModelVersion modelVersionSource,ModelTask modelTask) {
this.initModelVersion(modelVersion,tuningRun,modelVersionSource);
modelVersionService.saveNew(modelVersion);
modelTask.setModelVersion(modelVersion.getModelVersion());
this.initModelTask(modelTask,tuningRun,modelVersion);
this.modelTaskService.saveNew(modelTask);
return modelVersion;
}
private void initModelVersion(ModelVersion modelVersion,TuningRun tuningRun,ModelVersion modelVersionSource){
modelVersion.setRunId(tuningRun.getRunId());
modelVersion.setTaskId(tuningRun.getTaskId());
modelVersion.setBaseModel(modelVersionSource.getVersionName());
modelVersion.setBaseId(modelVersionSource.getVersionId());
modelVersion.setIsCompress(0);
modelVersion.setStatus(0);
modelVersion.setBasePromptTemplate(modelVersionSource.getBasePromptTemplate());
modelVersion.setRunName(tuningRun.getRunName());
modelVersion.setModelTrainingMethod(tuningRun.getTrainMethod());
TuningTask tuningTask = this.tuningTaskService.getById(tuningRun.getTaskId());
modelVersion.setTrainingTask(tuningTask.getTaskName());
}
private void initModelTask(ModelTask modelTask,TuningRun tuningRun,ModelVersion modelVersion){
modelTask.setModelVersion(modelVersion.getModelVersion());
modelTask.setModelId(modelVersion.getModelId());
modelTask.setTaskType(0);
modelTask.setTaskStatus(3);
modelTask.setVersionId(modelVersion.getVersionId());
modelTask.setVersionName(modelVersion.getVersionName());
modelTask.setTaskId(tuningRun.getTaskId());
}
@Transactional(rollbackFor = Exception.class)
public Boolean updateStatus(Boolean flag, TuningRun tuningRun, ModelVersion modelVersionTarget, String targetModelVersionURl,ModelTask modelTask) {
if (flag) {
modelVersionTarget.setModelUrl(targetModelVersionURl);
modelVersionTarget.setStatus(1);
modelTask.setTaskStatus(1);
tuningRun.setPublishStatus(1);
} else {
modelTask.setTaskStatus(-1);
modelVersionTarget.setStatus(-1);
tuningRun.setPublishStatus(-1);
}
modelTask.setCompleteTime(new Date());
modelTaskService.updateById(modelTask);
modelVersionService.updateById(modelVersionTarget);
return updateById(tuningRun);
}
/**
* 根据最新对象和原有对象的数据对比,判断关联的字典数据和多对一主表数据是否都是合法数据。
*
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment