Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
L
lmp_server
Project
Project
Details
Activity
Releases
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
lmp
lmp_server
Commits
7297ac91
Commit
7297ac91
authored
Dec 08, 2023
by
linpeiqin
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
模型压缩逻辑初步提交
parent
f477a48f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
96 additions
and
26 deletions
+96
-26
ModelCompressService.java
...a/com/yice/webadmin/app/service/ModelCompressService.java
+3
-1
ModelCompressServiceImpl.java
...e/webadmin/app/service/impl/ModelCompressServiceImpl.java
+92
-22
TuningRunServiceImpl.java
.../yice/webadmin/app/service/impl/TuningRunServiceImpl.java
+1
-3
No files found.
application-webadmin/src/main/java/com/yice/webadmin/app/service/ModelCompressService.java
View file @
7297ac91
...
...
@@ -5,6 +5,7 @@ import com.yice.common.core.base.service.IBaseService;
import
com.yice.webadmin.app.model.ModelCompress
;
import
com.yice.webadmin.app.model.ModelTask
;
import
java.net.URISyntaxException
;
import
java.util.List
;
/**
...
...
@@ -37,7 +38,7 @@ public interface ModelCompressService extends IBaseService<ModelCompress, Long>
* @param relationData 全部关联从表数据。
* @return 返回新增主表对象。
*/
ModelCompress
saveNewWithRelation
(
ModelCompress
modelCompress
,
JSONObject
relationData
);
ModelCompress
saveNewWithRelation
(
ModelCompress
modelCompress
,
JSONObject
relationData
)
throws
URISyntaxException
;
/**
* 更新数据对象。
...
...
@@ -87,4 +88,5 @@ public interface ModelCompressService extends IBaseService<ModelCompress, Long>
* @return 查询结果集。
*/
List
<
ModelCompress
>
getModelCompressListWithRelation
(
ModelCompress
filter
,
ModelTask
modelTaskFilter
,
String
orderBy
);
}
application-webadmin/src/main/java/com/yice/webadmin/app/service/impl/ModelCompressServiceImpl.java
View file @
7297ac91
package
com
.
yice
.
webadmin
.
app
.
service
.
impl
;
import
cn.hutool.core.collection.CollUtil
;
import
com.alibaba.fastjson.JSON
;
import
com.alibaba.fastjson.JSONArray
;
import
com.alibaba.fastjson.JSONObject
;
import
com.baomidou.mybatisplus.core.conditions.update.UpdateWrapper
;
import
com.github.pagehelper.Page
;
...
...
@@ -9,20 +11,23 @@ import com.yice.common.core.base.service.BaseService;
import
com.yice.common.core.object.MyRelationParam
;
import
com.yice.common.core.util.MyModelUtil
;
import
com.yice.common.sequence.wrapper.IdGeneratorWrapper
;
import
com.yice.webadmin.app.config.PythonConfig
;
import
com.yice.webadmin.app.dao.ModelCompressMapper
;
import
com.yice.webadmin.app.model.ModelCompress
;
import
com.yice.webadmin.app.model.ModelManage
;
import
com.yice.webadmin.app.model.ModelTask
;
import
com.yice.webadmin.app.model.ModelVersion
;
import
com.yice.webadmin.app.model.*
;
import
com.yice.webadmin.app.service.ModelCompressService
;
import
com.yice.webadmin.app.service.ModelManageService
;
import
com.yice.webadmin.app.service.ModelTaskService
;
import
com.yice.webadmin.app.service.ModelVersionService
;
import
lombok.extern.slf4j.Slf4j
;
import
org.java_websocket.client.WebSocketClient
;
import
org.java_websocket.drafts.Draft_6455
;
import
org.java_websocket.handshake.ServerHandshake
;
import
org.springframework.beans.factory.annotation.Autowired
;
import
org.springframework.stereotype.Service
;
import
org.springframework.transaction.annotation.Transactional
;
import
java.net.URI
;
import
java.net.URISyntaxException
;
import
java.util.List
;
/**
...
...
@@ -45,6 +50,10 @@ public class ModelCompressServiceImpl extends BaseService<ModelCompress, Long> i
private
ModelVersionService
modelVersionService
;
@Autowired
private
ModelManageService
modelManageService
;
@Autowired
private
PythonConfig
pythonConfig
;
private
String
modelVersionURl
;
/**
* 返回当前Service的主表Mapper对象。
...
...
@@ -83,29 +92,89 @@ public class ModelCompressServiceImpl extends BaseService<ModelCompress, Long> i
}
}
@Transactional
(
rollbackFor
=
Exception
.
class
)
@Override
public
ModelCompress
saveNewWithRelation
(
ModelCompress
modelCompress
,
JSONObject
relationData
)
{
public
ModelCompress
saveNewWithRelation
(
ModelCompress
modelCompress
,
JSONObject
relationData
)
throws
URISyntaxException
{
this
.
saveNew
(
modelCompress
);
this
.
saveOrUpdateOneToOneRelationData
(
modelCompress
,
relationData
);
ModelVersion
modelVersion
=
this
.
modelVersionService
.
getById
(
modelCompress
.
getSourceVersionId
());
ModelVersion
modelVersionS
=
new
ModelVersion
();
modelVersionS
.
setTaskId
(
modelCompress
.
getTaskId
());
modelVersionS
.
setIsCompress
(
1
);
if
(
modelCompress
.
getCreateMethod
()
==
0
)
{
ModelVersion
modelVersion
=
new
ModelVersion
();
modelVersion
.
setModelId
(
modelCompress
.
getTargetModelId
());
List
<
ModelVersion
>
modelVersionList
=
this
.
modelVersionService
.
getModelVersionList
(
modelVersion
,
"model_version"
);
int
lastModelVersion
=
modelVersionList
.
get
(
modelVersionList
.
size
()
-
1
).
getModelVersion
();
modelVersion
.
setModelId
(
modelCompress
.
getTargetModelId
());
modelVersion
.
setIsCompress
(
1
);
modelVersion
.
setModelVersion
(
lastModelVersion
+
1
);
this
.
modelVersionService
.
saveNew
(
modelVersion
);
modelVersionS
.
setModelId
(
modelCompress
.
getTargetModelId
());
ModelManage
modelManageS
=
this
.
modelManageService
.
getById
(
modelCompress
.
getTargetModelId
());
modelManageS
.
setModelDescribe
(
modelCompress
.
getTaskDescribe
());
modelManageS
.
setModelType
(
0
);
modelManageS
.
setIsBaseModel
(
0
);
this
.
modelManageService
.
updateById
(
modelManageS
);
ModelVersion
modelVersionR
=
this
.
modelVersionService
.
saveNew
(
modelVersionS
);
modelVersionURl
=
pythonConfig
.
getModelOutputFileBaseDir
()
+
modelVersionR
.
getVersionName
();
modelVersionR
.
setModelUrl
(
modelVersionURl
);
this
.
modelVersionService
.
updateById
(
modelVersionR
);
}
else
{
ModelManage
modelManage
=
new
ModelManage
();
ModelVersion
modelVersion
=
new
ModelVersion
();
modelManage
.
setModelName
(
modelCompress
.
getTargetModelName
());
modelManage
.
setModelType
(
0
);
modelVersion
.
setIsCompress
(
1
);
modelManage
.
setModelDescribe
(
"模型压缩"
);
this
.
modelManageService
.
saveAndCreateVersion
(
modelManage
,
modelVersion
);
ModelManage
modelManageS
=
new
ModelManage
();
modelManageS
.
setModelName
(
modelCompress
.
getTargetModelName
());
modelManageS
.
setModelDescribe
(
modelCompress
.
getTaskDescribe
());
modelManageS
.
setModelType
(
0
);
modelManageS
.
setIsBaseModel
(
0
);
modelVersionURl
=
pythonConfig
.
getModelOutputFileBaseDir
()
+
modelCompress
.
getTargetModelName
()
+
"_V1"
;
modelVersionS
.
setModelUrl
(
modelVersionURl
);
this
.
modelManageService
.
saveAndCreateVersion
(
modelManageS
,
modelVersionS
);
}
this
.
saveNew
(
modelCompress
);
this
.
saveOrUpdateOneToOneRelationData
(
modelCompress
,
relationData
);
new
WebSocketClient
(
new
URI
(
this
.
pythonConfig
.
getPythonWebsocketUri
()),
new
Draft_6455
())
{
@Override
public
void
onOpen
(
ServerHandshake
serverHandshake
)
{
log
.
info
(
"-------------与大模型建立连接-------------"
);
}
@Override
public
void
onMessage
(
String
message
)
{
log
.
info
(
"收到来自服务端的消息:"
+
message
);
JSONObject
receiveJson
=
(
JSONObject
)
JSON
.
parse
(
message
);
String
receiveMsg
=
receiveJson
.
getString
(
"msg"
);
JSONObject
sendJson
=
new
JSONObject
();
if
(
receiveMsg
.
equals
(
"send_hash"
))
{
sendJson
.
put
(
"fn_index"
,
44
);
sendJson
.
put
(
"session_hash"
,
modelCompress
.
getTaskId
().
toString
());
log
.
info
(
"发送服务端的消息:"
+
message
);
System
.
out
.
println
(
sendJson
.
toJSONString
());
this
.
send
(
sendJson
.
toJSONString
());
}
if
(
receiveMsg
.
equals
(
"send_data"
))
{
JSONArray
array
=
new
JSONArray
();
array
.
add
(
"zh"
);
array
.
add
(
modelVersion
.
getVersionName
());
array
.
add
(
modelVersion
.
getModelUrl
());
array
.
add
(
new
JSONArray
());
array
.
add
(
""
);
array
.
add
(
""
);
array
.
add
(
2
);
array
.
add
(
modelVersionURl
);
array
.
add
(
"8"
);
sendJson
.
put
(
"data"
,
array
);
sendJson
.
put
(
"event_data"
,
"null"
);
sendJson
.
put
(
"fn_index"
,
44
);
sendJson
.
put
(
"session_hash"
,
modelCompress
.
getTaskId
().
toString
());
System
.
out
.
println
(
array
.
toJSONString
());
log
.
info
(
"发送服务端的消息:"
+
sendJson
.
toJSONString
());
this
.
send
(
sendJson
.
toJSONString
());
}
if
(
receiveMsg
.
equals
(
"process_completed"
))
{
this
.
close
();
}
}
@Override
public
void
onClose
(
int
i
,
String
s
,
boolean
b
)
{
log
.
info
(
"关闭连接:::"
+
"i = "
+
i
+
":::s = "
+
s
+
":::b = "
+
b
);
}
@Override
public
void
onError
(
Exception
e
)
{
log
.
error
(
"报错了:::"
+
e
.
getMessage
());
}
}.
connect
();
return
modelCompress
;
}
...
...
@@ -198,6 +267,7 @@ public class ModelCompressServiceImpl extends BaseService<ModelCompress, Long> i
return
resultList
;
}
private
ModelCompress
buildDefaultValue
(
ModelCompress
modelCompress
)
{
if
(
modelCompress
.
getTaskId
()
==
null
)
{
modelCompress
.
setTaskId
(
idGenerator
.
nextLongId
());
...
...
application-webadmin/src/main/java/com/yice/webadmin/app/service/impl/TuningRunServiceImpl.java
View file @
7297ac91
...
...
@@ -189,7 +189,6 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
public
boolean
publishToModelVersion
(
RunPublishDto
runPublishDto
)
{
TuningRun
tuningRun
=
this
.
getById
(
runPublishDto
.
getRunId
());
ModelVersion
modelVersion
=
this
.
modelVersionService
.
getById
(
tuningRun
.
getModelVersionId
());
ModelManage
modelManage
=
this
.
modelManageService
.
getById
(
modelVersion
.
getModelId
());
ModelVersion
modelVersionS
=
new
ModelVersion
();
modelVersionS
.
setRunId
(
tuningRun
.
getRunId
());
modelVersionS
.
setTaskId
(
tuningRun
.
getTaskId
());
...
...
@@ -243,8 +242,7 @@ public class TuningRunServiceImpl extends BaseService<TuningRun, Long> implement
array
.
add
(
tuningRun
.
getTrainMethod
());
array
.
add
(
jsonObject
.
get
(
"promptTemplate"
));
array
.
add
(
2
);
String
newModelUrl
=
modelVersionURl
;
array
.
add
(
newModelUrl
);
array
.
add
(
modelVersionURl
);
array
.
add
(
"none"
);
sendJson
.
put
(
"data"
,
array
);
sendJson
.
put
(
"event_data"
,
"null"
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment