Commit bad23ce0 authored by linpeiqin's avatar linpeiqin

增加多数据集训练及评估,简化配置文件

parent 51746da1
......@@ -2,6 +2,7 @@ package com.yice.webadmin.app.controller;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONException;
import com.alibaba.fastjson.JSONObject;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ArrayNode;
......@@ -34,9 +35,8 @@ import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import java.util.*;
import java.util.stream.Collectors;
/**
* 模型评估操作控制器类。
......@@ -94,11 +94,19 @@ public class ModelEstimateController {
public ResponseResult<JSONArray> getPreviewCommand(@RequestParam Long taskId) {
ModelEstimate modelEstimate = this.modelEstimateService.getById(taskId);
ModelVersion modelVersion = this.modelVersionService.getById(modelEstimate.getModelVersionId());
DatasetVersion datasetVersion = this.datasetVersionService.getById(modelEstimate.getDatasetVersionId());
TuningRun tuningRun = this.tuningRunService.getById(modelVersion.getRunId());
JSONObject jsonObject = (JSONObject) JSON.parse(modelEstimate.getConfiguration());
JSONArray datasetVersionNames = new JSONArray();
datasetVersionNames.add(datasetVersion.getVersionName());
Set<Long> datasetVersionIds = Arrays.stream(modelEstimate.getDatasetVersionIds().split(",")).mapToLong(Long::parseLong).boxed().collect(Collectors.toSet());
datasetVersionService.getInList(datasetVersionIds).stream()
.map(DatasetVersion::getVersionName)
.forEach(versionName -> {
try {
datasetVersionNames.add(versionName);
} catch (JSONException e) {
e.printStackTrace();
}
});
JSONArray array = new JSONArray();
array.add("zh");
array.add(modelVersion.getVersionName());
......@@ -195,19 +203,21 @@ public class ModelEstimateController {
}
ModelEstimate modelEstimate = this.modelEstimateService.getById(taskId);
ModelVersion modelVersion = this.modelVersionService.getById(modelEstimate.getModelVersionId());
DatasetVersion datasetVersion = this.datasetVersionService.getById(modelEstimate.getDatasetVersionId());
Set<Long> datasetVersionIds = Arrays.stream(modelEstimate.getDatasetVersionIds().split(",")).mapToLong(Long::parseLong).boxed().collect(Collectors.toSet());
List<DatasetVersion> datasetVersions = datasetVersionService.getInList(datasetVersionIds);
List<String> dataList = new ArrayList<>();
int i = 0;
String url = this.pythonConfig.getModelEstimateFileBaseDir() + modelVersion.getVersionName() + File.separator + "evl_" + taskId + File.separator + "generated_predictions.jsonl";
//获取评估输出详情
List<JSONObject> jsonObjects = this.getFileJsonArray(url);
File file = new File(datasetVersion.getFileUrl()); // 指定文件路径
ObjectMapper objectMapper = new ObjectMapper();
ArrayNode arrayNode = (ArrayNode) objectMapper.readTree(file); // 读取JSON文件内容并转换为ArrayNode对象
List<String> dataList = new ArrayList<>();
int i = 0;
for (JSONObject jsonNode : jsonObjects) { // 遍历JSON数组并取出每个元素(ObjectNode)中的数据
jsonNode.put("input", arrayNode.get(i).get("instruction").textValue() + arrayNode.get(i).get("input").textValue());
dataList.add(jsonNode.toJSONString());
i++;
for (DatasetVersion datasetVersion : datasetVersions) {
ArrayNode arrayNode = (ArrayNode) new ObjectMapper().readTree(new File(datasetVersion.getFileUrl())); // 读取JSON文件内容并转换为ArrayNode对象
for (JSONObject jsonNode : jsonObjects) { // 遍历JSON数组并取出每个元素(ObjectNode)中的数据
jsonNode.put("datasetName",datasetVersion.getVersionName());
jsonNode.put("input", arrayNode.get(i).get("instruction").textValue() + arrayNode.get(i).get("input").textValue());
dataList.add(jsonNode.toJSONString());
i++;
}
}
int total = dataList.size(); // 获取总数据量
int page = pageParam.getPageNum();
......
......@@ -2,6 +2,7 @@ package com.yice.webadmin.app.controller;
import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONException;
import com.alibaba.fastjson.JSONObject;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.github.pagehelper.page.PageMethod;
......@@ -33,8 +34,11 @@ import org.springframework.web.bind.annotation.*;
import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.Date;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
/**
* 精调任务运行操作控制器类。
......@@ -145,10 +149,18 @@ public class TuningRunController {
public ResponseResult<JSONArray> getPreviewCommand(@RequestParam Long runId) {
TuningRun tuningRun = this.tuningRunService.getById(runId);
ModelVersion modelVersion = this.modelVersionService.getById(tuningRun.getModelVersionId());
DatasetVersion datasetVersion = this.datasetVersionService.getById(tuningRun.getDatasetVersionId());
JSONObject jsonObject = (JSONObject) JSON.parse(tuningRun.getConfiguration());
JSONArray datasetVersionNames = new JSONArray();
datasetVersionNames.add(datasetVersion.getVersionName());
Set<Long> datasetVersionIds = Arrays.stream(tuningRun.getDatasetVersionIds().split(",")).mapToLong(Long::parseLong).boxed().collect(Collectors.toSet());
datasetVersionService.getInList(datasetVersionIds).stream()
.map(DatasetVersion::getVersionName)
.forEach(versionName -> {
try {
datasetVersionNames.add(versionName);
} catch (JSONException e) {
e.printStackTrace();
}
});
JSONArray array = new JSONArray();
array.add("zh");
array.add(modelVersion.getVersionName());
......
......@@ -10,7 +10,7 @@
<result column="task_name" jdbcType="VARCHAR" property="taskName"/>
<result column="task_describe" jdbcType="VARCHAR" property="taskDescribe"/>
<result column="model_version_id" jdbcType="BIGINT" property="modelVersionId"/>
<result column="dataset_version_id" jdbcType="BIGINT" property="datasetVersionId"/>
<result column="dataset_version_ids" jdbcType="VARCHAR" property="datasetVersionIds"/>
<result column="scoring_mode" jdbcType="TINYINT" property="scoringMode"/>
<result column="configuration" jdbcType="LONGVARCHAR" property="configuration"/>
<result column="task_status" jdbcType="TINYINT" property="taskStatus"/>
......@@ -26,7 +26,7 @@
task_name,
task_describe,
model_version_id,
dataset_version_id,
dataset_version_ids,
scoring_mode,
configuration,
task_status)
......@@ -40,7 +40,7 @@
#{item.taskName},
#{item.taskDescribe},
#{item.modelVersionId},
#{item.datasetVersionId},
#{item.datasetVersionIds},
#{item.scoringMode},
#{item.configuration},
#{item.taskStatus})
......
......@@ -10,7 +10,7 @@
<result column="model_id" jdbcType="BIGINT" property="modelId"/>
<result column="model_version_id" jdbcType="BIGINT" property="modelVersionId"/>
<result column="run_time" jdbcType="BIGINT" property="runTime"/>
<result column="dataset_version_id" jdbcType="BIGINT" property="datasetVersionId"/>
<result column="dataset_version_ids" jdbcType="VARCHAR" property="datasetVersionIds"/>
<result column="split_ratio" jdbcType="TINYINT" property="splitRatio"/>
<result column="train_mode" jdbcType="VARCHAR" property="trainMode"/>
<result column="train_method" jdbcType="VARCHAR" property="trainMethod"/>
......@@ -32,7 +32,7 @@
model_id,
model_version_id,
run_time,
dataset_version_id,
dataset_version_ids,
split_ratio,
train_mode,
train_method,
......@@ -52,7 +52,7 @@
#{item.modelId},
#{item.modelVersionId},
#{item.runTime},
#{item.datasetVersionId},
#{item.datasetVersionIds},
#{item.splitRatio},
#{item.trainMode},
#{item.trainMethod},
......
......@@ -43,10 +43,10 @@ public class ModelEstimateDto {
private Long modelVersionId;
/**
* 评估数据集版本ID。
* 评估数据集版本IDs
*/
@ApiModelProperty(value = "评估数据集版本ID")
private Long datasetVersionId;
@ApiModelProperty(value = "评估数据集版本IDs")
private String datasetVersionIds;
/**
* 打分模式。
......
......@@ -67,10 +67,10 @@ public class TuningRunDto {
private Long runTime;
/**
* 数据集版本ID。
* 数据集版本IDs
*/
@ApiModelProperty(value = "数据集版本ID")
private Long datasetVersionId;
@ApiModelProperty(value = "数据集版本IDs")
private String datasetVersionIds;
/**
* 拆分比例。
......
......@@ -61,20 +61,9 @@ public class ModelEstimate extends BaseModel {
private Map<String, Object> modelVersionIdDictMap;
/**
* 评估数据集版本ID。
* 评估数据集版本IDs
*/
private Long datasetVersionId;
/**
* 评估数据集版本字典。
*/
@RelationDict(
masterIdField = "datasetVersionId",
slaveModelClass = DatasetVersion.class,
slaveIdField = "versionId",
slaveNameField = "versionName")
@TableField(exist = false)
private Map<String, Object> datasetVersionIdDictMap;
private String datasetVersionIds;
/**
* 打分模式。
......
......@@ -68,9 +68,9 @@ public class TuningRun extends BaseModel {
private Long runTime;
/**
* 数据集版本ID。
* 数据集版本IDs
*/
private Long datasetVersionId;
private String datasetVersionIds;
/**
* 拆分比例。
......
......@@ -47,7 +47,7 @@ public class ModelEstimateVo extends BaseVo {
* 评估数据集版本ID。
*/
@ApiModelProperty(value = "评估数据集版本ID")
private Long datasetVersionId;
private String datasetVersionIds;
/**
* 打分模式。
*/
......
......@@ -68,10 +68,10 @@ public class TuningRunVo extends BaseVo {
private Long runTime;
/**
* 数据集版本ID。
* 数据集版本IDs
*/
@ApiModelProperty(value = "数据集版本ID")
private Long datasetVersionId;
@ApiModelProperty(value = "数据集版本IDs")
private String datasetVersionIds;
/**
* 拆分比例。
......
......@@ -51,7 +51,7 @@ application:
# 初始化密码。
defaultUserPassword: 123456
# 缺省的文件上传根目录。
uploadFileBaseDir: /home/linking/llms/code/LLaMA-Factory-0.3.2/lmp_data/
uploadFileBaseDir: ${python.baseDir}/code/lmp_training_api/lmp_data/
# 跨域的IP(http://192.168.10.10:8086)白名单列表,多个IP之间逗号分隔(* 表示全部信任,空白表示禁用跨域信任)。
credentialIpList: "*"
......@@ -61,29 +61,35 @@ application:
excludeLogin: false
python:
#基础路径
baseDir: z:/home/linking/llms
#训练服务地址
trainAddress: 192.168.0.36:7860
#推理服务地址
predictAddress: 192.168.0.36:7861
#数据集文件基础路径
datasetFileBaseDir: /home/linking/llms/code/LLaMA-Factory-0.3.2/lmp_data/
datasetFileBaseDir: ${python.baseDir}/code/lmp_training_api/lmp_data/
#模型训练文件基础路径
modelTuningFileBaseDir: /home/linking/llms/code/LLaMA-Factory-0.3.2/saves/
modelTuningFileBaseDir: ${python.baseDir}/code/lmp_training_api/saves/
#模型训练文件合并后路径
modelOutputFileBaseDir: /home/linking/llms/models/
modelOutputFileBaseDir: ${python.baseDir}/models/
#模型评估文件基础路径
modelEstimateFileBaseDir: /home/linking/llms/code/LLaMA-Factory-0.3.2/saves/
modelEstimateFileBaseDir: ${python.baseDir}/code/lmp_training_api/saves/
#数据集配置信息
datasetInfo: dataset_info.json
#数据集配置目录
datasetFileMenu: lmp_data
#python平台通用接口地址
factoryInterface: http://192.168.0.36:7860/run/predict
factoryInterface: http://${python.trainAddress}/run/predict
#python websocket 服务地址
pythonWebsocketUri: ws://192.168.0.36:7860/queue/join
pythonWebsocketUri: ws://${python.trainAddress}/queue/join
#输出控制地址
controllerAddress: http://127.0.0.1:20001
#对话基础路径
chatAddress: http://192.168.0.36:8000/
llm-model:
#模型管理基础路径
llmModelInterface: http://192.168.0.36:7861/llm_model/
llmModelInterface: http://${python.predictAddress}/llm_model/
#模型停止
stop: stop
#模型模型部署
......@@ -94,12 +100,12 @@ llm-model:
listRunningModels: list_running_models
other:
#其他管理接口
otherInterface: http://192.168.0.36:7861/other/
otherInterface: http://${python.predictAddress}/other/
#获取gpu信息
getGpuInfo: get_gpu_info
knowledge:
#知识库通用接口地址
knowledgeInterface: http://192.168.0.36:7861/knowledge_base/
knowledgeInterface: http://${python.predictAddress}/knowledge_base/
#创建知识库
create: create_knowledge_base
#获取知识库列表
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment