Commit 830cbcf6 authored by linpeiqin's avatar linpeiqin

修改一些小BUG

parent b834a96a
......@@ -27,6 +27,7 @@
<result column="is_compress" jdbcType="TINYINT" property="isCompress"/>
<result column="task_id" jdbcType="BIGINT" property="taskId"/>
<result column="run_id" jdbcType="BIGINT" property="runId"/>
<result column="base_prompt_template" jdbcType="VARCHAR" property="basePromptTemplate"/>
</resultMap>
<insert id="insertList">
......@@ -55,7 +56,8 @@
model_url,
is_compress,
task_id,
run_id
run_id,
base_prompt_template
)
VALUES
<foreach collection="list" index="index" item="item" separator=",">
......@@ -83,7 +85,9 @@
#{item.modelUrl},
#{item.isCompress},
#{item.taskId},
#{item.runId})
#{item.runId},
#{item.basePromptTemplate}
)
</foreach>
</insert>
......
......@@ -144,4 +144,10 @@ public class ModelVersionDto {
*/
@ApiModelProperty(value = "训练任务ID")
private Long taskId;
/**
* 默认的提示词模板。
*/
@ApiModelProperty(value = "默认的提示词模板")
private String basePromptTemplate;
}
......@@ -129,6 +129,12 @@ public class ModelVersion extends BaseModel {
*/
private Integer isCompress;
/**
* 默认的提示词模板。
*/
private String basePromptTemplate;
@RelationOneToOne(
masterIdField = "modelId",
slaveModelClass = ModelManage.class,
......
......@@ -60,6 +60,7 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
private IdGeneratorWrapper idGenerator;
@Autowired
private PythonConfig pythonConfig;
private String modelName;
/**
* 返回当前Service的主表Mapper对象。
......@@ -179,6 +180,23 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
TuningRun tuningRun = this.getById(runPublishDto.getRunId());
ModelVersion modelVersion = this.modelVersionService.getById(tuningRun.getModelVersionId());
ModelManage modelManage = this.modelManageService.getById(modelVersion.getModelId());
ModelVersion modelVersionS = new ModelVersion();
modelVersionS.setRunId(tuningRun.getRunId());
modelVersionS.setTaskId(tuningRun.getTaskId());
if (runPublishDto.getPublishWay() == 0) {
ModelManage modelManageS = new ModelManage();
modelManageS.setModelDescribe(runPublishDto.getModelDescribe());
modelManageS.setModelName(runPublishDto.getModelName());
modelManageS.setModelType(runPublishDto.getModelType());
this.modelManageService.saveAndCreateVersion(modelManageS, modelVersionS);
modelName = runPublishDto.getModelName() + "_V1";
} else {
modelVersionS.setModelId(runPublishDto.getModelId());
ModelManage modelManageS = this.modelManageService.getById(runPublishDto.getModelId());
modelManageS.setModelDescribe(runPublishDto.getModelDescribe());
this.modelManageService.updateById(modelManageS);
modelName = this.modelVersionService.saveNew(modelVersionS).getVersionName();
}
new WebSocketClient(new URI(this.pythonConfig.getPythonWebsocketUri()), new Draft_6455()) {
@Override
public void onOpen(ServerHandshake serverHandshake) {
......@@ -210,7 +228,7 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
array.add(jsonObject.get("promptTemplate"));
array.add(2);
//路径需要修改,暂时不能确定,后面一条线解决
String newModelUrl = pythonConfig.getModelOutputFileBaseDir() + (runPublishDto.getPublishWay() == 0 ? runPublishDto.getModelName() : modelManage.getModelName());
String newModelUrl = pythonConfig.getModelOutputFileBaseDir() + modelName;
array.add(newModelUrl);
array.add("none");
sendJson.put("data",array);
......@@ -234,24 +252,6 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
log.error("报错了:::" + e.getMessage());
}
}.connect();
if (runPublishDto.getPublishWay() == 0) {
ModelManage modelManageS = new ModelManage();
modelManageS.setModelDescribe(runPublishDto.getModelDescribe());
modelManageS.setModelName(runPublishDto.getModelName());
modelManageS.setModelType(runPublishDto.getModelType());
this.modelManageService.saveAndCreateVersion(modelManageS, new ModelVersion());
} else {
ModelVersion modelVersionS = new ModelVersion();
modelVersionS.setModelId(runPublishDto.getModelId());
List<ModelVersion> modelVersionList = this.modelVersionService.getModelVersionList(modelVersionS, "model_version");
int lastModelVersion = modelVersionList.get(modelVersionList.size() - 1).getModelVersion();
modelVersionS.setModelId(runPublishDto.getModelId());
modelVersionS.setModelVersion(lastModelVersion + 1);
ModelManage modelManageS = this.modelManageService.getById(runPublishDto.getModelId());
modelManageS.setModelDescribe(runPublishDto.getModelDescribe());
this.modelManageService.updateById(modelManageS);
this.modelVersionService.saveNew(modelVersionS);
}
tuningRun.setPublishStatus(1);
return this.updateById(tuningRun);
}
......
......@@ -151,4 +151,10 @@ public class ModelVersionVo extends BaseVo {
*/
@ApiModelProperty(value = "训练任务ID")
private Long taskId;
/**
* 默认的提示词模板。
*/
@ApiModelProperty(value = "默认的提示词模板")
private String basePromptTemplate;
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment