Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
L
lmp_server
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
lmp
lmp_server
Commits
b52fd131
Commit
b52fd131
authored
Dec 07, 2023
by
linpeiqin
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
增加发布到输出的内容
parent
f229451e
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
131 additions
and
26 deletions
+131
-26
PythonConfig.java
.../main/java/com/yice/webadmin/app/config/PythonConfig.java
+9
-0
TuningRunController.java
...com/yice/webadmin/app/controller/TuningRunController.java
+11
-9
ModelManageServiceImpl.java
...ice/webadmin/app/service/impl/ModelManageServiceImpl.java
+0
-2
ModelVersionServiceImpl.java
...ce/webadmin/app/service/impl/ModelVersionServiceImpl.java
+12
-0
TuningRunServiceImpl.java
.../yice/webadmin/app/service/impl/TuningRunServiceImpl.java
+86
-15
application-dev.yml
application-webadmin/src/main/resources/application-dev.yml
+4
-0
pom.xml
pom.xml
+9
-0
No files found.
application-webadmin/src/main/java/com/yice/webadmin/app/config/PythonConfig.java
View file @
b52fd131
...
...
@@ -22,6 +22,10 @@ public class PythonConfig {
* 模型训练基础目录
*/
private
String
modelTuningFileBaseDir
;
/**
* 模型训练基础目录
*/
private
String
modelOutputFileBaseDir
;
/**
* 数据集配置文件
*/
...
...
@@ -34,5 +38,10 @@ public class PythonConfig {
* python平台通用接口地址
*/
private
String
factoryInterface
;
/**
* python websocket地址
*/
private
String
pythonWebsocketUri
;
}
application-webadmin/src/main/java/com/yice/webadmin/app/controller/TuningRunController.java
View file @
b52fd131
...
...
@@ -33,9 +33,7 @@ import org.springframework.web.bind.annotation.*;
import
java.io.File
;
import
java.io.IOException
;
import
java.nio.file.Path
;
import
java.nio.file.Paths
;
import
java.text.SimpleDateFormat
;
import
java.util.List
;
import
java.util.Optional
;
...
...
@@ -80,6 +78,9 @@ public class TuningRunController {
if
(
errorMessage
!=
null
)
{
return
ResponseResult
.
error
(
ErrorCodeEnum
.
DATA_VALIDATED_FAILED
,
errorMessage
);
}
if
(
tuningRunDto
.
getTaskId
()
==
null
)
{
return
ResponseResult
.
error
(
ErrorCodeEnum
.
DATA_VALIDATED_FAILED
,
"请填写任务ID"
);
}
TuningRun
tuningRun
=
MyModelUtil
.
copyTo
(
tuningRunDto
,
TuningRun
.
class
);
TuningRun
tuningRunFilter
=
new
TuningRun
();
tuningRunFilter
.
setTaskId
(
tuningRun
.
getTaskId
());
...
...
@@ -118,6 +119,12 @@ public class TuningRunController {
return
ResponseResult
.
success
();
}
/**
* 获取预览命令。
*
* @param runId 运行ID。
* @return 应答结果对象。
*/
@GetMapping
(
"/getPreviewCommand"
)
public
ResponseResult
<
String
>
getPreviewCommand
(
@RequestParam
Long
runId
)
{
TuningRun
tuningRun
=
this
.
tuningRunService
.
getById
(
runId
);
...
...
@@ -128,7 +135,6 @@ public class TuningRunController {
JSONObject
jsonObject
=
(
JSONObject
)
JSON
.
parse
(
tuningRun
.
getConfiguration
());
JSONArray
datasetVersionNames
=
new
JSONArray
();
datasetVersionNames
.
add
(
datasetManage
.
getDatasetName
()
+
"_V"
+
datasetVersion
.
getDatasetVersion
());
SimpleDateFormat
sdf
=
new
SimpleDateFormat
(
"yyyy-MM-dd-hh-mm-ss"
);
JSONArray
array
=
new
JSONArray
();
array
.
add
(
"zh"
);
array
.
add
(
modelManage
.
getModelName
());
...
...
@@ -314,8 +320,7 @@ public class TuningRunController {
}
private
ResponseEntity
<
Resource
>
getFileByUrl
(
String
url
)
throws
IOException
{
Path
file
=
Paths
.
get
(
url
);
Resource
resource
=
new
UrlResource
(
file
.
toUri
());
// 使用UrlResource从文件路径构建一个Resource对象
Resource
resource
=
new
UrlResource
(
Paths
.
get
(
url
).
toUri
());
// 使用UrlResource从文件路径构建一个Resource对象
Optional
<
Resource
>
optionalResource
=
Optional
.
ofNullable
(
resource
);
if
(
optionalResource
.
isPresent
()
&&
optionalResource
.
get
().
exists
())
{
// 检查文件是否存在
HttpHeaders
headers
=
new
HttpHeaders
();
// 构建HTTP响应头
...
...
@@ -351,10 +356,7 @@ public class TuningRunController {
private
ResponseResult
<
Void
>
doDelete
(
Long
runId
)
{
String
errorMessage
;
// 验证关联Id的数据合法性
TuningRun
originalTuningRun
=
tuningRunService
.
getById
(
runId
);
if
(
originalTuningRun
==
null
)
{
// NOTE: 修改下面方括号中的话述
if
(
tuningRunService
.
getById
(
runId
)
==
null
)
{
errorMessage
=
"数据验证失败,当前 [对象] 并不存在,请刷新后重试!"
;
return
ResponseResult
.
error
(
ErrorCodeEnum
.
DATA_NOT_EXIST
,
errorMessage
);
}
...
...
application-webadmin/src/main/java/com/yice/webadmin/app/service/impl/ModelManageServiceImpl.java
View file @
b52fd131
...
...
@@ -188,8 +188,6 @@ public class ModelManageServiceImpl extends BaseService<ModelManage, Long> imple
if
(
modelVersion
.
getBusinessLabel
()
==
null
&&
modelManage
.
getBusinessLabel
()
!=
null
)
{
modelVersion
.
setBusinessLabel
(
modelManage
.
getBusinessLabel
());
}
modelVersion
.
setModelVersion
(
1
);
modelVersion
.
setIsCompress
(
0
);
this
.
modelVersionService
.
saveNew
(
modelVersion
);
return
reModelManage
;
}
...
...
application-webadmin/src/main/java/com/yice/webadmin/app/service/impl/ModelVersionServiceImpl.java
View file @
b52fd131
...
...
@@ -61,7 +61,19 @@ public class ModelVersionServiceImpl extends BaseService<ModelVersion, Long> imp
@Transactional
(
rollbackFor
=
Exception
.
class
)
@Override
public
ModelVersion
saveNew
(
ModelVersion
modelVersion
)
{
String
modelName
=
this
.
modelManageService
.
getById
(
modelVersion
.
getModelId
()).
getModelName
();
ModelVersion
modelVersionFilter
=
new
ModelVersion
();
modelVersionFilter
.
setModelId
(
modelVersion
.
getModelId
());
List
<
ModelVersion
>
modelVersionList
=
this
.
getModelVersionList
(
modelVersionFilter
,
"model_Version"
);
Integer
version
=
1
;
if
(
modelVersionList
!=
null
&&
modelVersionList
.
size
()
==
0
)
{
version
=
modelVersionList
.
get
(
modelVersionList
.
size
()
-
1
).
getModelVersion
()
+
1
;
}
modelVersion
.
setModelVersion
(
version
);
modelVersion
.
setIsCompress
(
0
);
modelVersion
.
setVersionName
(
modelName
+
"_V"
+
modelVersion
.
getModelVersion
());
modelVersionMapper
.
insert
(
this
.
buildDefaultValue
(
modelVersion
));
//此处应该调用精调运行发布的方法生成模型任务,不应该直接生成!!!!!!!!!!!!!!!!!
ModelTask
modelTask
=
new
ModelTask
();
modelTask
.
setModelVersion
(
modelVersion
.
getModelVersion
());
modelTask
.
setModelId
(
modelVersion
.
getModelId
());
...
...
application-webadmin/src/main/java/com/yice/webadmin/app/service/impl/TuningRunServiceImpl.java
View file @
b52fd131
package
com
.
yice
.
webadmin
.
app
.
service
.
impl
;
import
cn.hutool.core.collection.CollUtil
;
import
com.alibaba.fastjson.JSON
;
import
com.alibaba.fastjson.JSONArray
;
import
com.alibaba.fastjson.JSONObject
;
import
com.baomidou.mybatisplus.core.conditions.query.QueryWrapper
;
import
com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper
;
import
com.github.pagehelper.Page
;
...
...
@@ -10,6 +13,7 @@ import com.yice.common.core.object.CallResult;
import
com.yice.common.core.object.MyRelationParam
;
import
com.yice.common.core.util.MyModelUtil
;
import
com.yice.common.sequence.wrapper.IdGeneratorWrapper
;
import
com.yice.webadmin.app.config.PythonConfig
;
import
com.yice.webadmin.app.dao.TuningRunMapper
;
import
com.yice.webadmin.app.dto.RunPublishDto
;
import
com.yice.webadmin.app.model.ModelManage
;
...
...
@@ -19,11 +23,18 @@ import com.yice.webadmin.app.service.ModelManageService;
import
com.yice.webadmin.app.service.ModelVersionService
;
import
com.yice.webadmin.app.service.TuningRunService
;
import
com.yice.webadmin.app.service.TuningTaskService
;
import
lombok.SneakyThrows
;
import
lombok.extern.slf4j.Slf4j
;
import
org.java_websocket.client.WebSocketClient
;
import
org.java_websocket.drafts.Draft_6455
;
import
org.java_websocket.handshake.ServerHandshake
;
import
org.springframework.beans.factory.annotation.Autowired
;
import
org.springframework.stereotype.Component
;
import
org.springframework.stereotype.Service
;
import
org.springframework.transaction.annotation.Transactional
;
import
java.net.URI
;
import
java.net.URISyntaxException
;
import
java.util.List
;
/**
...
...
@@ -47,6 +58,8 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
@Autowired
private
IdGeneratorWrapper
idGenerator
;
@Autowired
private
PythonConfig
pythonConfig
;
/**
* 返回当前Service的主表Mapper对象。
...
...
@@ -159,28 +172,86 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
return
resultList
;
}
@SneakyThrows
@Transactional
(
rollbackFor
=
Exception
.
class
)
@Override
public
boolean
publishToModelVersion
(
RunPublishDto
runPublishDto
)
{
TuningRun
tuningRun
=
this
.
getById
(
runPublishDto
.
getRunId
());
ModelVersion
modelVersion
=
this
.
modelVersionService
.
getById
(
tuningRun
.
getModelVersionId
());
ModelManage
modelManage
=
this
.
modelManageService
.
getById
(
modelVersion
.
getModelId
());
new
WebSocketClient
(
new
URI
(
this
.
pythonConfig
.
getPythonWebsocketUri
()),
new
Draft_6455
())
{
@Override
public
void
onOpen
(
ServerHandshake
serverHandshake
)
{
log
.
info
(
"-------------与大模型建立连接-------------"
);
}
@Override
public
void
onMessage
(
String
message
)
{
log
.
info
(
"收到来自服务端的消息:"
+
message
);
JSONObject
receiveJson
=
(
JSONObject
)
JSON
.
parse
(
message
);
JSONObject
jsonObject
=
(
JSONObject
)
JSON
.
parse
(
tuningRun
.
getConfiguration
());
String
receiveMsg
=
receiveJson
.
getString
(
"msg"
);
JSONObject
sendJson
=
new
JSONObject
();
if
(
receiveMsg
.
equals
(
"send_hash"
))
{
sendJson
.
put
(
"fn_index"
,
44
);
sendJson
.
put
(
"session_hash"
,
runPublishDto
.
getRunId
().
toString
());
log
.
info
(
"发送服务端的消息:"
+
message
);
System
.
out
.
println
(
sendJson
.
toJSONString
());
this
.
send
(
sendJson
.
toJSONString
());
}
if
(
receiveMsg
.
equals
(
"send_data"
))
{
JSONArray
array
=
new
JSONArray
();
array
.
add
(
"zh"
);
array
.
add
(
modelManage
.
getModelName
());
array
.
add
(
modelVersion
.
getModelUrl
());
JSONArray
ddJson
=
new
JSONArray
();
ddJson
.
add
(
"train_"
+
tuningRun
.
getRunId
());
array
.
add
(
ddJson
);
array
.
add
(
tuningRun
.
getTrainMethod
());
array
.
add
(
jsonObject
.
get
(
"promptTemplate"
));
array
.
add
(
2
);
//路径需要修改,暂时不能确定,后面一条线解决
String
newModelUrl
=
pythonConfig
.
getModelOutputFileBaseDir
()
+
(
runPublishDto
.
getPublishWay
()
==
0
?
runPublishDto
.
getModelName
()
:
modelManage
.
getModelName
());
array
.
add
(
newModelUrl
);
array
.
add
(
"none"
);
sendJson
.
put
(
"data"
,
array
);
sendJson
.
put
(
"event_data"
,
"null"
);
sendJson
.
put
(
"fn_index"
,
44
);
sendJson
.
put
(
"session_hash"
,
runPublishDto
.
getRunId
().
toString
());
System
.
out
.
println
(
array
.
toJSONString
());
log
.
info
(
"发送服务端的消息:"
+
sendJson
.
toJSONString
());
this
.
send
(
sendJson
.
toJSONString
());
}
if
(
receiveMsg
.
equals
(
"process_completed"
))
{
this
.
close
();
}
}
@Override
public
void
onClose
(
int
i
,
String
s
,
boolean
b
)
{
log
.
info
(
"关闭连接:::"
+
"i = "
+
i
+
":::s = "
+
s
+
":::b = "
+
b
);
}
@Override
public
void
onError
(
Exception
e
)
{
log
.
error
(
"报错了:::"
+
e
.
getMessage
());
}
}.
connect
();
if
(
runPublishDto
.
getPublishWay
()
==
0
)
{
ModelManage
modelManage
=
new
ModelManage
();
modelManage
.
setModelDescribe
(
runPublishDto
.
getModelDescribe
());
modelManage
.
setModelName
(
runPublishDto
.
getModelName
());
modelManage
.
setModelType
(
runPublishDto
.
getModelType
());
this
.
modelManageService
.
saveAndCreateVersion
(
modelManage
,
new
ModelVersion
());
ModelManage
modelManage
S
=
new
ModelManage
();
modelManage
S
.
setModelDescribe
(
runPublishDto
.
getModelDescribe
());
modelManage
S
.
setModelName
(
runPublishDto
.
getModelName
());
modelManage
S
.
setModelType
(
runPublishDto
.
getModelType
());
this
.
modelManageService
.
saveAndCreateVersion
(
modelManage
S
,
new
ModelVersion
());
}
else
{
ModelVersion
modelVersion
=
new
ModelVersion
();
modelVersion
.
setModelId
(
runPublishDto
.
getModelId
());
List
<
ModelVersion
>
modelVersionList
=
this
.
modelVersionService
.
getModelVersionList
(
modelVersion
,
"model_version"
);
ModelVersion
modelVersion
S
=
new
ModelVersion
();
modelVersion
S
.
setModelId
(
runPublishDto
.
getModelId
());
List
<
ModelVersion
>
modelVersionList
=
this
.
modelVersionService
.
getModelVersionList
(
modelVersion
S
,
"model_version"
);
int
lastModelVersion
=
modelVersionList
.
get
(
modelVersionList
.
size
()
-
1
).
getModelVersion
();
modelVersion
.
setModelId
(
runPublishDto
.
getModelId
());
modelVersion
.
setModelVersion
(
lastModelVersion
+
1
);
ModelManage
modelManage
=
this
.
modelManageService
.
getById
(
runPublishDto
.
getModelId
());
modelManage
.
setModelDescribe
(
runPublishDto
.
getModelDescribe
());
this
.
modelManageService
.
updateById
(
modelManage
);
this
.
modelVersionService
.
saveNew
(
modelVersion
);
modelVersion
S
.
setModelId
(
runPublishDto
.
getModelId
());
modelVersion
S
.
setModelVersion
(
lastModelVersion
+
1
);
ModelManage
modelManage
S
=
this
.
modelManageService
.
getById
(
runPublishDto
.
getModelId
());
modelManage
S
.
setModelDescribe
(
runPublishDto
.
getModelDescribe
());
this
.
modelManageService
.
updateById
(
modelManage
S
);
this
.
modelVersionService
.
saveNew
(
modelVersion
S
);
}
TuningRun
tuningRun
=
this
.
getById
(
runPublishDto
.
getRunId
());
tuningRun
.
setPublishStatus
(
1
);
return
this
.
updateById
(
tuningRun
);
}
...
...
application-webadmin/src/main/resources/application-dev.yml
View file @
b52fd131
...
...
@@ -65,12 +65,16 @@ python:
datasetFileBaseDir
:
/home/linking/llms/code/LLaMA-Factory-0.3.2/lmp_data/
#模型训练文件基础路径
modelTuningFileBaseDir
:
/home/linking/llms/code/LLaMA-Factory-0.3.2/saves/
#模型训练文件合并后路径
modelOutputFileBaseDir
:
/home/linking/llms/models/
#数据集配置信息
datasetInfo
:
dataset_info.json
#数据集配置目录
datasetFileMenu
:
lmp_data
#python平台通用接口地址
factoryInterface
:
http://192.168.0.36:7860/run/predict
#python websocket 服务地址
pythonWebsocketUri
:
ws://192.168.0.36:7860/queue/join
knowledge
:
#知识库通用接口地址
...
...
pom.xml
View file @
b52fd131
...
...
@@ -59,6 +59,15 @@
<groupId>
org.springframework.boot
</groupId>
<artifactId>
spring-boot-starter-webflux
</artifactId>
</dependency>
<dependency>
<groupId>
org.springframework.boot
</groupId>
<artifactId>
spring-boot-starter-websocket
</artifactId>
</dependency>
<dependency>
<groupId>
org.java-websocket
</groupId>
<artifactId>
Java-WebSocket
</artifactId>
<version>
1.5.4
</version>
</dependency>
<!-- freemarker 模板引擎模块 -->
<dependency>
<groupId>
org.springframework.boot
</groupId>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment