142 Commits

Author SHA1 Message Date
9db2f96557 更新 .gitea/workflows/prod_build_manual.yaml 2026-04-24 10:20:13 +08:00
zcr
cc2404831d Merge branch 'develop' 2026-04-17 14:15:57 +08:00
zcr
6892361050 修复design印花部分 overall 模式印花平铺起始从印花图片中心开始
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-04-15 17:36:29 +08:00
zcr
f0b73d5fc1 修复design印花部分 mask_inv_print 提取错误
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-04-15 17:23:00 +08:00
zcr
ea7522a45d Merge branch 'develop'
# Conflicts:
#	app/service/design_fast/utils/synthesis_item.py
#	app/service/prompt_generation/chatgpt_for_translation.py
2026-04-14 10:18:04 +08:00
zcr
7543d6b346 feat: 更新flux2 klein 的输出示例 ; fix:
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-04-14 10:16:30 +08:00
zcr
3ca4003e30 feat: 更新flux2 klein 的输出示例 ; fix: 2026-03-30 17:22:14 +08:00
zcr
59e8a88a01 feat: 更新flux2 klein 的输出示例 ; fix: 2026-03-30 17:14:18 +08:00
zcr
ac6f74438d feat: 更新分割模型参数 ; fix: 2026-03-27 15:10:16 +08:00
zcr
e0d519bfb3 feat: 更新分割模型参数 ; fix: 2026-03-27 15:10:16 +08:00
zcr
3414f2c1aa feat: 更新分割模型参数 ; fix: 2026-03-27 14:59:27 +08:00
zcr
160bf1a6b1 feat: 更新分割模型参数 ; fix: 2026-03-27 14:56:32 +08:00
zcr
79eb3fb859 feat: flux2 增加状态码 ; fix: 2026-03-25 23:17:03 +08:00
zcr
4395d67288 feat: flux2 增加状态码 ; fix: 2026-03-25 23:17:02 +08:00
zcr
674514ec11 feat: brand dna logo生成替换flux2klein ; fix: 2026-03-25 23:16:53 +08:00
zcr
e9ca1d301b feat: 新增flux2klein作为moodboard的localbase 模型 ; fix: 2026-03-25 23:16:48 +08:00
zcr
a4d55fdb14 feat: flux2 增加状态码 ; fix: 2026-03-25 10:29:03 +08:00
zcr
7f2f79d029 feat: flux2 增加状态码 ; fix: 2026-03-24 14:35:39 +08:00
zcr
6d9e96305b feat: brand dna logo生成替换flux2klein ; fix: 2026-03-23 11:21:50 +08:00
zcr
d93c50ce2b feat: 新增flux2klein作为moodboard的localbase 模型 ; fix: 2026-03-23 10:46:16 +08:00
zcr
316c2fef67 feat:
fix: 删除计数中间件
2026-03-13 11:22:57 +08:00
zcr
e25f49a776 feat:
fix: 删除计数中间件
2026-03-13 11:22:12 +08:00
zcr
33b4dd4a7f feat:
fix: 翻译 模型ip更换
2026-03-05 15:20:40 +08:00
zcr
ac8ca4dd46 feat:
fix: 翻译 模型ip更换
2026-03-05 15:17:47 +08:00
zcr
db88d9b813 feat:
fix: 翻译 模型ip更换
2026-03-05 15:13:40 +08:00
zcr
6ea9837f83 feat:
fix: sam 模型ip更换
2026-03-05 15:06:27 +08:00
zcr
7e48420ba7 feat:
fix: sam 模型ip更换
2026-03-05 15:06:19 +08:00
zcr
13002eefda feat:
fix:  others 旋转功能修复
2026-03-05 14:02:01 +08:00
zcr
09e25f423e feat:
fix:  others 旋转功能修复
2026-03-05 14:01:29 +08:00
zcr
bcc82ba065 feat:
fix:  替换项目中所有mmcv的依赖
2026-02-27 15:45:28 +08:00
zcr
dcc88adfc0 feat:
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
fix:  替换项目中所有mmcv的依赖
2026-02-27 15:26:07 +08:00
zcr
e4fd7b2fb9 feat:
fix:  替换项目中所有mmcv的依赖
2026-02-10 13:04:14 +08:00
ba93d33a17 更新 .gitea/workflows/prod_build_scheduled.yaml 2026-02-10 11:44:59 +08:00
292da1de2b 更新 .gitea/workflows/prod_build_manual.yaml 2026-02-10 11:44:48 +08:00
c6ebfae942 更新 .gitea/workflows/ltx_develop_build_manual.yaml 2026-02-10 11:44:36 +08:00
4dd8416911 更新 .gitea/workflows/develop_build_scheduled.yaml 2026-02-10 11:44:27 +08:00
zcr
bafcb68028 feat:
fix:  替换项目中所有mmcv的依赖
2026-02-10 11:34:09 +08:00
zcr
c03b7e263e feat:
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
fix:  替换项目中所有mmcv的依赖
2026-02-10 11:17:31 +08:00
zcr
200414e5ad feat: 停用flux2 img2product 复用sdxl img2product
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
fix:
2026-02-09 17:33:07 +08:00
23a6a30cc4 更新 .gitea/workflows/develop_build_manual.yaml 2026-02-09 15:28:47 +08:00
4d0688afd5 删除 .gitea/workflows/develop_build_commit.yaml 2026-02-09 15:28:22 +08:00
zcr
9a00fce0eb feat: 印花逻辑修改 默认不处理除overall以外所有印花类型
fix:
2026-02-03 16:44:01 +08:00
zcr
4656eeee91 feat: 印花逻辑修改 默认不处理除overall以外所有印花类型
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
fix:
2026-02-03 16:43:33 +08:00
zcr
f017d7e212 feat:
fix: 修复sketch类型为others时 跳过 上印花 导致的尺寸与分割尺寸不一致问题, 修复others分割出后片的问题
2026-02-03 16:23:05 +08:00
zcr
fe25f5878b feat:
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
fix: 修复sketch类型为others时 跳过 上印花 导致的尺寸与分割尺寸不一致问题, 修复others分割出后片的问题
2026-02-03 16:22:47 +08:00
zcr
c1b80c58f1 feat:
fix: 队列名修复
2026-02-02 15:37:31 +08:00
zcr
2cc17a1210 feat:
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
fix: 队列名修复
2026-02-02 15:37:01 +08:00
zcr
be92d48abb feat:
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
fix: 回溯镜像旋转逻辑
2026-01-30 15:45:57 +08:00
zcr
57be559cf2 feat:
fix:  修复类别为other时出现的pipeline item缺失
2026-01-29 16:26:16 +08:00
zcr
f8382f280f feat:
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
fix:  修复类别为other时出现的pipeline item缺失
2026-01-29 16:25:43 +08:00
zcr
c24862507f feat:
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
fix:  slogan 服务迁移
2026-01-28 15:37:03 +08:00
zcr
d5452098f3 feat:
fix: 移除打印
2026-01-27 13:53:21 +08:00
zcr
315e298ba8 feat:
fix:
2026-01-27 13:53:17 +08:00
zcr
ec26c8b507 feat:
fix:  印花overall 角度异常
2026-01-27 13:53:04 +08:00
zcr
e02ca351b6 feat:
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
fix:  印花overall 角度异常
2026-01-27 13:42:34 +08:00
zcr
c987f498bc feat:
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
fix:
2026-01-27 11:28:36 +08:00
zcr
3aa8dfa0f4 feat:
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
fix: 移除打印
2026-01-27 10:12:23 +08:00
zcr
265f4de50e feat:
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
fix: 更新端口
2026-01-26 16:32:30 +08:00
zcr
a996a1853d feat:
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
fix: 更新端口
2026-01-26 16:11:10 +08:00
zcr
1cbd019ffd feat: 更新翻译模型
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
fix:
2026-01-26 15:56:42 +08:00
zcr
e2a49e2f3a feat: 新增to product img flux2 版,停用sdxl版
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
fix:
2026-01-26 15:26:15 +08:00
zcr
66037c94e6 feat: 新增to product img flux2 版,停用sdxl版
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
fix:
2026-01-26 15:23:49 +08:00
zcr
754e8d7735 feat: 新增to product img flux2 版,停用sdxl版
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
fix:
2026-01-26 15:21:51 +08:00
zcr
cdaeb6daac feat: 新增to product img flux2 版,停用sdxl版
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
fix:
2026-01-26 15:19:28 +08:00
zcr
863d9287dc fix: 参数对齐
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
(cherry picked from commit ddef6af1cf)
2026-01-26 14:56:49 +08:00
zcr
ddef6af1cf fix: 参数对齐 2026-01-26 14:49:57 +08:00
zcr
fdffb1e724 Merge branch 'develop' 2026-01-24 22:05:57 +08:00
zcr
ecf10611c2 fix: merge 模式下 镜像和旋转功能与前端对其
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-24 14:43:10 +08:00
zcr
f78809b22a fix: merge 模式下 镜像和旋转功能与前端对其
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-24 14:03:35 +08:00
63a2f5e007 更新 .gitea/workflows/prod_build_scheduled.yaml 2026-01-24 03:20:19 +08:00
aeb67f366a 更新 .gitea/workflows/prod_build_manual.yaml 2026-01-24 03:20:03 +08:00
zcr
c244e313ae Merge branch 'develop' 2026-01-24 03:14:24 +08:00
zcr
15934085e0 fix: 修复design merge 模式 ,旋转sketch位置计算错误
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-24 03:03:26 +08:00
zcr
40b41d02a4 fix: 修复design merge 模式 ,旋转sketch位置计算错误
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-24 03:01:34 +08:00
1a1fd46f81 添加 .gitea/workflows/prod_build_manual.yaml 2026-01-24 02:55:56 +08:00
zcr
dcd8e26f0f fix: 修复design merge 模式 ,旋转sketch位置计算错误 2026-01-24 02:46:52 +08:00
zcr
fd94a3b4f0 Merge branch 'develop' 2026-01-24 02:44:51 +08:00
zcr
682c589238 fix: 修复design merge 模式 ,旋转sketch位置计算错误
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-24 02:44:13 +08:00
3f9309235a 2026.01.23 生产部署
All checks were successful
定时 AiDA python prod 分支构建部署 / scheduled_deploy (push) Successful in 2m20s
2026-01-23 21:06:19 +08:00
zcr
a578aa4fc5 暂时移除design 缓存
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-23 18:09:21 +08:00
zcr
ebd665b241 Merge branch 'develop'
# Conflicts:
#	app/core/config.py
2026-01-23 17:37:49 +08:00
zcr
ec649152e3 移除keypoint 缓存
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-23 17:34:51 +08:00
833e1bc924 暂时停用定时部署 2026-01-23 10:41:43 +08:00
zcr
7ed5911336 服务迁移测试
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-22 13:41:47 +08:00
zcr
b09538e294 feat: 新增design模式 merge,回参增加mask
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-15 14:13:56 +08:00
zcr
313863a6a7 fix: design 预处理 读取四通道图片背景变黑问题
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-13 15:36:28 +08:00
zcr
9ca1a2ba1f fix: design 单品未传design_type
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-13 14:58:31 +08:00
litianxiang
fb46a9521d Merge remote-tracking branch 'origin/develop' into dev-ltx
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-13 13:57:28 +08:00
litianxiang
b90688f835 更改增量更新日志级别 2026-01-13 13:57:15 +08:00
zcr
7e30779aec feat: seg any thing 新增box模式
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-13 12:43:30 +08:00
zcr
f7294f5966 feat: seg any thing 新增box模式
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-13 12:32:18 +08:00
zcr
0ac5a4e0a8 Merge remote-tracking branch 'origin/develop' into develop
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-12 16:18:15 +08:00
zcr
40b57b749c feat: 新增design模式 merge,前端CV python 合成 2026-01-12 16:18:04 +08:00
litianxiang
b8a538a8a1 fix:增量更新向量问题修改
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-12 13:59:06 +08:00
litianxiang
29b4f43a27 debug:推荐接口
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-12 13:34:56 +08:00
litianxiang
69dc20207d debug:推荐接口
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-12 13:03:58 +08:00
litianxiang
18979af604 debug:推荐接口返回redis值
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-12 13:01:26 +08:00
litianxiang
74406f9be4 推荐接口更新向量接口注册
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-12 11:59:01 +08:00
litianxiang
df99e3ac76 新增查看redis内容接口
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-12 11:51:37 +08:00
litianxiang
19346c2eb7 Merge remote-tracking branch 'origin/develop' into dev-ltx
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-12 09:51:52 +08:00
litianxiang
2af9cbfe78 fix:推荐接口 2026-01-12 09:49:07 +08:00
zcr
fe12b5697d fix: design 镜像默认值修改,旋转方向和前端保持一致
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-09 17:40:49 +08:00
zcr
c04d4877b0 fix: design 回参新增镜像旋转参数
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-09 17:12:53 +08:00
zcr
91016e6cae fix: design 回参新增镜像旋转参数
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-09 17:08:16 +08:00
zcr
0f4bb260ad fix: design 回参新增镜像旋转参数
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-09 17:06:39 +08:00
zcr
c792106f02 fix: design 回参新增镜像旋转参数
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-09 15:42:42 +08:00
zcr
deac5a4cab fix: design item sketch旋转参数为none
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-09 12:31:34 +08:00
zcr
15682036b3 feat : 新增seg anything 接口 ,接口文档补充
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-08 17:39:27 +08:00
zcr
9ba3a0ca49 feat : 新增seg anything 接口
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-08 17:33:54 +08:00
zcr
f6963070fb feat : 支持上下左右同时镜像
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-08 13:47:44 +08:00
zcr
12f5ca3ca3 feat : design 示例说明
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-08 10:44:02 +08:00
zcr
19110f51bf feat : design 示例说明
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-08 10:29:31 +08:00
2b7e4013ee 更新 .gitea/workflows/develop_build_manual.yaml
All checks were successful
定时 AiDA python develop 分支构建部署 / scheduled_deploy (push) Successful in 2m14s
2026-01-08 10:23:45 +08:00
zcr
e04636ce21 feat : design overall print 新增平铺间距和旋转角度
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-07 17:03:02 +08:00
zcr
2a50e7040e feat : design overall print 新增平铺间距和旋转角度
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-07 16:22:19 +08:00
zcr
a6f3bda9f7 feat : design 单品新增 镜像旋转功能
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-06 12:21:10 +08:00
zcr
c18f45e549 feat : design 单品新增 镜像旋转功能
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-06 12:00:58 +08:00
zcr
4951fab71a 代码整理
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2025-12-30 17:49:22 +08:00
zcr
aa57478852 新推荐接口first commit
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2025-12-30 17:35:32 +08:00
zcr
2a6c48d937 新推荐接口first commit
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2025-12-30 17:23:36 +08:00
litianxiang
fed3fcdf85 新推荐接口first commit 2025-12-30 17:18:12 +08:00
5b2bb3ce7c 添加 .gitea/workflows/ltx_develop_build_manual.yaml
All checks were successful
定时 AiDA python develop 分支构建部署 / scheduled_deploy (push) Successful in 2m15s
2025-12-29 11:01:49 +08:00
6739a92d28 2025.12.19 生产部署
All checks were successful
定时 AiDA python prod 分支构建部署 / scheduled_deploy (push) Successful in 22s
定时 AiDA python develop 分支构建部署 / scheduled_deploy (push) Successful in 28s
2025-12-19 17:47:54 +08:00
f23b99d326 更新 .gitea/workflows/prod_build_scheduled.yaml 2025-12-19 17:46:35 +08:00
10d41cd32f 新增生产部署actions文件 2025-12-19 17:46:03 +08:00
zcr
bb7b85bfb8 Merge branch 'develop'
All checks were successful
定时 AiDA python develop 分支构建部署 / scheduled_deploy (push) Successful in 15s
# Conflicts:
#	.gitea/workflows/develop_build_scheduled.yaml
#	app/service/design_fast/design_generate.py
2025-12-16 14:37:05 +08:00
6ecb6be59c 更新 .gitea/workflows/develop_build_scheduled.yaml
All checks were successful
定时 AiDA python develop 分支构建部署 / scheduled_deploy (push) Successful in 17s
2025-11-28 17:42:22 +08:00
64285cd5f3 更新 .gitea/workflows/develop_build_scheduled.yaml 2025-11-28 17:41:22 +08:00
fe6a5fb029 更新 .gitea/workflows/develop_build_scheduled.yaml
All checks were successful
定时 AiDA python develop 分支构建部署 / scheduled_deploy (push) Successful in 13s
2025-11-28 17:31:29 +08:00
5217847d49 上传文件至「.gitea/workflows」
All checks were successful
定时 AiDA python develop 分支构建部署 / scheduled_deploy (push) Successful in 14s
2025-11-28 17:23:07 +08:00
0a9fc51310 更新 .gitea/workflows/develop_build_commit.yaml 2025-11-28 17:19:47 +08:00
cf052f9632 上传文件至「.gitea/workflows」 2025-11-28 17:18:52 +08:00
19a8ea9a93 更新 .gitea/workflows/develop_build_commit.yaml 2025-11-28 17:11:35 +08:00
09ff2f1ab7 更新 .gitea/workflows/develop_build_commit.yaml 2025-11-28 17:10:04 +08:00
109a23197a 上传文件至「.gitea/workflows」 2025-11-28 17:02:38 +08:00
zhh
2135a180be feat(新功能):
fix(修复bug):  删除java端debug callback api url
docs(文档变更):
refactor(重构):
test(增加测试):
2025-11-21 23:17:53 +08:00
zhh
09032c0564 Merge branch 'develop'
# Conflicts:
#	app/service/design_fast/design_generate.py
2025-11-21 23:01:34 +08:00
zhh
167faa10c8 feat(新功能): fix(修复bug): 取消design-v2 java端测试接口(重构): test(增加测试): 2025-09-27 00:02:23 +08:00
zhh
0a048bf37f Merge branch 'develop'
# Conflicts:
#	app/service/design_fast/design_generate.py
2025-09-26 23:31:42 +08:00
zhh
05045dda76 feat(新功能): fix(修复bug): : refactor(重构): test(增加测试): 徐佩design测试 2025-09-23 11:38:33 +08:00
zhh
30f9a99df2 Merge branch 'develop'
# Conflicts:
#	app/service/design_fast/pipeline/split.py
2025-09-22 17:56:18 +08:00
zhh
3932b8359a feat(新功能):
fix(修复bug):
docs(文档变更):
refactor(重构):
test(增加测试):  mask 使用原尺寸测试
2025-09-17 16:43:26 +08:00
61 changed files with 4127 additions and 1561 deletions

View File

@@ -1,2 +1,6 @@
seg_cache
test
.venv
__pycache__/
*.pyc
.git/

View File

@@ -7,7 +7,7 @@ jobs:
runs-on: ubuntu-latest
env:
REMOTE_DEPLOY_PATH: /workspace/Trinity/Fastapi_AiDA_Trinity_Dev
REMOTE_DEPLOY_PATH: /workspace/AiDA_Workspace/Python_Server_Workspace/Dev
steps:
- name: 1.检出代码
@@ -35,6 +35,4 @@ jobs:
cd ${{ env.REMOTE_DEPLOY_PATH }}
docker-compose down 2>&1
docker-compose up -d --build --remove-orphans 2>&1
docker image prune -f 2>&1
docker-compose up -d 2>&1

View File

@@ -1,15 +1,15 @@
name: 定时 AiDA python develop 分支构建部署
on:
# 使用 schedule 触发器,遵循标准的 Cron 格式 (分钟 小时-8 日期 月份 星期)
schedule:
- cron: '30 9 * * *'
# schedule:
# - cron: '30 9 * * *'
jobs:
scheduled_deploy:
runs-on: ubuntu-latest
env:
REMOTE_DEPLOY_PATH: /workspace/Trinity/Fastapi_AiDA_Trinity_Dev
REMOTE_DEPLOY_PATH: /workspace/AiDA_Workspace/Python_Server_Workspace/Dev
steps:
- name: 1.检出代码

View File

@@ -1,23 +1,19 @@
name: git commit AiDA python develop 分支构建部署
name: 手动 AiDA python develop 分支构建部署
on:
workflow_dispatch:
push:
branches:
- develop
jobs:
scheduled_deploy:
runs-on: ubuntu-latest
if: "contains(github.event.head_commit.message, '[run build]')"
env:
REMOTE_DEPLOY_PATH: /workspace/Trinity/Fastapi_AiDA_Trinity_Dev
REMOTE_DEPLOY_PATH: /workspace/AiDA_Workspace/Python_Server_Workspace/Dev
steps:
- name: 1.检出代码
uses: actions/checkout@v4
with:
ref: 'develop'
ref: 'dev-ltx'
- name: 2.复制文件到服务器
uses: appleboy/scp-action@v0.1.7
@@ -28,7 +24,7 @@ jobs:
source: "."
target: ${{ env.REMOTE_DEPLOY_PATH }}
- name: Restart Docker containers
- name: 3.重启docker-compose
uses: appleboy/ssh-action@v0.1.10
with:
host: ${{ secrets.SERVER_HOST }}

View File

@@ -0,0 +1,40 @@
name: 定时 AiDA python prod 分支构建部署
on:
workflow_dispatch:
jobs:
scheduled_deploy:
runs-on: ubuntu-latest
env:
REMOTE_DEPLOY_PATH: /workspace/AiDA_Workspace/Python_Server_Workspace/AiDA_Prod
steps:
- name: 1.检出代码
uses: actions/checkout@v4
with:
ref: 'master'
- name: 2.复制文件到服务器
uses: appleboy/scp-action@v0.1.7
with:
host: ${{ secrets.SERVER_HOST }}
username: ${{ secrets.SERVER_USER }}
password: ${{ secrets.SERVER_PASSWORD }}
source: "."
target: ${{ env.REMOTE_DEPLOY_PATH }}
- name: Restart Docker containers
uses: appleboy/ssh-action@v0.1.10
with:
host: ${{ secrets.SERVER_HOST }}
username: ${{ secrets.SERVER_USER }}
password: ${{ secrets.SERVER_PASSWORD }}
script: |
# 进入项目目录
cd ${{ env.REMOTE_DEPLOY_PATH }}
docker-compose down 2>&1
docker-compose up -d 2>&1
docker image prune -f 2>&1

View File

@@ -0,0 +1,42 @@
name: 定时 AiDA python prod 分支构建部署
on:
# 使用 schedule 触发器,遵循标准的 Cron 格式 (分钟 小时-8 日期 月份 星期)
schedule:
- cron: '07 13 23 1 *'
jobs:
scheduled_deploy:
runs-on: ubuntu-latest
env:
REMOTE_DEPLOY_PATH: /workspace/AiDA_Workspace/Python_Server_Workspace/Prod
steps:
- name: 1.检出代码
uses: actions/checkout@v4
with:
ref: 'master'
- name: 2.复制文件到服务器
uses: appleboy/scp-action@v0.1.7
with:
host: ${{ secrets.SERVER_HOST }}
username: ${{ secrets.SERVER_USER }}
password: ${{ secrets.SERVER_PASSWORD }}
source: "."
target: ${{ env.REMOTE_DEPLOY_PATH }}
- name: Restart Docker containers
uses: appleboy/ssh-action@v0.1.10
with:
host: ${{ secrets.SERVER_HOST }}
username: ${{ secrets.SERVER_USER }}
password: ${{ secrets.SERVER_PASSWORD }}
script: |
# 进入项目目录
cd ${{ env.REMOTE_DEPLOY_PATH }}
docker-compose down 2>&1
docker-compose up -d 2>&1
docker image prune -f 2>&1

View File

@@ -20,7 +20,6 @@
$ conda activate trinity_client_aida
$ pip install -r requirements.txt
$ conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia -y
$ pip install mmcv==1.4.2 -f https://download.openmmlab.com/mmcv/dist/cu117/torch1.13/index.html
1. 启动服务器

View File

@@ -1,25 +1,34 @@
import io
import logging
import os
import sys
import time
from typing import List
from collections import defaultdict
import numpy as np
import pymysql
import torch
from PIL import Image
from apscheduler.schedulers.background import BackgroundScheduler
from apscheduler.triggers.cron import CronTrigger
from fastapi import HTTPException, APIRouter
from fastapi.responses import JSONResponse
from minio import Minio
from torchvision import models, transforms
from app.core.mysql_config import DB_CONFIG
from app.core.new_config import settings
import pymysql
from app.core.config import DB_CONFIG, TABLE_CATEGORIES, RECOMMEND_PATH_PREFIX
from minio import Minio
import torch
from torchvision import models, transforms
from PIL import Image
import os
from fastapi.responses import JSONResponse
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
logger = logging.getLogger()
router = APIRouter()
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
# MinIO 配置
minio_client = Minio(
"www.minio.aida.com.hk:12024",
access_key="admin",
secret_key="Aidlab123123!",
secure=True
)
transform = transforms.Compose([
transforms.Resize((224, 224)),
@@ -58,8 +67,8 @@ def extract_feature_vector_from_resnet(sketch_path: str) -> np.ndarray:
# 预加载
BRAND_FEATURES = np.load(f'{settings.RECOMMEND_PATH_PREFIX}brand_feature.npy', allow_pickle=True).item()
SYSTEM_FEATURES = np.load(f'{settings.RECOMMEND_PATH_PREFIX}sketch_feature_dict.npy', allow_pickle=True).item()
BRAND_FEATURES = np.load(f'{RECOMMEND_PATH_PREFIX}brand_feature.npy', allow_pickle=True).item()
SYSTEM_FEATURES = np.load(f'{RECOMMEND_PATH_PREFIX}sketch_feature_dict.npy', allow_pickle=True).item()
def save_sketch_to_iid():
@@ -67,11 +76,11 @@ def save_sketch_to_iid():
sketch_path: iid
for iid, sketch_path in enumerate(SYSTEM_FEATURES.keys(), start=1)
}
np.save(f"{settings.RECOMMEND_PATH_PREFIX}sketch_to_iid.npy", sketch_to_iid)
np.save(f"{RECOMMEND_PATH_PREFIX}sketch_to_iid.npy", sketch_to_iid)
def load_sketch_to_iid():
path = f"{settings.RECOMMEND_PATH_PREFIX}sketch_to_iid.npy"
path = f"{RECOMMEND_PATH_PREFIX}sketch_to_iid.npy"
if os.path.exists(path):
return np.load(path, allow_pickle=True).item()
save_sketch_to_iid()
@@ -81,7 +90,7 @@ def load_sketch_to_iid():
sketch_to_iid = load_sketch_to_iid()
def get_new_category(gender: str, sketch_category: str) -> str:
def getNewCategory(gender: str, sketch_category: str) -> str:
return f"{gender.lower()}_{sketch_category.lower()}"
@@ -94,8 +103,8 @@ def get_category_from_path(path: str) -> str:
def load_brand_matrix():
"""单独加载 brand_matrix 和 brand_index_map"""
mat_path = f"{settings.RECOMMEND_PATH_PREFIX}brand_matrix.npy"
idx_path = f"{settings.RECOMMEND_PATH_PREFIX}brand_index_map.npy"
mat_path = f"{RECOMMEND_PATH_PREFIX}brand_matrix.npy"
idx_path = f"{RECOMMEND_PATH_PREFIX}brand_index_map.npy"
try:
matrix = np.load(mat_path)
index_map = np.load(idx_path, allow_pickle=True).item()
@@ -104,19 +113,11 @@ def load_brand_matrix():
index_map = {}
return matrix, index_map
def cosine_similarity(vec1, vec2):
"""计算余弦相似度(增加零值处理)"""
norm = np.linalg.norm(vec1) * np.linalg.norm(vec2)
return np.dot(vec1, vec2) / (norm + 1e-10) if norm != 0 else 0.0
def getNewCategory(gender, sketch_category):
print(gender)
print(sketch_category)
return "None"
def calculate_brand_matrix(sketch_data, brand_id: int) -> np.ndarray:
# 1. 收集品牌-分类-特征
brand_feature = defaultdict(lambda: defaultdict(list))
@@ -163,11 +164,11 @@ def calculate_brand_matrix(sketch_data, brand_id: int) -> np.ndarray:
brand_matrix[row_idx, sketch_index[iid]] = cos_sim
# 7. 持久化
np.save(f"{settings.RECOMMEND_PATH_PREFIX}brand_feature_matrix.npy", brand_matrix)
np.save(f"{settings.RECOMMEND_PATH_PREFIX}brand_index_map.npy", brand_index_map)
np.save(f"{RECOMMEND_PATH_PREFIX}brand_feature_matrix.npy", brand_matrix)
np.save(f"{RECOMMEND_PATH_PREFIX}brand_index_map.npy", brand_index_map)
# 返回该品牌对应行
return brand_matrix[row_idx:row_idx + 1]
return brand_matrix[row_idx:row_idx+1]
@router.get("/brand_dna_initialize/{brand_id}")
@@ -179,9 +180,11 @@ async def brand_dna_initialize(brand_id: int):
cursor.execute("""
SELECT id, img_url, gender, category
FROM product_image_attribute
WHERE library_id IN (SELECT library_id
WHERE library_id IN (
SELECT library_id
FROM brand_rel_library
WHERE brand_id = %s)
WHERE brand_id = %s
)
""", (brand_id,))
sketch_data = cursor.fetchall()

View File

@@ -1,9 +1,11 @@
import json
import logging
import requests
from fastapi import APIRouter, HTTPException, BackgroundTasks
from app.schemas.design import DesignModel, ModelProgressModel, DesignStreamModel
from app.core.config import settings
from app.schemas.design import DesignModel, ModelProgressModel, DesignStreamModel, SAMRequestModel
from app.schemas.response_template import ResponseModel
from app.service.design_fast.design_generate import design_generate, design_generate_v2
from app.service.design_fast.model_process_service import model_transpose
@@ -15,16 +17,29 @@ logger = logging.getLogger()
@router.post("/design")
def design(request_data: DesignModel):
"""
objects.items.transparent:
- **objects.items.transparent**:
```json
"transparent":{
"mask_url":"test/transparent_test/transparent_mask.png",
"scale":0.1
},
mask_url 为空"" -> 单件衣服透明
mask_url"mask_url" -> 区域透明
```
- **mask_url** 为"" -> 单件衣服透明
- **mask_url** 非空"mask_url" -> 区域透明
- **transpose** 镜像模式 ,:"top_bottom""left_right"
- **rotate** 45,
创建一个具有以下参数的请求体:
- ** design 参数变更:
design detail 请求参数中 basic -> preview_submit 替换为design_type 可选参数 default ,merge (移除preview和submit)
design_type 参数说明:
defuault模式下 请求参数不变
merge模式下 items -> 每个item需要新增 merge_image_path , merge_image_path为前端处理 print color等操作后的单件结果图
**
- 创建一个具有以下参数的请求体:
示例参数:
```json
{
"objects": [
{
@@ -56,7 +71,7 @@ def design(request_data: DesignModel):
]
},
"layer_order": true,
"preview_submit": "submit",
"design_type": "preview",
"scale_bag": 0.7,
"scale_earrings": 0.16,
"self_template": true,
@@ -65,14 +80,19 @@ def design(request_data: DesignModel):
},
"items": [
{
"businessId": 2377945,
"color": "209 196 171",
"image_id": 189410,
"businessId": 2115382,
"color": "",
"image_id": 61686,
"offset": [
0,
0
],
"path": "aida-collection-element/89/Sketchboard/53d38bd5-f77b-4034-ada2-45f1e2ebe00c.png",
"path": "aida-sys-image/images/female/dress/0628000564.jpg",
"transpose": [
1,
1
],
"rotate": 45,
"print": {
"element": {
"element_angle_list": [],
@@ -81,85 +101,30 @@ def design(request_data: DesignModel):
"location": []
},
"overall": {
"location": [],
"print_angle_list": [],
"print_path_list": [],
"print_scale_list": []
},
"single": {
"location": [],
"print_angle_list": [],
"print_path_list": [],
"print_scale_list": []
}
},
"priority": 12,
"resize_scale": [
1.0,
1.0
"location": [
[
53.0,
118.5
]
],
"seg_mask_url": "aida-clothing/mask/mask_8e96ddb0-e466-11f0-8de2-0242ac130002.png",
"type": "Outwear"
},
{
"businessId": 2377946,
"color": "122 152 139",
"image_id": 81868,
"offset": [
0,
0
"print_angle_list": [
0.0
],
"path": "aida-sys-image/images/female/blouse/0825001443.jpg",
"print": {
"element": {
"element_angle_list": [],
"element_path_list": [],
"element_scale_list": [],
"location": []
},
"overall": {
"location": [],
"print_angle_list": [],
"print_path_list": [],
"print_scale_list": []
},
"single": {
"location": [],
"print_angle_list": [],
"print_path_list": [],
"print_scale_list": []
}
},
"priority": 11,
"resize_scale": [
1.0,
1.0
"print_path_list": [
"aida-users/89/print/02d57aa8-f342-4e1d-b02c-b278f94dcfe6-3-89.png"
],
"seg_mask_url": "aida-clothing/mask/mask_8f0fab78-e466-11f0-8de2-0242ac130002.png",
"type": "Blouse"
},
{
"businessId": 2377947,
"color": "111 78 63",
"gradient": "aida-gradient/517c3a4d-aed7-4423-aa99-7b60d3577df1.png",
"image_id": 116494,
"offset": [
0,
0
"print_scale_list": [
[
0.5,
0.5
]
],
"path": "aida-sys-image/images/female/skirt/0825000219.jpg",
"print": {
"element": {
"element_angle_list": [],
"element_path_list": [],
"element_scale_list": [],
"location": []
},
"overall": {
"location": [],
"print_angle_list": [],
"print_path_list": [],
"print_scale_list": []
"gap": [
[
10,
10
]
]
},
"single": {
"location": [],
@@ -173,8 +138,8 @@ def design(request_data: DesignModel):
1.0,
1.0
],
"seg_mask_url": "aida-clothing/mask/mask_8f6191fe-e466-11f0-8de2-0242ac130002.png",
"type": "Skirt"
"seg_mask_url": "aida-clothing/mask/mask_9698b428-eb93-11f0-9327-0242c0a80003.png",
"type": "Dress"
},
{
"body_path": "aida-sys-image/models/female/2e4815b9-1191-419d-94ed-5771239ca4a5.png",
@@ -186,6 +151,7 @@ def design(request_data: DesignModel):
],
"process_id": "89"
}
```
"""
# logger.info(f"design request item is : @@@@@@:{json.dumps(request_data.dict(),indent=4)}")
# data = generate(request_data=request_data)
@@ -421,6 +387,55 @@ async def design_v2(request_data: DesignStreamModel, background_tasks: Backgroun
return ResponseModel()
@router.post("/seg_anything")
async def seg_anything(request_data: SAMRequestModel):
"""
**Segment Anything 交互式分割接口**
通过传入图片路径和点击的点坐标,返回分割后的掩码数据。
### 参数说明:
- **bucket**: minio bucket name
- **object_name**: minio object name
- **image_path**: 图片在服务器或云端的相对路径。
- **type**: 推理类型
- **box**: 框选矩形点位信息
- **points**: 交互点的坐标列表。每个点为 [x, y] 像素格式。
- **labels**: 坐标点的属性标签,必须与 points 长度一致:
- 1: **前景点** (代表想要分割出的区域)
- 0: **背景点** (代表想要排除的区域)
### 请求体示例:
```json
point
{
"bucket": "test",
"object_name": "7068-400a-ac94-c01647fa5f6f.png",
"image_path": "aida-users/89/sketch/4e8fe37d-7068-400a-ac94-c01647fa5f6f.png",
"type":"point",
"points": [[310, 403], [493, 375], [261, 266], [404, 484]],
"labels": [1, 1, 0, 1]
}
box
{
"bucket": "test",
"object_name": "7068-400a-ac94-c01647fa5f6f.png",
"image_path": "aida-users/89/sketch/4e8fe37d-7068-400a-ac94-c01647fa5f6f.png",
"type":"box",
"box": [350, 286, 544, 520]
}
```
"""
try:
logger.info(f"seg_anything request item is : @@@@@@:{json.dumps(request_data.dict(), indent=4)}")
data = requests.post(f"http://{settings.B_4_X_4090_SERVICE_HOST}:10075/predict", json=request_data.dict())
logger.info(f"seg_anything response @@@@@@:{json.dumps(json.loads(data.content), indent=4)}")
return ResponseModel(data=json.loads(data.content))
except Exception as e:
logger.warning(f"seg_anything Run Exception @@@@@@:{e}")
# @router.post('/get_progress')
# def get_progress(request_data: DesignProgressModel):
# """

View File

@@ -1,9 +1,12 @@
import json
import logging
import httpx
import requests
from fastapi import APIRouter, BackgroundTasks, HTTPException
from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel, GenerateRelightImageModel, GenerateMultiViewModel, BatchGenerateProductImageModel, BatchGenerateRelightImageModel, AgentTollGenerateImageModel
from app.core.config import settings
from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel, GenerateRelightImageModel, GenerateMultiViewModel, BatchGenerateProductImageModel, BatchGenerateRelightImageModel, AgentTollGenerateImageModel, Flux2ToProductImgModel, GenerateSloganImageModel, GenerateImageFlux2KleinModel
from app.schemas.pose_transform import BatchPoseTransformModel
from app.schemas.response_template import ResponseModel
from app.service.generate_batch_image.service import start_product_batch_generate, start_relight_batch_generate, start_pose_transform_batch_generate
@@ -20,6 +23,61 @@ logger = logging.getLogger()
'''generate image'''
# flux2 klein
@router.post("/generate_image_flux2_klein")
async def generate_image_flux2_klein(request_item: GenerateImageFlux2KleinModel):
"""
创建一个具有以下参数的请求体:
- **bucket_name**: OSS桶名 (必填)
- **object_name**: OSS对象名文件路径(必填)
- **width**: 图片宽度默认1024像素 (非必填,1024)
- **height**: 图片高度默认1024像素 (非必填,默认1024)
- **prompt**: 文本提示词,用于模型推理等场景 (非必填,默认"")
- **steps**: 推理步数,控制模型生成过程的迭代次数 (非必填,默认4)
- **guidance**: 引导系数,调节提示词对生成结果的影响程度 (非必填,默认 4.0 )
### 示例参数:
```
{
"bucket_name": "aida-users",
"object_name": "89/moodboard/5fdc698c-cb9b-4b36-afa9ce4-1-89.png",
"prompt": "a single item of sketch of dress, 4k, white background"
}
```
### 输出示例:
```
{
"code": 200,
"msg": "OK!",
"data": {
"output_path": "aida-users/89/moodboard/5fdc698c-cb9b-4b36-afa9ce4-1-89.png"
}
}
```
"""
try:
logger.info(f"generate_image_flux2_gen_img request: {json.dumps(request_item.model_dump(), indent=4)}")
async with httpx.AsyncClient(timeout=120) as client:
resp = await client.post(
f"http://{settings.FLUX2_GEN_IMG_MODEL_URL}/predict",
json=request_item.model_dump(),
)
if resp.status_code == 200:
result = resp.json()
logger.info(f"flux2_gen_img response: {json.dumps(result, indent=4)}")
return ResponseModel(data=result)
else:
error = resp.json()
logger.info(f"flux2_gen_img response: {json.dumps(error, indent=4)}")
return ResponseModel(data=error, msg="ERROR!", code=500)
except Exception as e:
logger.warning(f"generate_image_flux2_gen_img Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
# sdxl
@router.post("/generate_image")
def generate_image(request_item: GenerateImageModel, background_tasks: BackgroundTasks):
"""
@@ -154,6 +212,62 @@ def generate_single_logo_image(tasks_id: str):
return ResponseModel(data=data['data'])
"""slogan """
@router.post("/generate_slogan")
async def generate_slogan(request_data: GenerateSloganImageModel):
"""
### 请求体示例:
```json
{
"num_point": 16,
"image_url": "aida-slogan/6886785f-0aac-4052-b6fd-7ae20a841d8d.png",
"prompt": "123",
"tasks_id": "string-89"
}
```
"""
try:
logger.info(f"generate_slogan request item is : @@@@@@:{json.dumps(request_data.dict(), indent=4)}")
data = requests.post(f"http://{settings.A6000_SERVICE_HOST}:10020/api/slogan", json=request_data.dict())
logger.info(f"generate_slogan response @@@@@@:{json.dumps(json.loads(data.content), indent=4)}")
return ResponseModel(data=json.loads(data.content))
except Exception as e:
logger.warning(f"generate_slogan Run Exception @@@@@@:{e}")
"""product image flux2.0"""
# @router.post("/img_to_product")
# async def img_to_product(request_data: Flux2ToProductImgModel):
# """
# 创建一个具有以下参数的请求体:
# - **tasks_id**: 任务id 用于取消生成任务和获取生成结果
# - **prompt**: 想要生成图片的描述词
# - **image_path**: 被生成图片的S3或minio url地址
# - **infer_step**: 推理步数
#
# ### 请求体示例:
# ```json
# point
# {
# "prompt": "Create realistic studio photo with real people model standing and wearing this garment, in white studio, Keep original model if present, or generate appropriate model, Standing pose, facing camera.",
# "image_path":"aida-results/result_38151e0a-f83b-11f0-89f6-0242ac130002.png",
# "infer_step":4,
# "tasks_id":"123456-123"
# }
# ```
# """
# try:
# logger.info(f"img_to_product request item is : @@@@@@:{json.dumps(request_data.dict(), indent=4)}")
# data = requests.post(f"http://{settings.A6000_SERVICE_HOST}:10090/api/v1/to_product", json=request_data.dict())
# logger.info(f"img_to_product response @@@@@@:{json.dumps(json.loads(data.content), indent=4)}")
# return ResponseModel(data=json.loads(data.content))
# except Exception as e:
# logger.warning(f"img_to_product Run Exception @@@@@@:{e}")
'''product image'''
@@ -178,7 +292,7 @@ def generate_product_image(request_item: GenerateProductImageModel, background_t
}
"""
try:
logger.info(f"generate_product_image request item is : @@@@@@:{json.dumps(request_item.dict(),indent=4)}")
logger.info(f"generate_product_image request item is : @@@@@@:{json.dumps(request_item.dict(), indent=4)}")
service = GenerateProductImage(request_item)
background_tasks.add_task(service.get_result)
except Exception as e:

View File

@@ -0,0 +1,116 @@
import logging
import sys
from typing import Optional
from fastapi import APIRouter, HTTPException, Query
from concurrent.futures import ThreadPoolExecutor
import threading
from app.schemas.response_template import ResponseModel
from app.service.recommendation_system.import_sys_sketch_to_milvus import main as import_main
logger = logging.getLogger()
router = APIRouter()
# 使用线程池执行器来运行长时间任务
executor = ThreadPoolExecutor(max_workers=1)
# 用于跟踪任务状态
task_status = {"running": False}
def run_import_task(batch_size: int, retry_times: int, limit: Optional[int], offset: int, skip_create_collection: bool):
"""在后台线程中运行导入任务"""
original_argv = None
try:
task_status["running"] = True
# 保存原始 sys.argv
original_argv = sys.argv.copy()
# 模拟命令行参数
sys.argv = [
"import_sys_sketch_to_milvus.py",
"--batch-size", str(batch_size),
"--retry-times", str(retry_times),
]
if limit is not None:
sys.argv.extend(["--limit", str(limit)])
if offset > 0:
sys.argv.extend(["--offset", str(offset)])
if skip_create_collection:
sys.argv.append("--skip-create-collection")
import_main()
task_status["running"] = False
logger.info("导入任务完成")
except Exception as e:
task_status["running"] = False
logger.error(f"导入任务失败: {e}", exc_info=True)
raise
finally:
# 恢复原始 sys.argv
if original_argv is not None:
sys.argv = original_argv
@router.post("/import-sys-sketch", response_model=ResponseModel)
async def import_sys_sketch(
batch_size: int = Query(1000, description="批量处理大小默认1000"),
retry_times: int = Query(3, description="失败重试次数默认3"),
limit: Optional[int] = Query(None, description="限制处理数量(用于测试,默认:不限制)"),
offset: int = Query(0, description="起始偏移量默认0"),
skip_create_collection: bool = Query(False, description="跳过创建集合(如果集合已存在)"),
):
"""
从 t_sys_file 导入系统图向量到 Milvus
该接口会异步执行导入任务,任务在后台运行。
"""
try:
# 检查是否有任务正在运行
if task_status["running"]:
raise HTTPException(
status_code=409,
detail="已有导入任务正在运行,请等待完成后再试"
)
# 在后台线程中执行任务
executor.submit(
run_import_task,
batch_size,
retry_times,
limit,
offset,
skip_create_collection
)
return ResponseModel(
code=200,
msg="导入任务已启动,正在后台执行",
data={
"status": "started",
"batch_size": batch_size,
"retry_times": retry_times,
"limit": limit,
"offset": offset,
"skip_create_collection": skip_create_collection
}
)
except HTTPException:
raise
except Exception as e:
logger.error(f"启动导入任务失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"启动导入任务失败: {str(e)}")
@router.get("/import-sys-sketch/status", response_model=ResponseModel)
async def get_import_status():
"""
获取导入任务状态
"""
return ResponseModel(
code=200,
msg="OK",
data={
"running": task_status["running"]
}
)

85
app/api/api_precompute.py Normal file
View File

@@ -0,0 +1,85 @@
import logging
from fastapi import APIRouter, HTTPException
from concurrent.futures import ThreadPoolExecutor
from app.schemas.response_template import ResponseModel
from app.service.recommendation_system.precompute import run_precompute
logger = logging.getLogger()
router = APIRouter()
# 使用线程池执行器来运行长时间任务
executor = ThreadPoolExecutor(max_workers=1)
# 用于跟踪任务状态
task_status = {"running": False}
def run_precompute_task():
"""在后台线程中运行预计算任务"""
try:
task_status["running"] = True
logger.info("开始执行预计算任务...")
run_precompute()
task_status["running"] = False
logger.info("预计算任务完成")
except Exception as e:
task_status["running"] = False
logger.error(f"预计算任务失败: {e}", exc_info=True)
raise
@router.post("/precompute", response_model=ResponseModel)
async def precompute():
"""
运行预计算任务
该接口会异步执行预计算任务,包括:
1. 优化数据库表结构
2. 历史数据迁移
3. 初始用户偏好向量生成
任务在后台运行。
"""
try:
# 检查是否有任务正在运行
if task_status["running"]:
raise HTTPException(
status_code=409,
detail="已有预计算任务正在运行,请等待完成后再试"
)
# 在后台线程中执行任务
executor.submit(run_precompute_task)
return ResponseModel(
code=200,
msg="预计算任务已启动,正在后台执行",
data={
"status": "started",
"tasks": [
"优化数据库表结构",
"历史数据迁移",
"初始用户偏好向量生成"
]
}
)
except HTTPException:
raise
except Exception as e:
logger.error(f"启动预计算任务失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"启动预计算任务失败: {str(e)}")
@router.get("/precompute/status", response_model=ResponseModel)
async def get_precompute_status():
"""
获取预计算任务状态
"""
return ResponseModel(
code=200,
msg="OK",
data={
"running": task_status["running"]
}
)

View File

@@ -1,206 +1,206 @@
import io
import logging
import math
import sys
import time
from typing import List
import numpy as np
from typing import List, Optional
from fastapi import HTTPException, APIRouter, Query
from apscheduler.schedulers.background import BackgroundScheduler
from apscheduler.triggers.cron import CronTrigger
from fastapi import HTTPException, APIRouter
from app.service.recommend.service import load_resources, matrix_data
from app.service.recommendation_system.recommendation_api import get_recommendations as get_new_recommendations
from app.service.recommendation_system.incremental_listener import start_background_listener
from app.service.recommendation_system.milvus_client import create_collection
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
logger = logging.getLogger()
router = APIRouter()
# ========== 旧版推荐接口(基于 npy 矩阵,已废弃)==========
# @router.get("/recommend/{user_id}/{category}/{num_recommendations}/{brand_id}/{brand_scale}", response_model=List[str])
# async def get_recommendations(user_id: int, category: str, brand_id: int, brand_scale: float, num_recommendations: int = 10):
# """
# :param user_id: 4
# :param category: female_skirt
# :param num_recommendations: 1
# :return:
# [
# "aida-sys-image/images/female/skirt/903000017.jpg"
# ]
# """
# try:
# start_time = time.time()
# cache_key = (user_id, category)
# # === 新增:用户存在性检查 ===
# user_exists_inter = user_id in matrix_data["user_index_interaction"]
# user_exists_feat = user_id in matrix_data["user_index_feature"]
#
# # 任一矩阵不存在用户则返回随机推荐
# if not (user_exists_inter and user_exists_feat):
# logger.info(f"用户 {user_id} 数据不完整,触发随机推荐")
# return get_random_recommendations(category, num_recommendations)
#
# # 检查缓存
# if cache_key in matrix_data["cached_scores"]:
# processed_inter, processed_feat = matrix_data["cached_scores"][cache_key]
# valid_sketch_idxs_inter = matrix_data["cached_valid_idxs"][cache_key]
# else:
# # 实时计算逻辑(同原代码)
# user_idx_inter = matrix_data["user_index_interaction"].get(user_id)
# user_idx_feature = matrix_data["user_index_feature"].get(user_id)
#
# category_iids = matrix_data["category_to_iids"].get(category, [])
# valid_sketch_idxs_inter = [
# idx for iid, idx in matrix_data["sketch_index_interaction"].items()
# if iid in category_iids
# ]
#
# # 处理交互分数
# raw_inter_scores = []
# if user_idx_inter is not None and valid_sketch_idxs_inter:
# raw_inter_scores = matrix_data["interaction_matrix"][user_idx_inter, valid_sketch_idxs_inter]
# processed_inter = raw_inter_scores * 0.7
#
# # 处理特征分数
# valid_sketch_idxs_feature = [
# idx for iid, idx in matrix_data["sketch_index_feature"].items()
# if iid in category_iids
# ]
# raw_feat_scores = []
# if user_idx_feature is not None and valid_sketch_idxs_feature:
# raw_feat_scores = matrix_data["feature_matrix"][user_idx_feature, valid_sketch_idxs_feature]
# raw_feat_scores = (raw_feat_scores - np.min(raw_feat_scores)) / (
# np.max(raw_feat_scores) - np.min(raw_feat_scores) + 1e-8)
# processed_feat = raw_feat_scores
# else:
# processed_feat = np.array([])
#
# # 更新缓存
# matrix_data["cached_scores"][cache_key] = (processed_inter, processed_feat)
# matrix_data["cached_valid_idxs"][cache_key] = valid_sketch_idxs_inter
#
# # 合并分数
# if brand_id is not None:
# brand_idx_feature = matrix_data["brand_index_map"].get(brand_id)
#
# brand_feat_valid = (
# matrix_data["brand_feature_matrix"].size > 0 and # 矩阵非空
# brand_idx_feature is not None and
# valid_sketch_idxs_feature # 有可用索引
# )
#
# if brand_feat_valid:
# raw_brand_feat_scores = matrix_data["brand_feature_matrix"][
# brand_idx_feature, valid_sketch_idxs_feature
# ]
# raw_brand_feat_scores = (raw_brand_feat_scores - np.min(raw_brand_feat_scores)) / (
# np.max(raw_brand_feat_scores) - np.min(raw_brand_feat_scores) + 1e-8
# )
# processed_brand_feat = raw_brand_feat_scores
#
# # 如果 processed_feat 是空的,替换为全 0避免 shape 不一致
# if processed_feat.size == 0:
# processed_feat = np.zeros_like(processed_brand_feat)
#
# final_scores = processed_inter + 0.3 * (
# (1 - brand_scale) * processed_feat + brand_scale * processed_brand_feat
# )
# else:
# # brand 信息不可用
# final_scores = processed_inter + 0.3 * processed_feat if processed_feat.size > 0 else processed_inter
# else:
# final_scores = processed_inter + 0.3 * processed_feat if processed_feat.size > 0 else processed_inter
#
# valid_sketch_idxs = matrix_data["cached_valid_idxs"][cache_key]
#
# # 概率采样
# scores = np.array(final_scores)
#
# # 调整后的概率转换带温度控制的softmax
# def calibrated_softmax(scores, temperature=1.0):
# scores = scores / temperature
# scale = scores - max(scores)
# exps = np.exp(scale)
# return exps / np.sum(exps)
#
# probs = calibrated_softmax(scores, 0.09)
#
# chosen_indices = np.random.choice(
# len(valid_sketch_idxs),
# size=min(num_recommendations, len(valid_sketch_idxs)),
# p=probs,
# replace=False
# )
# recommendations = [matrix_data["iid_to_sketch"][valid_sketch_idxs[idx]] for idx in chosen_indices]
#
# logger.info(f"推荐生成完成,耗时: {time.time() - start_time:.2f}秒")
# return recommendations
# except Exception as e:
# logger.error(f"推荐失败: {str(e)}", exc_info=True)
# raise HTTPException(status_code=500, detail=str(e))
@router.on_event("startup")
async def startup_event():
# 初始加载
load_resources()
"""启动时初始化增量监听任务"""
try:
# 屏蔽 apscheduler 的 INFO 日志
logging.getLogger("apscheduler").setLevel(logging.WARNING)
# 确保 Milvus 集合已创建(若已存在则直接返回)
try:
create_collection()
except Exception as exc:
logger.error("Milvus 集合创建/检查失败,不影响服务继续启动: %s", exc, exc_info=True)
# 配置定时任务
scheduler = BackgroundScheduler()
scheduler.add_job(
load_resources,
trigger=CronTrigger(hour=0, minute=30),
name="每日资源刷新"
)
start_background_listener(scheduler)
scheduler.start()
logger.info("定时任务已启动")
logger.info("增量监听定时任务已启动")
except Exception as e:
logger.error(f"启动增量监听任务失败: {e}", exc_info=True)
def softmax(scores):
max_score = max(scores)
exp_scores = [math.exp(s - max_score) for s in scores]
sum_exp = sum(exp_scores)
return [s / sum_exp for s in exp_scores]
@router.get("/recommend/{user_id}/{category}", response_model=List[str])
async def recommend(
user_id: int,
category: str,
style: Optional[str] = Query(
None,
description="风格样式(可选):若传入,则在利用分支对同 style 的候选进行加分",
),
):
"""新版推荐接口Milvus + Redis 偏好向量)。"""
try:
results = get_new_recommendations(user_id, category, style)
path = results[0] if results else ""
return [path]
except Exception as e:
logger.error("新版推荐接口失败 [user=%s, category=%s]: %s", user_id, category, e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
# def get_random_recommendations(category: str, num: int) -> List[str]:
# """根据预加载热度向量推荐(冷启动)"""
# try:
# heat_data = matrix_data.get("heat_data", {})
#
# if category not in heat_data:
# raise ValueError(f"热度数据缺少类别 {category},使用随机推荐")
#
# heat_dict = heat_data[category] # {url: score}
# urls = list(heat_dict.keys())
# scores = list(heat_dict.values())
#
# if not urls:
# raise ValueError("该类别下无热度记录,使用随机推荐")
#
# probs = softmax(scores)
# sample_size = min(num, len(urls))
# sampled_urls = random.choices(urls, weights=probs, k=sample_size)
#
# return sampled_urls
#
# except Exception as e:
# # 回退:完全随机推荐
# all_iids = list(matrix_data["iid_to_sketch"].keys())
# category_iids = matrix_data["category_to_iids"].get(category, all_iids)
# sample_size = min(num, len(category_iids))
# sampled = np.random.choice(category_iids, size=sample_size, replace=False)
# return [matrix_data["iid_to_sketch"][iid] for iid in sampled]
def get_random_recommendations(category: str, num: int) -> List[str]:
"""全品类随机推荐"""
all_iids = list(matrix_data["iid_to_sketch"].keys())
# 优先从当前品类选择
category_iids = matrix_data["category_to_iids"].get(category, all_iids)
# 确保不超出实际数量
sample_size = min(num, len(category_iids))
sampled = np.random.choice(category_iids, size=sample_size, replace=False)
return [matrix_data["iid_to_sketch"][iid] for iid in sampled]
@router.get("/recommend/{user_id}/{category}/{num_recommendations}/{brand_id}/{brand_scale}", response_model=List[str])
async def get_recommendations(user_id: int, category: str, brand_id: int, brand_scale: float, num_recommendations: int = 10):
@router.get("/redis/user_pref")
async def get_all_user_preferences():
"""
@param user_id: 4
@param category: female_skirt
@param num_recommendations: 1
@return:
[
"aida-sys-image/images/female/skirt/903000017.jpg"
]
获取所有以 user_pref 为前缀的 Redis key 信息
"""
try:
logger.info(f"user_id:{user_id}-----category:{category}-----brand_id:{brand_id}-----brand_scale:{brand_scale}-----num_recommendations:{num_recommendations}")
start_time = time.time()
cache_key = (user_id, category)
# === 新增:用户存在性检查 ===
user_exists_inter = user_id in matrix_data["user_index_interaction"]
user_exists_feat = user_id in matrix_data["user_index_feature"]
from app.service.utils.redis_utils import Redis
from app.service.recommendation_system.config import REDIS_KEY_USER_PREF_PREFIX
# 任一矩阵不存在用户则返回随机推荐
if not (user_exists_inter and user_exists_feat):
logger.info(f"用户 {user_id} 数据不完整,触发随机推荐")
return get_random_recommendations(category, num_recommendations)
# 扫描所有匹配 user_pref:* 的 key
pattern = f"{REDIS_KEY_USER_PREF_PREFIX}:*"
keys = Redis.scan_keys(pattern)
# 检查缓存
if cache_key in matrix_data["cached_scores"]:
processed_inter, processed_feat = matrix_data["cached_scores"][cache_key]
valid_sketch_idxs_inter = matrix_data["cached_valid_idxs"][cache_key]
else:
# 实时计算逻辑(同原代码)
user_idx_inter = matrix_data["user_index_interaction"].get(user_id)
user_idx_feature = matrix_data["user_index_feature"].get(user_id)
# 直接返回所有 key 和原始 value
result = {}
for key in keys:
# 读取对应的值
value = Redis.read(key)
if value:
result[key] = value
category_iids = matrix_data["category_to_iids"].get(category, [])
valid_sketch_idxs_inter = [
idx for iid, idx in matrix_data["sketch_index_interaction"].items()
if iid in category_iids
]
# 处理交互分数
raw_inter_scores = []
if user_idx_inter is not None and valid_sketch_idxs_inter:
raw_inter_scores = matrix_data["interaction_matrix"][user_idx_inter, valid_sketch_idxs_inter]
processed_inter = raw_inter_scores * 0.7
# 处理特征分数
valid_sketch_idxs_feature = [
idx for iid, idx in matrix_data["sketch_index_feature"].items()
if iid in category_iids
]
raw_feat_scores = []
if user_idx_feature is not None and valid_sketch_idxs_feature:
raw_feat_scores = matrix_data["feature_matrix"][user_idx_feature, valid_sketch_idxs_feature]
raw_feat_scores = (raw_feat_scores - np.min(raw_feat_scores)) / (
np.max(raw_feat_scores) - np.min(raw_feat_scores) + 1e-8)
processed_feat = raw_feat_scores
else:
processed_feat = np.array([])
# 更新缓存
matrix_data["cached_scores"][cache_key] = (processed_inter, processed_feat)
matrix_data["cached_valid_idxs"][cache_key] = valid_sketch_idxs_inter
# 合并分数
if brand_id is not None:
brand_idx_feature = matrix_data["brand_index_map"].get(brand_id)
brand_feat_valid = (
matrix_data["brand_feature_matrix"].size > 0 and # 矩阵非空
brand_idx_feature is not None and
valid_sketch_idxs_feature # 有可用索引
)
if brand_feat_valid:
raw_brand_feat_scores = matrix_data["brand_feature_matrix"][
brand_idx_feature, valid_sketch_idxs_feature
]
raw_brand_feat_scores = (raw_brand_feat_scores - np.min(raw_brand_feat_scores)) / (
np.max(raw_brand_feat_scores) - np.min(raw_brand_feat_scores) + 1e-8
)
processed_brand_feat = raw_brand_feat_scores
# 如果 processed_feat 是空的,替换为全 0避免 shape 不一致
if processed_feat.size == 0:
processed_feat = np.zeros_like(processed_brand_feat)
final_scores = processed_inter + 0.3 * (
(1 - brand_scale) * processed_feat + brand_scale * processed_brand_feat
)
else:
# brand 信息不可用
final_scores = processed_inter + 0.3 * processed_feat if processed_feat.size > 0 else processed_inter
else:
final_scores = processed_inter + 0.3 * processed_feat if processed_feat.size > 0 else processed_inter
valid_sketch_idxs = matrix_data["cached_valid_idxs"][cache_key]
# 概率采样
scores = np.array(final_scores)
# 调整后的概率转换带温度控制的softmax
def calibrated_softmax(scores, temperature=1.0):
scores = scores / temperature
scale = scores - max(scores)
exps = np.exp(scale)
return exps / np.sum(exps)
probs = calibrated_softmax(scores, 0.09)
chosen_indices = np.random.choice(
len(valid_sketch_idxs),
size=min(num_recommendations, len(valid_sketch_idxs)),
p=probs,
replace=False
)
recommendations = [matrix_data["iid_to_sketch"][valid_sketch_idxs[idx]] for idx in chosen_indices]
logger.info(f"推荐生成完成,耗时: {time.time() - start_time:.2f}")
return recommendations
return result
except Exception as e:
logger.error(f"推荐失败: {str(e)}", exc_info=True)
logger.error("获取用户偏好数据失败: %s", e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -7,6 +7,7 @@ from app.api import api_design_pre_processing
from app.api import api_generate_image
from app.api import api_mannequins_edit
from app.api import api_pose_transform
from app.api import api_precompute
from app.api import api_prompt_generation
from app.api import api_recommendation
from app.api import api_test
@@ -21,6 +22,7 @@ router.include_router(api_prompt_generation.router, tags=['prompt_generation'],
router.include_router(api_design_pre_processing.router, tags=['design_pre_processing'], prefix="/api")
router.include_router(api_brand_dna.router, tags=['api_brand_dna'], prefix="/api")
router.include_router(api_recommendation.router, tags=['api_recommendation'], prefix="/api")
router.include_router(api_precompute.router, tags=['api_precompute'], prefix="/api")
router.include_router(api_mannequins_edit.router, tags=['api_mannequins_edit'], prefix="/api")
router.include_router(api_pose_transform.router, tags=['api_pose_transform'], prefix="/api")
router.include_router(api_clothing_seg.router, tags=['api_clothing_seg'], prefix="/api")

View File

@@ -1,235 +0,0 @@
import os
import pika
from dotenv import load_dotenv
from pydantic import BaseSettings
BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../'))
load_dotenv(os.path.join(BASE_DIR, '.env'))
class Settings(BaseSettings):
PROJECT_NAME: str = 'FASTAPI BASE'
SECRET_KEY: str = ''
API_PREFIX: str = ''
BACKEND_CORS_ORIGINS: list[str] = ['*']
DATABASE_URL: str = ''
ACCESS_TOKEN_EXPIRE_SECONDS: int = 60 * 60 * 24 * 7 # Token expired after 7 days
SECURITY_ALGORITHM: str = 'HS256'
LOGGING_CONFIG_FILE: str = os.path.join(BASE_DIR, 'logging_env.py')
OSS = "minio"
DEBUG = False
if DEBUG:
LOGS_PATH = "logs/"
CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv"
SEG_CACHE_PATH = "../seg_cache/"
POSE_TRANSFORM_VIDEO_PATH = "../pose_transform_video/"
RECOMMEND_PATH_PREFIX = "service/recommend/"
CHROMADB_PATH = "./chromadb/"
else:
LOGS_PATH = "app/logs/"
CATEGORY_PATH = "app/service/attribute/config/descriptor/category/category_dis.csv"
SEG_CACHE_PATH = "/seg_cache/"
POSE_TRANSFORM_VIDEO_PATH = "/pose_transform_video/"
RECOMMEND_PATH_PREFIX = "app/service/recommend/"
CHROMADB_PATH = "/chromadb/"
# RABBITMQ_ENV = "" # 生产环境
RABBITMQ_ENV = os.getenv("RABBITMQ_ENV", "-dev")
# RABBITMQ_ENV = "-local" # 本地测试环境
if RABBITMQ_ENV == "-dev":
JAVA_STREAM_API_URL = f"https://develop.api.aida.com.hk/api/third/party/receiveDesignResults"
elif RABBITMQ_ENV == "-prod":
JAVA_STREAM_API_URL = f"https://api.aida.com.hk/api/third/party/receiveDesignResults"
settings = Settings()
# minio 配置
MINIO_URL = "www.minio-api.aida.com.hk"
MINIO_ACCESS = 'vXKFLSJkYeEq2DrSZvkB'
MINIO_SECRET = 'uKTZT3x7C43WvPN9QTc99DiRkwddWZrG9Uh3JVlR'
MINIO_SECURE = True
# S3 配置
S3_ACCESS_KEY = "AKIAVD3OJIMF6UJFLSHZ"
S3_AWS_SECRET_ACCESS_KEY = "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"
S3_REGION_NAME = "ap-east-1"
# redis 配置
REDIS_HOST = "10.1.1.240"
REDIS_PORT = "6379"
REDIS_DB = "2"
# rabbitmq config
RABBITMQ_PARAMS = {
"host": "18.167.251.121",
"port": 5672,
"credentials": pika.credentials.PlainCredentials(username='rabbit', password='123456'),
"virtual_host": "/"
}
# milvus 配置
MILVUS_URL = "http://10.1.1.240:19530"
MILVUS_TOKEN = "root:Milvus"
MILVUS_ALIAS = "default"
MILVUS_TABLE_KEYPOINT = "keypoint_cache_2"
MILVUS_TABLE_SEG = "seg_cache"
# Mysql 配置
DB_HOST = '18.167.251.121' # 数据库主机地址
# DB_PORT = int( 33006)
DB_PORT = 33008 # 数据库端口
DB_USERNAME = 'aida_con_python' # 数据库用户名
DB_PASSWORD = '123456' # 数据库密码
DB_NAME = 'aida' # 数据库库名
# openai
os.environ['SERPAPI_API_KEY'] = "a793513017b0718db7966207c31703d280d12435c982f1e67bbcbffa52e7632c"
OPENAI_STREAM = True
BUFFER_THRESHOLD = 6 # must be even number
SINGLE_TOKEN_THRESHOLD = 200
TOKEN_THRESHOLD = 600
OPENAI_TEMPERATURE = 0
# OPENAI_API_KEY = "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"
OPENAI_API_KEY = "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"
OPENAI_MODEL = "gpt-3.5-turbo-0613"
OPENAI_MODEL_LIST = {"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-4-0314",
"gpt-4-32k-0314",
"gpt-4-0613",
"gpt-4-32k-0613", }
# SR service config
SR_MODEL_NAME = "super_resolution"
SR_TRITON_URL = "10.1.1.240:10031"
SR_MINIO_BUCKET = "aida-users"
SR_RABBITMQ_QUEUES = f"SuperResolution{RABBITMQ_ENV}"
# GenerateImage service config
FAST_GI_MODEL_URL = '10.1.1.243:10011'
FAST_GI_MODEL_NAME = 'stable_diffusion_xl'
GI_MODEL_URL = '10.1.1.240:10061'
GI_MODEL_NAME = 'flux'
GMV_MODEL_URL = '10.1.1.243:10081'
GMV_MODEL_NAME = 'multi_view'
GMV_RABBITMQ_QUEUES = f"GenerateMultiView{RABBITMQ_ENV}"
GI_MINIO_BUCKET = "aida-users"
GI_RABBITMQ_QUEUES = f"GenerateImage{RABBITMQ_ENV}"
GI_SYS_IMAGE_URL = "aida-sys-image/generate_image/white_image.jpg"
# SLOGAN service config
SLOGAN_RABBITMQ_QUEUES = f"Slogan{RABBITMQ_ENV}"
# Generate Single Logo service config
GSL_MODEL_URL = '10.1.1.243:10041'
GSL_MINIO_BUCKET = "aida-users"
GSL_MODEL_NAME = 'stable_diffusion_xl_transparent'
GEN_SINGLE_LOGO_RABBITMQ_QUEUES = f"GenSingleLogo{RABBITMQ_ENV}"
# Generate Product service config
# GPI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"ToProductImage{RABBITMQ_ENV}")
# GPI_MODEL_NAME_OVERALL = 'sdxl_ensemble_all'
# GPI_MODEL_URL = '10.1.1.243:10051'
# Generate Product service config 旧版product img 模型
GPI_RABBITMQ_QUEUES = f"ToProductImage{RABBITMQ_ENV}"
BATCH_GPI_RABBITMQ_QUEUES = f"BatchToProductImage{RABBITMQ_ENV}"
GPI_MODEL_NAME_OVERALL = 'diffusion_ensemble_all'
GPI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_cnet'
GPI_MODEL_URL = '10.1.1.243:10051'
# Generate Single Logo service config
GRI_RABBITMQ_QUEUES = f"Relight{RABBITMQ_ENV}"
BATCH_GRI_RABBITMQ_QUEUES = f"BatchRelight{RABBITMQ_ENV}"
GRI_MODEL_NAME_OVERALL = 'diffusion_relight_ensemble'
GRI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_relight'
GRI_MODEL_URL = '10.1.1.240:10051'
# Pose Transform service config
PS_RABBITMQ_QUEUES = f"PoseTransform{RABBITMQ_ENV}"
BATCH_PS_RABBITMQ_QUEUES = f"BatchPoseTransform{RABBITMQ_ENV}"
PT_MODEL_URL = '10.1.1.243:10061'
# SEG service config
SEGMENTATION = {
"new_model_name": "seg_knet",
"name": "seg_ocrnet_hr18",
"input": "seg_input__0",
"output": "seg_output__0",
}
# ollama config
OLLAMA_URL = "http://10.1.1.240:11434/api/embeddings"
# design batch
BATCH_DESIGN_RABBITMQ_QUEUES = f"DesignBatch{RABBITMQ_ENV}"
# DESIGN config
DESIGN_MODEL_URL = '10.1.1.240:10000'
AIDA_CLOTHING = "aida-clothing"
KEYPOINT_RESULT_TABLE_FIELD_SET = ('neckline_left', 'neckline_right', 'shoulder_left', 'shoulder_right', 'armpit_left', 'armpit_right',
'cuff_left_in', 'cuff_left_out', 'cuff_right_in', 'cuff_right_out', 'waistband_left', 'waistband_right')
# DESIGN 预处理
IF_DEBUG_SHOW = False
# 优先级
PRIORITY_DICT = {
'earring_front': 99,
'bag_front': 98,
'hairstyle_front': 97,
'outwear_front': 20,
'tops_front': 19,
'dress_front': 18,
'blouse_front': 17,
'skirt_front': 16,
'trousers_front': 15,
'bottoms_front': 14,
'shoes_right': 1,
'shoes_left': 1,
'body': 0,
'bottoms_back': -14,
'trousers_back': -15,
'skirt_back': -16,
'blouse_back': -17,
'dress_back': -18,
'tops_back': -19,
'outwear_back': -20,
'hairstyle_back': -97,
'bag_back': -98,
'earring_back': -99,
}
QWEN_API_KEY = "sk-f31c29e61ac2498ba5e307aaa6dc10e0"
DB_CONFIG = {
"host": "18.167.251.121",
"port": 3306,
"user": "root",
"password": "QWa998345",
"database": "aida",
"charset": "utf8mb4"
}
TABLE_CATEGORIES = {
"female_dress": "female/dress",
"female_outwear": "female/outwear",
"female_trousers": "female/trousers",
"female_skirt": "female/skirt",
"female_blouse": "female/blouse",
"male_tops": "male/tops",
"male_bottoms": "male/bottoms",
"male_outwear": "male/outwear"
}
# --- ComfyUI 配置信息 ---
COMFYUI_SERVER_ADDRESS = "10.1.2.227:8080" # 替换为您的 ComfyUI 服务器地址

View File

@@ -36,7 +36,7 @@ class Settings(BaseSettings):
# --- mysql 配置信息 ---
MYSQL_HOST: str = Field(default='', description="")
MYSQL_PORT: str = Field(default='', description="")
MYSQL_PORT: int = Field(default=3306, description="")
MYSQL_USER: str = Field(default='', description="")
MYSQL_PASSWORD: str = Field(default='', description="")
MYSQL_DB: str = Field(default='', description="")
@@ -64,11 +64,19 @@ class Settings(BaseSettings):
# --- Design Callback Java 接口 ---
JAVA_STREAM_API_URL: str = Field(default='', description="")
# --- flux2 klein model url ---
FLUX2_GEN_IMG_MODEL_URL: str = Field(default='', description="")
# --- 服务器IP ---
A6000_SERVICE_HOST: str = Field(default='', description="")
B_4_X_4090_SERVICE_HOST: str = Field(default='', description="")
# --- 其他配置信息 以下均为Docker容器内配置---
LOGS_PATH: str = Field(default="/logs/", description="")
CATEGORY_PATH: str = Field(default="/app/service/attribute/config/descriptor/category/category_dis.csv", description="")
SEG_CACHE_PATH: str = Field(default="/seg_cache/", description="")
RECOMMEND_PATH_PREFIX: str = Field(default="/app/service/recommend/", description="")
SERVE_PORT: int = Field(default=2010, description="")
settings = Settings()
@@ -117,39 +125,41 @@ KEYPOINT_RESULT_TABLE_FIELD_SET = ('neckline_left', 'neckline_right', 'shoulder_
MILVUS_TABLE_KEYPOINT = "keypoint_cache_2"
# ollama 地址
OLLAMA_URL = "http://10.1.1.240:11434/api/embeddings"
OLLAMA_URL = f"http://{settings.A6000_SERVICE_HOST}:11434/api/embeddings"
"""Triton Server Config"""
# Design
DESIGN_MODEL_URL = '10.1.1.240:10000'
DESIGN_MODEL_URL = f'{settings.A6000_SERVICE_HOST}:10000'
DESIGN_MODEL_NAME = 'seg_knet'
# Seg Product
SEG_PRODUCT_MODEL_URL = f'{settings.B_4_X_4090_SERVICE_HOST}:30000'
# Generate Image
GI_MODEL_URL = '10.1.1.240:10061'
GI_MODEL_URL = f'{settings.A6000_SERVICE_HOST}:10061'
GI_MODEL_NAME = 'flux'
# Generate Single Logo
GSL_MODEL_URL = '10.1.1.243:10041'
GSL_MODEL_URL = f'{settings.B_4_X_4090_SERVICE_HOST}:10041'
GSL_MODEL_NAME = 'stable_diffusion_xl_transparent'
# Generate Product (整套和单品)
GPI_MODEL_URL = '10.1.1.243:10051'
GPI_MODEL_URL = f'{settings.B_4_X_4090_SERVICE_HOST}:10051'
GPI_MODEL_NAME_OVERALL = 'diffusion_ensemble_all'
GPI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_cnet'
# 以下停用中...*************
# 多视角生成
GMV_MODEL_URL = '10.1.1.243:10081'
GMV_MODEL_URL = f'{settings.B_4_X_4090_SERVICE_HOST}:10081'
GMV_MODEL_NAME = 'multi_view'
# 超分
SR_MODEL_NAME = "super_resolution"
SR_TRITON_URL = "10.1.1.240:10031"
SR_TRITON_URL = f"{settings.A6000_SERVICE_HOST}:10031"
# 打光
GRI_MODEL_URL = '10.1.1.240:10051'
GRI_MODEL_URL = f'{settings.A6000_SERVICE_HOST}:10051'
GRI_MODEL_NAME_OVERALL = 'diffusion_relight_ensemble'
GRI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_relight'
# agent 图片生成
FAST_GI_MODEL_URL = '10.1.1.243:10011'
FAST_GI_MODEL_URL = f'{settings.B_4_X_4090_SERVICE_HOST}:10011'
FAST_GI_MODEL_NAME = 'stable_diffusion_xl'
# 图转视频 triton版
PT_MODEL_URL = '10.1.1.243:10061'
PT_MODEL_URL = f'{settings.B_4_X_4090_SERVICE_HOST}:10061'
# *************

View File

@@ -16,7 +16,7 @@ from fastapi.responses import JSONResponse
from app.api.api_route import router
from app.core.config import settings
from app.core.record_api_count import count_api_calls
# from app.core.record_api_count import count_api_calls
from app.schemas.response_template import ResponseModel
from logging_env import LOGGER_CONFIG_DICT
from dotenv import load_dotenv
@@ -48,7 +48,7 @@ def get_application() -> FastAPI:
allow_methods=["*"],
allow_headers=["*"],
)
application.middleware("http")(count_api_calls)
# application.middleware("http")(count_api_calls)
application.include_router(router=router)
return application

View File

@@ -1,4 +1,16 @@
from pydantic import BaseModel
from typing import List, Optional
from pydantic import BaseModel, Field
class SAMRequestModel(BaseModel):
bucket: str = Field(..., description="minio bucket name ")
object_name: str = Field(..., description="minio object name ")
image_path: str = Field(..., description="图片路径,必填字段")
type: str = Field(..., description="推理类型,必填字段")
points: Optional[List[List[float]]] | None = None
labels: Optional[List[int]] | None = None
box: Optional[List[int]] | None = None
class DesignModel(BaseModel):

View File

@@ -1,6 +1,6 @@
from typing import List
from typing import List, Optional
from pydantic import BaseModel
from pydantic import BaseModel, Field
class GenerateMultiViewModel(BaseModel):
@@ -8,6 +8,17 @@ class GenerateMultiViewModel(BaseModel):
image_url: str
class GenerateImageFlux2KleinModel(BaseModel):
bucket_name: str = Field(..., description="OSS桶名不传则为None")
object_name: str = Field(..., description="OSS对象名文件路径不传则为None")
# input_image_paths: Optional[List[str]] = Field(default=[], description="输入图片路径列表")
width: Optional[int] = Field(default=1024, description="图片宽度默认512像素")
height: Optional[int] = Field(default=1024, description="图片高度默认512像素")
prompt: Optional[str] = Field(default="", description="文本提示词,用于模型推理等场景")
steps: Optional[int] = Field(default=4, description="推理步数,控制模型生成过程的迭代次数")
guidance: Optional[float] = Field(default=4.0, description="引导系数,调节提示词对生成结果的影响程度")
class GenerateImageModel(BaseModel):
tasks_id: str
prompt: str
@@ -24,6 +35,13 @@ class GenerateSingleLogoImageModel(BaseModel):
seed: str
class GenerateSloganImageModel(BaseModel):
num_point: int
tasks_id: str
prompt: str
image_url: str
class GenerateProductImageModel(BaseModel):
tasks_id: str
prompt: str
@@ -32,6 +50,13 @@ class GenerateProductImageModel(BaseModel):
product_type: str
class Flux2ToProductImgModel(BaseModel):
tasks_id: str
prompt: str
image_path: str
infer_step: int | None = None
class GenerateRelightImageModel(BaseModel):
tasks_id: str
prompt: str

View File

@@ -3,7 +3,6 @@
from pprint import pprint
import cv2
import mmcv
import numpy as np
import pandas as pd
import torch
@@ -12,6 +11,7 @@ from minio import Minio
from app.core.config import settings, DESIGN_MODEL_URL
from app.schemas.attribute_retrieve import AttributeRecognitionModel
from app.service.utils.image_normalize import my_imnormalize
from app.service.utils.new_oss_client import oss_get_image
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
@@ -109,10 +109,9 @@ class AttributeRecognition:
@staticmethod
def preprocess(img):
img = mmcv.imread(img)
img_scale = (224, 224)
img = cv2.resize(img, img_scale)
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
img = my_imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img

View File

@@ -10,7 +10,6 @@
from minio import Minio
from skimage import transform
import cv2
import mmcv
import numpy as np
import pandas as pd
import tritonclient.http as httpclient
@@ -18,6 +17,7 @@ import torch
from app.core.config import settings, DESIGN_MODEL_URL
from app.schemas.attribute_retrieve import CategoryRecognitionModel
from app.service.utils.image_normalize import my_imnormalize
from app.service.utils.new_oss_client import oss_get_image
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
@@ -39,11 +39,10 @@ class CategoryRecognition:
@staticmethod
def preprocess(img):
img = mmcv.imread(img)
# ori_shape = img.shape[:2]
img_scale = (224, 224)
img = cv2.resize(img, img_scale)
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
img = my_imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img

View File

@@ -1,7 +1,6 @@
import logging
import cv2
import mmcv
import numpy as np
import pandas as pd
import torch
@@ -9,11 +8,12 @@ import torch.nn.functional as F
import tritonclient.http as httpclient
from minio import Minio
from app.core.config import DESIGN_MODEL_URL
from app.core.config import DESIGN_MODEL_URL, SEG_PRODUCT_MODEL_URL
from app.core.config import settings
from app.schemas.brand_dna import BrandDnaModel
from app.service.attribute.config import const
from app.service.utils.generate_uuid import generate_uuid
from app.service.utils.image_normalize import my_imnormalize
from app.service.utils.new_oss_client import oss_upload_image, oss_get_image
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
@@ -29,7 +29,7 @@ class BrandDna:
self.attr_type = pd.read_csv(settings.CATEGORY_PATH)
# self.attr_type = pd.read_csv(r"E:\workspace\trinity_client_aida\app\service\attribute\config\descriptor\category\category_dis.csv")
self.att_client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
self.seg_client = httpclient.InferenceServerClient(url='10.1.1.243:30000')
self.seg_client = httpclient.InferenceServerClient(url=SEG_PRODUCT_MODEL_URL)
self.const = const
# self.const = local_debug_const
@@ -202,7 +202,7 @@ class BrandDna:
# 服装分割预处理
@staticmethod
def seg_product_preprocess(image):
img = mmcv.imread(image)
img = image
ori_shape = img.shape[:2]
img_scale_w, img_scale_h = ori_shape
if ori_shape[0] > 1024:
@@ -211,9 +211,9 @@ class BrandDna:
img_scale_h = 1024
# 如果图片size任意一边 大于 1024 则会resize 成1024
if ori_shape != (img_scale_w, img_scale_h):
# mmcv.imresize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了
# my_imnormalize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了
img = cv2.resize(img, (img_scale_h, img_scale_w))
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
img = my_imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img, ori_shape
@@ -227,11 +227,10 @@ class BrandDna:
# 类别检测模型预处理
@staticmethod
def category_preprocess(img):
img = mmcv.imread(img)
# ori_shape = img.shape[:2]
img_scale = (224, 224)
img = cv2.resize(img, img_scale)
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
img = my_imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img

View File

@@ -1,19 +1,10 @@
import logging
import cv2
import numpy as np
import tritonclient.grpc as grpcclient
import uuid
import httpx
from langchain_classic.output_parsers import ResponseSchema, StructuredOutputParser
from langchain_community.chat_models import ChatTongyi
from langchain_core.prompts import PromptTemplate
from minio import Minio
from tritonclient.utils import np_to_triton_dtype
from app.core.config import GI_MODEL_URL, GI_MODEL_NAME
from app.schemas.brand_dna import GenerateBrandModel
from app.service.utils.generate_uuid import generate_uuid
from app.service.utils.new_oss_client import oss_upload_image
from app.core.config import settings
@@ -26,14 +17,9 @@ class GenerateBrandInfo:
# user info init
self.user_id = request_data.user_id
self.category = "brand_logo"
# generate logo init
self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL)
self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8)
self.batch_size = 1
self.mode = 'txt2img'
# llm generate brand info init
self.model = ChatTongyi(model="qwen2.5-14b-instruct", api_key="sk-7658298c6b99443c98184a5e634fe6ab")
self.model = ChatTongyi(model="qwen2.5-14b-instruct", api_key=settings.QWEN_API_KEY)
self.response_schemas = [
ResponseSchema(name="brand_name", description="Brand name."),
@@ -63,38 +49,20 @@ class GenerateBrandInfo:
self.generate_logo_prompt = brand_data['brand_logo_prompt']
def generate_brand_logo(self):
prompts = [self.generate_logo_prompt] * self.batch_size
modes = [self.mode] * self.batch_size
images = [self.image.astype(np.float16)] * self.batch_size
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
mode_obj = np.array(modes, dtype="object").reshape((-1, 1))
image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3))
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
input_image = grpcclient.InferInput("input_image", image_obj.shape, np_to_triton_dtype(image_obj.dtype))
input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(mode_obj.dtype))
input_text.set_data_from_numpy(text_obj)
input_image.set_data_from_numpy(image_obj)
input_mode.set_data_from_numpy(mode_obj)
inputs = [input_text, input_image, input_mode]
result = self.grpc_client.infer(model_name=GI_MODEL_NAME, inputs=inputs)
image = result.as_numpy("generated_image")
image_result = cv2.cvtColor(np.squeeze(image.astype(np.uint8)), cv2.COLOR_RGB2BGR)
logo_url = self.upload_logo_image(image_result, generate_uuid())
self.result_data['brand_logo'] = logo_url
def upload_logo_image(self, image, object_name):
try:
_, img_byte_array = cv2.imencode('.jpg', image)
object_name = f'{self.user_id}/{self.category}/{object_name}.jpg'
oss_upload_image(oss_client=self.minio_client, bucket="aida-users", object_name=object_name, image_bytes=img_byte_array)
image_url = f"aida-users/{object_name}"
return image_url
except Exception as e:
logging.warning(f"upload_png_mask runtime exception : {e}")
request_item = {
"bucket_name": "aida-users",
"object_name": f'{self.user_id}/{self.category}/{uuid.uuid4().hex}.png',
"prompt": self.generate_logo_prompt,
"height": 1024,
"width": 1024
}
with httpx.Client(timeout=120) as client:
resp = client.post(
f"http://{settings.FLUX2_GEN_IMG_MODEL_URL}/predict",
json=request_item,
)
result = resp.json()
self.result_data['brand_logo'] = result.get("output_path", "")
if __name__ == '__main__':

View File

@@ -23,7 +23,7 @@ class ClothingSeg:
def __init__(self, request_data):
self.image_data = request_data.image_data
self.user_id = request_data.user_id
self.triton_client = grpcclient.InferenceServerClient(url="10.1.1.243:10071")
self.triton_client = grpcclient.InferenceServerClient(url=f"{settings.B_4_X_4090_SERVICE_HOST}:10071")
@RunTime
def get_result(self):
@@ -139,7 +139,7 @@ def get_bounding_box(mask):
if __name__ == "__main__":
test_data = ClothingSegModel(
user_id=89,
user_id="89",
image_data=[
# {
# "image_url": "test/clothing_seg/dress.jpg",

View File

@@ -13,7 +13,7 @@ from PIL import Image
from minio import Minio, S3Error
from moviepy.video.io.VideoFileClip import VideoFileClip
from app.core.config import settings
from app.core.config import settings, PS_RABBITMQ_QUEUES
from app.schemas.comfyui_i2v import ComfyuiPose2VModel
from app.service.generate_image.utils.mq import publish_status
@@ -622,9 +622,9 @@ class ComfyUIServerPose2V:
# 推送消息
if not settings.DEBUG:
publish_status(json.dumps(self.pose_transform_data), settings.COMFYUI_SERVER_ADDRESS)
publish_status(json.dumps(self.pose_transform_data), PS_RABBITMQ_QUEUES)
logger.info(
f" [x] Sent to {settings.COMFYUI_SERVER_ADDRESS} data@@@@ {json.dumps(self.pose_transform_data, indent=4)}")
f" [x] Sent to {PS_RABBITMQ_QUEUES} data@@@@ {json.dumps(self.pose_transform_data, indent=4)}")
return "\n🎉 所有任务完成!"

View File

@@ -10,13 +10,13 @@
import logging
import cv2
import mmcv
import numpy as np
import torch
import torch.nn.functional as F
import tritonclient.http as httpclient
from app.core.config import DESIGN_MODEL_URL, DESIGN_MODEL_NAME
from app.service.utils.image_normalize import my_imnormalize
"""
keypoint
@@ -25,13 +25,13 @@ from app.core.config import DESIGN_MODEL_URL, DESIGN_MODEL_NAME
def keypoint_preprocess(img_path):
img = mmcv.imread(img_path)
img = img_path
img_scale = (256, 256)
h, w = img.shape[:2]
img = cv2.resize(img, img_scale)
w_scale = img_scale[0] / w
h_scale = img_scale[1] / h
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
img = my_imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img, (w_scale, h_scale)
@@ -74,7 +74,7 @@ def keypoint_postprocess(output, scale_factor):
# KNet
def seg_preprocess(img_path):
img = mmcv.imread(img_path)
img = img_path
ori_shape = img.shape[:2]
img_scale_w, img_scale_h = ori_shape
if ori_shape[0] > 1024:
@@ -83,9 +83,9 @@ def seg_preprocess(img_path):
img_scale_h = 1024
# 如果图片size任意一边 大于 1024 则会resize 成1024
if ori_shape != (img_scale_w, img_scale_h):
# mmcv.imresize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了
# my_imnormalize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了
img = cv2.resize(img, (img_scale_h, img_scale_w))
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
img = my_imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img, ori_shape

View File

@@ -6,10 +6,10 @@ import requests
from minio import Minio
from app.core.config import settings
from app.service.design_fast.item import BodyItem, TopItem, BottomItem, OthersItem
from app.service.design_fast.item import BodyItem, TopItem, BottomItem, OthersItem, TopMergeItem, BottomMergeItem, OthersMergeItem
from app.service.design_fast.utils.organize import organize_body, organize_clothing, organize_others
from app.service.design_fast.utils.progress import final_progress, update_progress
from app.service.design_fast.utils.synthesis_item import synthesis, synthesis_single, update_base_size_priority
from app.service.design_fast.utils.synthesis_item import synthesis, synthesis_single, update_base_size_priority, merge
from app.service.utils.decorator import RunTime
id_lock = threading.Lock()
@@ -19,22 +19,46 @@ logger = logging.getLogger()
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
def process_item(item, basic):
# 处理project中单个item
if item['type'] == "Body":
body_server = BodyItem(data=item, basic=basic, minio_client=minio_client)
item_data = body_server.process()
elif item['type'].lower() in ['blouse', 'outwear', 'dress', 'tops']:
top_server = TopItem(data=item, basic=basic, minio_client=minio_client)
item_data = top_server.process()
elif item['type'].lower() in ['skirt', 'trousers', 'bottoms']:
bottom_server = BottomItem(data=item, basic=basic, minio_client=minio_client)
item_data = bottom_server.process()
elif item['type'].lower() in ['others']:
bottom_server = OthersItem(data=item, basic=basic, minio_client=minio_client)
item_data = bottom_server.process()
def process_item(item, basic, design_type):
# 1. 定义映射配置
# key 为 item_type 的小写value 为对应的处理类
DESIGN_MAP = {
'body': BodyItem,
'blouse': TopItem, 'outwear': TopItem,
'dress': TopItem, 'tops': TopItem,
'skirt': BottomItem, 'trousers': BottomItem,
'bottoms': BottomItem,
'others': OthersItem
}
MERGE_MAP = {
'body_merge': BodyItem,
'blouse_merge': TopMergeItem, 'outwear_merge': TopMergeItem,
'dress_merge': TopMergeItem, 'tops_merge': TopMergeItem,
'skirt_merge': BottomMergeItem, 'trousers_merge': BottomMergeItem,
'bottoms_merge': BottomMergeItem,
'others_merge': OthersMergeItem
}
# 2. 根据 design_type 选择映射表
mapping = MERGE_MAP if design_type == 'merge' else DESIGN_MAP
if design_type == 'merge':
item_type_key = f"{item['type'].lower()}_merge"
elif design_type == 'default':
item_type_key = item['type'].lower()
else:
raise NotImplementedError(f"Item type {item['type']} not implemented")
item_type_key = item['type'].lower()
handler_class = mapping.get(item_type_key)
if not handler_class:
raise NotImplementedError(f"Item type {item['type']} not implemented for design_type={design_type}")
# 4. 统一实例化并执行
# 注意:这里假设所有 Item 类构造函数签名一致
server = handler_class(data=item, basic=basic, minio_client=minio_client)
item_data = server.process()
return item_data
@@ -44,7 +68,7 @@ def process_layer(item, layers):
body_layer = organize_body(item)
layers.append(body_layer)
return item['body_image'].size
elif item['name'] == 'others':
elif item['name'] in ['others', 'others_merge']:
front_layer, back_layer = organize_others(item)
layers.append(front_layer)
layers.append(back_layer)
@@ -70,10 +94,11 @@ def design_generate(request_data):
nonlocal active_threads
basic = object['basic']
items_response = {'layers': [], 'objectSign': object['objectSign'] if 'objectSign' in object.keys() else ""}
design_type = basic.get('design_type', "default")
if basic['single_overall'] == "overall":
item_results = []
for item in object['items']:
item_results.append(process_item(item, basic))
item_results.append(process_item(item, basic, design_type))
layers = []
for item in item_results:
process_layer(item, layers)
@@ -93,12 +118,19 @@ def design_generate(request_data):
'image_url': lay['image_url'] if 'image_url' in lay.keys() else None,
'pattern_overall_image_url': lay['pattern_overall_image_url'] if 'pattern_overall_image_url' in lay.keys() else None,
'pattern_print_image_url': lay['pattern_print_image_url'] if 'pattern_print_image_url' in lay.keys() else None,
'transpose': lay.get('transpose', None),
'rotate': lay.get('rotate', None),
# 'back_perspective_url': lay['back_perspective_url'] if 'back_perspective_url' in lay.keys() else None,
})
if basic.get('design_type') == 'default':
items_response['synthesis_url'] = synthesis(layers, new_size, basic)
elif basic.get('design_type') == 'merge':
items_response['synthesis_url'] = merge(layers, new_size, basic)
else:
item_result = process_item(object['items'][0], basic)
items_response['synthesis_url'] = synthesis(layers, new_size, basic)
else:
item_result = process_item(object['items'][0], basic, design_type)
items_response['layers'].append({
'image_category': f"{item_result['name']}_front",
'image_size': item_result['back_image'].size if item_result['back_image'] else None,
@@ -152,6 +184,7 @@ def design_generate_v2(request_data):
def process_object(object, callback_url):
basic = object['basic']
design_type = basic.get('design_type', "default")
items_response = {
'layers': [],
'objectSign': object['objectSign'] if 'objectSign' in object.keys() else "",
@@ -160,7 +193,7 @@ def design_generate_v2(request_data):
if basic['single_overall'] == "overall":
item_results = []
for item in object['items']:
item_results.append(process_item(item, basic))
item_results.append(process_item(item, basic, design_type))
layers = []
for item in item_results:
process_layer(item, layers)
@@ -185,7 +218,7 @@ def design_generate_v2(request_data):
})
items_response['synthesis_url'] = synthesis(layers, new_size, basic)
else:
item_result = process_item(object['items'][0], basic)
item_result = process_item(object['items'][0], basic, design_type)
items_response['layers'].append({
'image_category': f"{item_result['name']}_front",
'image_size': item_result['back_image'].size if item_result['back_image'] else None,

View File

@@ -7,6 +7,7 @@ class BaseItem:
self.result['name'] = data['type'].lower()
self.result.pop("type")
self.result.update(basic)
self.result['design_type'] = basic.get('design_type', None)
class OthersItem(BaseItem):
@@ -14,10 +15,7 @@ class OthersItem(BaseItem):
super().__init__(data, basic)
self.Others_pipeline = [
LoadImage(minio_client),
# KeyPoint(),
# ContourDetection(),
Segmentation(minio_client),
# BackPerspective(minio_client),
Color(minio_client),
NoSegPrintPainting(minio_client),
PrintPainting(minio_client),
@@ -74,6 +72,65 @@ class BottomItem(BaseItem):
return self.result
"""merge"""
class OthersMergeItem(BaseItem):
def __init__(self, data, basic, minio_client):
super().__init__(data, basic)
self.Others_pipeline = [
LoadImage(minio_client),
# KeyPoint(),
# ContourDetection(),
Segmentation(minio_client),
# BackPerspective(minio_client),
Color(minio_client),
# NoSegPrintPainting(minio_client),
# PrintPainting(minio_client),
Scaling(),
Split(minio_client)
]
def process(self):
for item in self.Others_pipeline:
self.result = item(self.result)
return self.result
class TopMergeItem(BaseItem):
def __init__(self, data, basic, minio_client):
super().__init__(data, basic)
self.top_pipeline = [
LoadImage(minio_client),
KeyPoint(),
Segmentation(minio_client),
Scaling(),
Split(minio_client)
]
def process(self):
for item in self.top_pipeline:
self.result = item(self.result)
return self.result
class BottomMergeItem(BaseItem):
def __init__(self, data, basic, minio_client):
super().__init__(data, basic)
self.bottom_pipeline = [
LoadImage(minio_client),
KeyPoint(),
Segmentation(minio_client),
Scaling(),
Split(minio_client)
]
def process(self):
for item in self.bottom_pipeline:
self.result = item(self.result)
return self.result
class BodyItem(BaseItem):
def __init__(self, data, basic, minio_client):
super().__init__(data, basic)

View File

@@ -1,7 +1,7 @@
import logging
import numpy as np
from pymilvus import MilvusClient
# from pymilvus import MilvusClient
from app.core.config import KEYPOINT_RESULT_TABLE_FIELD_SET, MILVUS_TABLE_KEYPOINT, settings
from app.service.design_fast.utils.design_ensemble import get_keypoint_result
@@ -54,63 +54,64 @@ class KeyPoint:
"keypoint_vector": result.tolist()
}
]
try:
client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS)
client.upsert(collection_name=MILVUS_TABLE_KEYPOINT, data=data)
client.close()
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
except Exception as e:
logger.info(f"save keypoint cache milvus error : {e}")
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
@staticmethod
def update_keypoint_cache(keypoint_id, infer_result, search_result, site):
if site == "up":
# 需要的是up 即推理出来的是up 那么查询的就是down
result = np.concatenate([infer_result.flatten(), search_result[-4:]])
else:
# 需要的是down 即推理出来的是down 那么查询的就是up
result = np.concatenate([search_result[:20], infer_result.flatten()])
data = [
{"keypoint_id": keypoint_id,
"keypoint_site": "all",
"keypoint_vector": result.tolist()
}
]
# try:
# client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS)
# client.upsert(collection_name=MILVUS_TABLE_KEYPOINT, data=data)
# client.close()
# except Exception as e:
# logger.info(f"save keypoint cache milvus error : {e}")
# return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
try:
client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS)
client.upsert(
collection_name=MILVUS_TABLE_KEYPOINT,
data=data
)
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
except Exception as e:
logger.info(f"save keypoint cache milvus error : {e}")
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
# @staticmethod
# def update_keypoint_cache(keypoint_id, infer_result, search_result, site):
# if site == "up":
# # 需要的是up 即推理出来的是up 那么查询的就是down
# result = np.concatenate([infer_result.flatten(), search_result[-4:]])
# else:
# # 需要的是down 即推理出来的是down 那么查询的就是up
# result = np.concatenate([search_result[:20], infer_result.flatten()])
# data = [
# {"keypoint_id": keypoint_id,
# "keypoint_site": "all",
# "keypoint_vector": result.tolist()
# }
# ]
#
# try:
# client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS)
# client.upsert(
# collection_name=MILVUS_TABLE_KEYPOINT,
# data=data
# )
# return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
# except Exception as e:
# logger.info(f"save keypoint cache milvus error : {e}")
# return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
@RunTime
def keypoint_cache(self, result, site):
try:
client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS)
keypoint_id = result['image_id']
res = client.query(
collection_name=MILVUS_TABLE_KEYPOINT,
# ids=[keypoint_id],
filter=f"keypoint_id == {keypoint_id}",
output_fields=['keypoint_vector', 'keypoint_site']
)
if len(res) == 0:
# 没有结果 直接推理拿结果 并保存
keypoint_infer_result, site = self.infer_keypoint_result(result)
return self.save_keypoint_cache(result['image_id'], keypoint_infer_result, site)
elif res[0]["keypoint_site"] == "all" or res[0]["keypoint_site"] == site:
# 需要的类型和查询的类型一致或者查询的类型为all 则直接返回查询的结果
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, np.array(res[0]['keypoint_vector']).astype(int).reshape(12, 2).tolist()))
elif res[0]["keypoint_site"] != site:
# 需要的类型和查询到的不一致则更新类型为all
keypoint_infer_result, site = self.infer_keypoint_result(result)
return self.update_keypoint_cache(result["image_id"], keypoint_infer_result, res[0]['keypoint_vector'], site)
except Exception as e:
logger.info(f"search keypoint cache milvus error {e}")
return False
# @RunTime
# def keypoint_cache(self, result, site):
# try:
# client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS)
# keypoint_id = result['image_id']
# res = client.query(
# collection_name=MILVUS_TABLE_KEYPOINT,
# # ids=[keypoint_id],
# filter=f"keypoint_id == {keypoint_id}",
# output_fields=['keypoint_vector', 'keypoint_site']
# )
# if len(res) == 0:
# # 没有结果 直接推理拿结果 并保存
# keypoint_infer_result, site = self.infer_keypoint_result(result)
# return self.save_keypoint_cache(result['image_id'], keypoint_infer_result, site)
# elif res[0]["keypoint_site"] == "all" or res[0]["keypoint_site"] == site:
# # 需要的类型和查询的类型一致或者查询的类型为all 则直接返回查询的结果
# return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, np.array(res[0]['keypoint_vector']).astype(int).reshape(12, 2).tolist()))
# elif res[0]["keypoint_site"] != site:
# # 需要的类型和查询到的不一致则更新类型为all
# keypoint_infer_result, site = self.infer_keypoint_result(result)
# return self.update_keypoint_cache(result["image_id"], keypoint_infer_result, res[0]['keypoint_vector'], site)
# except Exception as e:
# logger.info(f"search keypoint cache milvus error {e}")
# return False

View File

@@ -35,15 +35,9 @@ class LoadImage:
return cls.name
def __call__(self, result):
if result.get("merge_image_path"):
result['merge_image'], _ = self.read_image(result['merge_image_path'])
result['image'], result['pre_mask'] = self.read_image(result['path'])
# if 'extract_lines' in result.keys():
# if result['extract_lines']:
# result['gray'] = self.get_lines(cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY), result['path'])
# else:
# result['gray'] = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY)
# else:
# result['gray'] = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY)
result['gray'] = self.get_lines(cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY))
result['keypoint'] = self.get_keypoint(result['name'])
result['img_shape'] = result['image'].shape
@@ -61,21 +55,6 @@ class LoadImage:
mask = skeleton
result = np.ones_like(img) * 255
result[mask] = img[mask]
# 步骤2细化边缘可选让线条更干净
# kernel = np.ones((1, 1), np.uint8)
# clean = cv2.morphologyEx(binary, cv2.MORPH_OPEN, kernel)
# thinned = cv2.ximgproc.thinning(binary, thinningType=cv2.ximgproc.THINNING_ZHANGSUEN) # thinning算法细化线条
# mask = thinned > 0
# result = np.ones_like(img) * 255
# result[mask] = img[mask]
# 步骤3反转回 白底黑线
# lines = cv2.bitwise_not(thinned)
# cv2.imwrite(os.path.join('/home/user/PycharmProjects/trinity_client_aida/test/lines_original_result_5', f"Original_{path.replace('/', '-')}.png"), img)
# cv2.imwrite(os.path.join('/home/user/PycharmProjects/trinity_client_aida/test/lines_original_result_5', f"Line_{path.replace('/', '-')}.png"), result)
return result
def read_image(self, image_path):
@@ -96,19 +75,19 @@ class LoadImage:
@staticmethod
def get_keypoint(name):
if name == 'blouse' or name == 'outwear' or name == 'dress' or name == 'tops':
if name in ['blouse', 'outwear', 'dress', 'tops', 'blouse_merge', 'outwear_merge', 'dress_merge', 'tops_merge']:
keypoint = 'shoulder'
elif name == 'trousers' or name == 'skirt' or name == 'bottoms':
elif name in ['trousers', 'skirt', 'bottoms', 'trousers_merge', 'skirt_merge', 'bottoms_merge']:
keypoint = 'waistband'
elif name == 'bag':
elif name in ['bag', 'bag_merge']:
keypoint = 'hand_point'
elif name == 'shoes':
elif name in ['shoes', 'shoes_merge']:
keypoint = 'toe'
elif name == 'hairstyle':
elif name in ['hairstyle', 'hairstyle_merge']:
keypoint = 'head_point'
elif name == 'earring':
elif name in ['earring', 'earring_merge']:
keypoint = 'ear_point'
elif name == 'others':
elif name in ['others', 'others_merge']:
keypoint = "others"
else:
raise KeyError(f"{name} does not belong to item category list: blouse, outwear, dress, trousers, skirt, "

View File

@@ -9,32 +9,27 @@ from app.service.utils.new_oss_client import oss_get_image
class NoSegPrintPainting:
def __init__(self, minio_client):
self.random_seed = random.randint(0, 1000)
self.minio_client = minio_client
def __call__(self, result):
single_print = result['print']['single']
# single_print = [result['print']['single']]
overall_print = result['print']['overall']
element_print = result['print']['element']
# element_print = result['print']['element'
single_print = None
element_print = None
result['single_image'] = None
result['print_image'] = None
if overall_print['print_path_list']:
painting_dict = {'dim_image_h': result['pattern_image'].shape[0], 'dim_image_w': result['pattern_image'].shape[1]}
if "print_angle_list" in overall_print.keys() and overall_print['print_angle_list'][0] != 0:
painting_dict = self.painting_collection(painting_dict, overall_print, print_trigger=True)
painting_dict['tile_print'] = self.rotate_crop_image(img=painting_dict['tile_print'], angle=-overall_print['print_angle_list'][0], crop=True)
painting_dict['mask_inv_print'] = self.rotate_crop_image(img=painting_dict['mask_inv_print'], angle=-overall_print['print_angle_list'][0], crop=True)
# resize 到sketch大小
painting_dict['tile_print'] = self.resize_and_crop(img=painting_dict['tile_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h'])
painting_dict['mask_inv_print'] = self.resize_and_crop(img=painting_dict['mask_inv_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h'])
else:
painting_dict = self.painting_collection(painting_dict, overall_print, print_trigger=True, is_single=False)
# 获取平铺 + 旋转 的overall print
painting_dict = self.painting_collection(painting_dict, overall_print)
result['no_seg_sketch_overall'] = result['no_seg_sketch_print'] = self.printpaint(result, painting_dict, print_=True)
result['pattern_image'] = result['no_seg_sketch_overall']
# result['pattern_image'] = result['no_seg_sketch_overall']
if single_print['print_path_list']:
if single_print:
print_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
mask_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
for i in range(len(single_print['print_path_list'])):
@@ -74,7 +69,7 @@ class NoSegPrintPainting:
single_image = cv2.add(tmp1, tmp2)
result['no_seg_sketch_print'] = single_image
if element_print['element_path_list']:
if element_print:
print_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8)
mask_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8)
for i in range(len(element_print['element_path_list'])):
@@ -151,7 +146,6 @@ class NoSegPrintPainting:
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
result['no_seg_sketch_print'] = cv2.add(tmp1, tmp2)
return result
@staticmethod
@@ -166,26 +160,23 @@ class NoSegPrintPainting:
print_background = img1_bg + img2_fg
return print_background
def painting_collection(self, painting_dict, print_dict, print_trigger=False, is_single=False):
if print_trigger:
def painting_collection(self, painting_dict, print_dict):
print_ = self.get_print(print_dict)
painting_dict['Trigger'] = not is_single
painting_dict['location'] = print_['location']
single_mask_inv_print = self.get_mask_inv(print_['image'])
dim_max = max(painting_dict['dim_image_h'], painting_dict['dim_image_w'])
dim_pattern = (int(dim_max * print_['scale'] / 5), int(dim_max * print_['scale'] / 5))
if not is_single:
# 如果print 模式为overall 且 有角度的话 组合的print为正方形方便裁剪
if "print_angle_list" in print_dict.keys() and print_dict['print_angle_list'][0] != 0:
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], dim_max, dim_max, painting_dict['location'], trigger=True)
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], dim_max, dim_max, painting_dict['location'], trigger=True)
else:
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'], trigger=True)
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'], trigger=True)
else:
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'])
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'])
painting_dict['dim_print_h'], painting_dict['dim_print_w'] = dim_pattern
gap = print_dict.get('gap', [[0, 0]])[0]
painting_dict['tile_print'], painting_dict['mask_inv_print'] = tile_image(pattern=print_['image'],
mask=print_['mask'],
dim=dim_pattern,
gap_x=gap[0],
gap_y=gap[1],
canvas_h=painting_dict['dim_image_h'],
canvas_w=painting_dict['dim_image_w'],
location=painting_dict['location'],
angle=int(print_.get('print_angle_list', [0])[0]))
# painting_dict['mask_inv_print'] = np.zeros(painting_dict['tile_print'].shape[:2], dtype=np.uint8)
# painting_dict['mask_inv_print'] = self.get_mask_inv(painting_dict['tile_print'])
return painting_dict
def tile_image(self, pattern, dim, scale, dim_image_h, dim_image_w, location, trigger=False):
@@ -219,33 +210,32 @@ class NoSegPrintPainting:
@staticmethod
def printpaint(result, painting_dict, print_=False):
if print_ and painting_dict['Trigger']:
if print_:
print_mask = cv2.bitwise_and(result['mask'], cv2.bitwise_not(painting_dict['mask_inv_print']))
img_fg = cv2.bitwise_and(painting_dict['tile_print'], painting_dict['tile_print'], mask=print_mask)
else:
print_mask = result['mask']
img_fg = result['final_image']
if print_ and not painting_dict['Trigger']:
index_ = None
try:
index_ = len(painting_dict['location'])
except:
assert f'there must be parameter of location if choose IfSingle'
for i in range(index_):
start_h, start_w = int(painting_dict['location'][i][1]), int(painting_dict['location'][i][0])
length_h = min(start_h + painting_dict['dim_print_h'], img_fg.shape[0])
length_w = min(start_w + painting_dict['dim_print_w'], img_fg.shape[1])
change_region = img_fg[start_h: length_h, start_w: length_w, :]
# problem in change_mask
change_mask = print_mask[start_h: length_h, start_w: length_w]
# get real part into change mask
_, change_mask = cv2.threshold(change_mask, 220, 255, cv2.THRESH_BINARY)
cv2.bitwise_not(painting_dict['mask_inv_print'])
img_fg[start_h:start_h + painting_dict['dim_print_h'], start_w:start_w + painting_dict['dim_print_w'], :] = change_region
# if print_ and not painting_dict['Trigger']:
# index_ = None
# try:
# index_ = len(painting_dict['location'])
# except:
# assert f'there must be parameter of location if choose IfSingle'
#
# for i in range(index_):
# start_h, start_w = int(painting_dict['location'][i][1]), int(painting_dict['location'][i][0])
#
# length_h = min(start_h + painting_dict['dim_print_h'], img_fg.shape[0])
# length_w = min(start_w + painting_dict['dim_print_w'], img_fg.shape[1])
#
# change_region = img_fg[start_h: length_h, start_w: length_w, :]
# # problem in change_mask
# change_mask = print_mask[start_h: length_h, start_w: length_w]
# # get real part into change mask
# _, change_mask = cv2.threshold(change_mask, 220, 255, cv2.THRESH_BINARY)
# cv2.bitwise_not(painting_dict['mask_inv_print'])
# img_fg[start_h:start_h + painting_dict['dim_print_h'], start_w:start_w + painting_dict['dim_print_w'], :] = change_region
clothes_mask_print = cv2.bitwise_not(print_mask)
@@ -267,18 +257,21 @@ class NoSegPrintPainting:
image = oss_get_image(oss_client=self.minio_client, bucket=bucket_name, object_name=object_name, data_type="PIL")
# 判断图片格式如果是RGBA 则贴在一张纯白图片上 防止透明转黑
if image.mode == "RGBA":
mask_pil = image.split()[3]
new_background = Image.new('RGB', image.size, (255, 255, 255))
new_background.paste(image, mask=image.split()[3])
image = new_background
else:
mask_pil = Image.new('L', image.size, 255) # L=灰度图255=纯白
print_dict['image'] = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
print_dict['mask'] = cv2.threshold(np.array(mask_pil), 127, 255, cv2.THRESH_BINARY)[1]
return print_dict
def crop_image(self, image, image_size_h, image_size_w, location, print_shape):
print_w = print_shape[1]
print_h = print_shape[0]
random.seed(self.random_seed)
# 1.拿到偏移量后和resize后的print宽高取余 得到真正偏移量
# 偏移量增加2分之print.w 使坐标位于图中间 如果要位于左上角删除+ print_w // 2 即可
x_offset = print_w - int(location[0][1] % print_w) + print_w // 2
@@ -420,3 +413,118 @@ class NoSegPrintPainting:
cropped_img = resized_img[start_y:start_y + target_height, :]
return cropped_img
def tile_image(pattern, mask, dim, gap_x, gap_y, canvas_h, canvas_w, location, angle=0):
"""
按照指定的 X/Y 间距平铺印花,并支持旋转
【修改版】以被平铺图案的【中心】作为平铺基准点
:param location: [[center_y, center_x]] → 第一个图案中心的坐标
:param angle: 旋转角度 (度数, 逆时针)
"""
# 1. 确保输入是 RGBA
if pattern.shape[2] == 3:
pattern = cv2.cvtColor(pattern, cv2.COLOR_BGR2BGRA)
# 2. 缩放与旋转印花
resized_p = cv2.resize(pattern, dim, interpolation=cv2.INTER_AREA)
rotated_p = rotate_image(resized_p, angle)
p_h, p_w = rotated_p.shape[:2]
# 3. 创建透明单元格(图案放在单元格中心)
cell_h = p_h + gap_y
cell_w = p_w + gap_x
unit_cell = np.zeros((cell_h, cell_w, 4), dtype=np.uint8)
# 计算图案在单元格中的左上角位置(让图案居中)
start_y = (cell_h - p_h) // 2
start_x = (cell_w - p_w) // 2
unit_cell[start_y:start_y + p_h, start_x:start_x + p_w, :] = rotated_p
# 4. 执行平铺
tiles_y = (canvas_h // cell_h) + 3 # 多加一点余量更安全
tiles_x = (canvas_w // cell_w) + 3
full_tiled = np.tile(unit_cell, (tiles_y, tiles_x, 1))
# 5. 计算偏移(关键修改:以中心为基准)
center_y, center_x = location[0][0], location[0][1] # 第一个图案的中心位置
# 计算从哪个位置开始裁剪,才能让中心落在指定坐标
offset_y = int((center_y - (p_h // 2)) % cell_h)
offset_x = int((center_x - (p_w // 2)) % cell_w)
tiled_layer = full_tiled[offset_y: offset_y + canvas_h,
offset_x: offset_x + canvas_w]
# 6. 创建纯白色背景并合成(保持你原来的风格)
white_background = np.full((canvas_h, canvas_w, 3), 255, dtype=np.uint8)
tiled_bgr = tiled_layer[:, :, :3]
alpha_mask = tiled_layer[:, :, 3] / 255.0
alpha_mask = cv2.merge([alpha_mask, alpha_mask, alpha_mask])
tiled_print = (tiled_bgr * alpha_mask + white_background * (1 - alpha_mask)).astype(np.uint8)
# ====================== 处理 Mask ======================
# Mask 也同样居中处理
resized_mask = cv2.resize(mask, dim, interpolation=cv2.INTER_NEAREST)
rotated_mask = rotate_image(resized_mask, angle) # 注意mask也需要旋转
unit_mask = np.zeros((cell_h, cell_w), dtype=np.uint8)
unit_mask[start_y:start_y + p_h, start_x:start_x + p_w] = rotated_mask
full_mask_tiled = np.tile(unit_mask, (tiles_y, tiles_x))
tiled_mask = full_mask_tiled[offset_y: offset_y + canvas_h,
offset_x: offset_x + canvas_w]
return tiled_print, cv2.bitwise_not(tiled_mask)
def rotate_image(image, angle):
"""
旋转图片并保持完整内容(自动扩大画布)
"""
if angle == 0:
return image
(h, w) = image.shape[:2]
(cX, cY) = (w // 2, h // 2)
# 获取旋转矩阵
M = cv2.getRotationMatrix2D((cX, cY), angle, 1.0)
# 计算旋转后新边界的 sine 和 cosine
cos = np.abs(M[0, 0])
sin = np.abs(M[0, 1])
# 计算新的画布尺寸
nW = int((h * sin) + (w * cos))
nH = int((h * cos) + (w * sin))
# 调整旋转矩阵以考虑平移
M[0, 2] += (nW / 2) - cX
M[1, 2] += (nH / 2) - cY
# 执行旋转
return cv2.warpAffine(image, M, (nW, nH))
def crop_image(image, image_size_h, image_size_w, location, print_shape):
print_w = print_shape[1]
print_h = print_shape[0]
# 1.拿到偏移量后和resize后的print宽高取余 得到真正偏移量
# 偏移量增加2分之print.w 使坐标位于图中间 如果要位于左上角删除+ print_w // 2 即可
x_offset = print_w - int(location[0][1] % print_w) + print_w // 2
y_offset = print_h - int(location[0][0] % print_h) + print_h // 2
# y_offset = int(location[0][0])
# x_offset = int(location[0][1])
if len(image.shape) == 2:
image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w]
elif len(image.shape) == 3:
image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w, :]
return image

View File

@@ -9,14 +9,17 @@ from app.service.utils.new_oss_client import oss_get_image
class PrintPainting:
def __init__(self, minio_client):
self.random_seed = None
self.minio_client = minio_client
def __call__(self, result):
single_print = result['print']['single']
# single_print = result['print']['single']
overall_print = result['print']['overall']
element_print = result['print']['element']
partial_path = result['print']['partial'] if 'partial' in result['print'] else None
# element_print = result['print']['element']
# partial_path = result['print']['partial'] if 'partial' in result['print'] else None
single_print = None
element_print = None
partial_path = None
result['single_image'] = None
result['print_image'] = None
# TODO 给result['pattern_image'] resize 到resize_scale的大小
@@ -38,24 +41,15 @@ class PrintPainting:
if overall_print['print_path_list']:
overall_print['location'][0] = [x * y for x, y in zip(overall_print['location'][0], result['resize_scale'])]
painting_dict = {'dim_image_h': result['pattern_image'].shape[0], 'dim_image_w': result['pattern_image'].shape[1]}
result['print_image'] = result['pattern_image']
if "print_angle_list" in overall_print.keys() and overall_print['print_angle_list'][0] != 0:
painting_dict = self.painting_collection(painting_dict, overall_print, print_trigger=True)
painting_dict['tile_print'] = self.rotate_crop_image(img=painting_dict['tile_print'], angle=-overall_print['print_angle_list'][0], crop=True)
painting_dict['mask_inv_print'] = self.rotate_crop_image(img=painting_dict['mask_inv_print'], angle=-overall_print['print_angle_list'][0], crop=True)
# resize 到sketch大小
painting_dict['tile_print'] = self.resize_and_crop(img=painting_dict['tile_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h'])
painting_dict['mask_inv_print'] = self.resize_and_crop(img=painting_dict['mask_inv_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h'])
else:
painting_dict = self.painting_collection(painting_dict, overall_print, print_trigger=True, is_single=False)
result['print_image'] = result['pattern_image'].copy()
# 获取平铺 + 旋转 的overall print
painting_dict = self.painting_collection(painting_dict, overall_print)
result['print_image'] = self.printpaint(result, painting_dict, print_=True)
result['single_image'] = result['final_image'] = result['pattern_image'] = result['print_image']
if single_print['print_path_list']:
if single_print:
# 2025-9-19 印花调整 印花坐标按照sketch的缩放比调整
sketch_resize_scale = result['resize_scale']
print_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
mask_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
for i in range(len(single_print['print_path_list'])):
@@ -78,75 +72,6 @@ class PrintPainting:
print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR)
mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR)
ret, mask_background = cv2.threshold(mask_background, 124, 255, cv2.THRESH_BINARY)
# else:
# mask = self.get_mask_inv(image)
# mask = np.expand_dims(mask, axis=2)
# mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
# mask = cv2.bitwise_not(mask)
#
# mask = cv2.resize(mask, (int(result['final_image'].shape[1] * single_print['print_scale_list'][i][0]), int(result['final_image'].shape[0] * single_print['print_scale_list'][i][1])))
# image = cv2.resize(image, (int(result['final_image'].shape[1] * single_print['print_scale_list'][i][0]), int(result['final_image'].shape[0] * single_print['print_scale_list'][i][1])))
# # 旋转后的坐标需要重新算
# rotate_mask, _ = self.img_rotate(mask, single_print['print_angle_list'][i])
# rotate_image, rotated_new_size = self.img_rotate(image, single_print['print_angle_list'][i])
# # x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2)
# x, y = int(single_print['location'][i][0] - rotated_new_size[0]), int(single_print['location'][i][1] - rotated_new_size[1])
#
# image_x = print_background.shape[1] # 底图宽
# image_y = print_background.shape[0] # 底图高
# print_x = rotate_image.shape[1] #印花宽
# print_y = rotate_image.shape[0] #印花高
#
# # 有bug
# # if x + print_x > image_x:
# # rotate_image = rotate_image[:, :x + print_x - image_x]
# # rotate_mask = rotate_mask[:, :x + print_x - image_x]
# # #
# # if y + print_y > image_y:
# # rotate_image = rotate_image[:y + print_y - image_y]
# # rotate_mask = rotate_mask[:y + print_y - image_y]
#
# # 不能是并行
# # 当前第一轮的if 108以及115是判断有没有过下界和右界。第二轮的是判断左上有没有超出。 如果这个样子的话先裁了右边再左移region就会有问题
# # 先挪 再判断 最后裁剪
#
# # 如果print旋转了 或者 print贴边了 则需要判断 判断左界和上界是否小于0
# if x <= 0: # 如果X轴偏移量小于0说明印花需要被裁剪至合适大小 或当X轴偏移量大于印花宽度时裁剪后的印花宽度为0
# rotate_image = rotate_image[:, abs(x):]
# rotate_mask = rotate_mask[:, abs(x):]
# start_x = x = 0
# else:
# start_x = x
#
# if y <= 0: # 如果X轴偏移量大于0说明印花需要被裁剪至合适大小 或当Y轴偏移量大于印花宽度时裁剪后的印花宽度为0
# rotate_image = rotate_image[abs(y):, :]
# rotate_mask = rotate_mask[abs(y):, :]
# start_y = y = 0
# else:
# start_y = y
#
# # ------------------
# # 如果print-size大于image-size 则需要裁剪print
#
# if x + print_x > image_x:
# rotate_image = rotate_image[:, :image_x - x]
# rotate_mask = rotate_mask[:, :image_x - x]
#
# if y + print_y > image_y:
# rotate_image = rotate_image[:image_y - y, :]
# rotate_mask = rotate_mask[:image_y - y, :]
#
# # mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = cv2.bitwise_xor(mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]], rotate_mask)
# # print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = cv2.add(print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]], rotate_image)
#
# # mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = rotate_mask
# # print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image
# mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x)
# print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x)
# gray_image = cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY)
# print_background = cv2.bitwise_and(print_background, print_background, mask=gray_image)
print_mask = cv2.bitwise_and(result['mask'], cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY))
img_fg = cv2.bitwise_or(print_background, print_background, mask=print_mask)
img_bg = cv2.bitwise_and(result['pattern_image'], result['pattern_image'], mask=cv2.bitwise_not(print_mask))
@@ -163,10 +88,9 @@ class PrintPainting:
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
result['single_image'] = cv2.add(tmp1, tmp2)
if element_print['element_path_list']:
if element_print:
# 2025-9-19 印花调整 印花坐标按照sketch的缩放比调整
sketch_resize_scale = result['resize_scale']
print_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8)
mask_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8)
for i in range(len(element_print['element_path_list'])):
@@ -207,20 +131,6 @@ class PrintPainting:
print_x = rotate_image.shape[1]
print_y = rotate_image.shape[0]
# 有bug
# if x + print_x > image_x:
# rotate_image = rotate_image[:, :x + print_x - image_x]
# rotate_mask = rotate_mask[:, :x + print_x - image_x]
# #
# if y + print_y > image_y:
# rotate_image = rotate_image[:y + print_y - image_y]
# rotate_mask = rotate_mask[:y + print_y - image_y]
# 不能是并行
# 当前第一轮的if 108以及115是判断有没有过下界和右界。第二轮的是判断左上有没有超出。 如果这个样子的话先裁了右边再左移region就会有问题
# 先挪 再判断 最后裁剪
# 如果print旋转了 或者 print贴边了 则需要判断 判断左界和上界是否小于0
if x <= 0:
rotate_image = rotate_image[:, -x:]
rotate_mask = rotate_mask[:, -x:]
@@ -235,9 +145,6 @@ class PrintPainting:
else:
start_y = y
# ------------------
# 如果print-size大于image-size 则需要裁剪print
if x + print_x > image_x:
rotate_image = rotate_image[:, :image_x - x]
rotate_mask = rotate_mask[:, :image_x - x]
@@ -246,11 +153,6 @@ class PrintPainting:
rotate_image = rotate_image[:image_y - y, :]
rotate_mask = rotate_mask[:image_y - y, :]
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = cv2.bitwise_xor(mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]], rotate_mask)
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = cv2.add(print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]], rotate_image)
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = rotate_mask
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image
mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x)
print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x)
@@ -298,12 +200,8 @@ class PrintPainting:
ret, mask_background = cv2.threshold(mask_background, 124, 255, cv2.THRESH_BINARY)
print_mask = cv2.bitwise_and(result['mask'], cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY))
img_fg = cv2.bitwise_or(print_background, print_background, mask=print_mask)
# TODO element 丢失信息
three_channel_image = cv2.merge([cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask)])
img_bg = cv2.bitwise_and(result['final_image'], three_channel_image)
# mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2)
# gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
# img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8)
result['final_image'] = cv2.add(img_bg, img_fg)
canvas = np.full_like(result['final_image'], 255)
temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2)
@@ -325,27 +223,21 @@ class PrintPainting:
print_background = img1_bg + img2_fg
return print_background
def painting_collection(self, painting_dict, print_dict, print_trigger=False, is_single=False):
if print_trigger:
def painting_collection(self, painting_dict, print_dict):
print_ = self.get_print(print_dict)
painting_dict['Trigger'] = not is_single
painting_dict['location'] = print_['location']
single_mask_inv_print = self.get_mask_inv(print_['image'])
dim_max = max(painting_dict['dim_image_h'], painting_dict['dim_image_w'])
dim_pattern = (int(dim_max * print_['scale'] / 5), int(dim_max * print_['scale'] / 5))
if not is_single:
self.random_seed = random.randint(0, 1000)
# 如果print 模式为overall 且 有角度的话 组合的print为正方形方便裁剪
if "print_angle_list" in print_dict.keys() and print_dict['print_angle_list'][0] != 0:
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], dim_max, dim_max, painting_dict['location'], trigger=True)
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], dim_max, dim_max, painting_dict['location'], trigger=True)
else:
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'], trigger=True)
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'], trigger=True)
else:
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'])
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'])
painting_dict['dim_print_h'], painting_dict['dim_print_w'] = dim_pattern
gap = print_dict.get('gap', [[0, 0]])[0]
painting_dict['tile_print'], painting_dict['mask_inv_print'] = tile_image(pattern=print_['image'],
mask=print_['mask'],
dim=dim_pattern,
gap_x=gap[0],
gap_y=gap[1],
canvas_h=painting_dict['dim_image_h'],
canvas_w=painting_dict['dim_image_w'],
location=painting_dict['location'],
angle=int(print_.get('print_angle_list', [0])[0]))
return painting_dict
def tile_image(self, pattern, dim, scale, dim_image_h, dim_image_w, location, trigger=False):
@@ -374,51 +266,37 @@ class PrintPainting:
mask_inv = cv2.inRange(print_tile, lower, upper)
return mask_inv
else:
# bg_color = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)[0][0]
# print_tile = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)
# bg_l, bg_a, bg_b = bg_color[0], bg_color[1], bg_color[2]
# bg_L_high, bg_L_low = self.get_low_high_lab(bg_l, L=True)
# bg_a_high, bg_a_low = self.get_low_high_lab(bg_a)
# bg_b_high, bg_b_low = self.get_low_high_lab(bg_b)
# lower = np.array([bg_L_low, bg_a_low, bg_b_low])
# upper = np.array([bg_L_high, bg_a_high, bg_b_high])
# print_tile = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)
# mask_inv = cv2.cvtColor(print_tile, cv2.COLOR_BGR2GRAY)
# mask_inv = cv2.cvtColor(print_, cv2.COLOR_BGR2GRAY)
mask_inv = np.zeros(print_.shape[:2], dtype=np.uint8)
return mask_inv
@staticmethod
def printpaint(result, painting_dict, print_=False):
if print_ and painting_dict['Trigger']:
if print_:
print_mask = cv2.bitwise_and(result['mask'], cv2.bitwise_not(painting_dict['mask_inv_print']))
img_fg = cv2.bitwise_and(painting_dict['tile_print'], painting_dict['tile_print'], mask=print_mask)
else:
print_mask = result['mask']
img_fg = result['final_image']
if print_ and not painting_dict['Trigger']:
index_ = None
try:
index_ = len(painting_dict['location'])
except:
assert f'there must be parameter of location if choose IfSingle'
for i in range(index_):
start_h, start_w = int(painting_dict['location'][i][1]), int(painting_dict['location'][i][0])
length_h = min(start_h + painting_dict['dim_print_h'], img_fg.shape[0])
length_w = min(start_w + painting_dict['dim_print_w'], img_fg.shape[1])
change_region = img_fg[start_h: length_h, start_w: length_w, :]
# problem in change_mask
change_mask = print_mask[start_h: length_h, start_w: length_w]
# get real part into change mask
_, change_mask = cv2.threshold(change_mask, 220, 255, cv2.THRESH_BINARY)
cv2.bitwise_not(painting_dict['mask_inv_print'])
img_fg[start_h:start_h + painting_dict['dim_print_h'], start_w:start_w + painting_dict['dim_print_w'], :] = change_region
# if print_ and not painting_dict['Trigger']:
# index_ = None
# try:
# index_ = len(painting_dict['location'])
# except:
# assert f'there must be parameter of location if choose IfSingle'
#
# for i in range(index_):
# start_h, start_w = int(painting_dict['location'][i][1]), int(painting_dict['location'][i][0])
#
# length_h = min(start_h + painting_dict['dim_print_h'], img_fg.shape[0])
# length_w = min(start_w + painting_dict['dim_print_w'], img_fg.shape[1])
#
# change_region = img_fg[start_h: length_h, start_w: length_w, :]
# # problem in change_mask
# change_mask = print_mask[start_h: length_h, start_w: length_w]
# # get real part into change mask
# _, change_mask = cv2.threshold(change_mask, 220, 255, cv2.THRESH_BINARY)
# cv2.bitwise_not(painting_dict['mask_inv_print'])
# img_fg[start_h:start_h + painting_dict['dim_print_h'], start_w:start_w + painting_dict['dim_print_w'], :] = change_region
clothes_mask_print = cv2.bitwise_not(print_mask)
@@ -440,21 +318,21 @@ class PrintPainting:
image = oss_get_image(oss_client=self.minio_client, bucket=bucket_name, object_name=object_name, data_type="PIL")
# 判断图片格式如果是RGBA 则贴在一张纯白图片上 防止透明转黑
if image.mode == "RGBA":
mask_pil = image.split()[3]
new_background = Image.new('RGB', image.size, (255, 255, 255))
new_background.paste(image, mask=image.split()[3])
image = new_background
else:
mask_pil = Image.new('L', image.size, 255) # L=灰度图255=纯白
print_dict['image'] = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
print_dict['mask'] = cv2.threshold(np.array(mask_pil), 127, 255, cv2.THRESH_BINARY)[1]
return print_dict
def crop_image(self, image, image_size_h, image_size_w, location, print_shape):
print_w = print_shape[1]
print_h = print_shape[0]
random.seed(self.random_seed)
# logging.info(f'overall print location : {location}')
# x_offset = random.randint(0, image.shape[0] - image_size_h)
# y_offset = random.randint(0, image.shape[1] - image_size_w)
# 1.拿到偏移量后和resize后的print宽高取余 得到真正偏移量
# 偏移量增加2分之print.w 使坐标位于图中间 如果要位于左上角删除+ print_w // 2 即可
x_offset = print_w - int(location[0][1] % print_w) + print_w // 2
@@ -596,3 +474,118 @@ class PrintPainting:
cropped_img = resized_img[start_y:start_y + target_height, :]
return cropped_img
def tile_image(pattern, mask, dim, gap_x, gap_y, canvas_h, canvas_w, location, angle=0):
"""
按照指定的 X/Y 间距平铺印花,并支持旋转
【修改版】以被平铺图案的【中心】作为平铺基准点
:param location: [[center_y, center_x]] → 第一个图案中心的坐标
:param angle: 旋转角度 (度数, 逆时针)
"""
# 1. 确保输入是 RGBA
if pattern.shape[2] == 3:
pattern = cv2.cvtColor(pattern, cv2.COLOR_BGR2BGRA)
# 2. 缩放与旋转印花
resized_p = cv2.resize(pattern, dim, interpolation=cv2.INTER_AREA)
rotated_p = rotate_image(resized_p, angle)
p_h, p_w = rotated_p.shape[:2]
# 3. 创建透明单元格(图案放在单元格中心)
cell_h = p_h + gap_y
cell_w = p_w + gap_x
unit_cell = np.zeros((cell_h, cell_w, 4), dtype=np.uint8)
# 计算图案在单元格中的左上角位置(让图案居中)
start_y = (cell_h - p_h) // 2
start_x = (cell_w - p_w) // 2
unit_cell[start_y:start_y + p_h, start_x:start_x + p_w, :] = rotated_p
# 4. 执行平铺
tiles_y = (canvas_h // cell_h) + 3 # 多加一点余量更安全
tiles_x = (canvas_w // cell_w) + 3
full_tiled = np.tile(unit_cell, (tiles_y, tiles_x, 1))
# 5. 计算偏移(关键修改:以中心为基准)
center_y, center_x = location[0][0], location[0][1] # 第一个图案的中心位置
# 计算从哪个位置开始裁剪,才能让中心落在指定坐标
offset_y = int((center_y - (p_h // 2)) % cell_h)
offset_x = int((center_x - (p_w // 2)) % cell_w)
tiled_layer = full_tiled[offset_y: offset_y + canvas_h,
offset_x: offset_x + canvas_w]
# 6. 创建纯白色背景并合成(保持你原来的风格)
white_background = np.full((canvas_h, canvas_w, 3), 255, dtype=np.uint8)
tiled_bgr = tiled_layer[:, :, :3]
alpha_mask = tiled_layer[:, :, 3] / 255.0
alpha_mask = cv2.merge([alpha_mask, alpha_mask, alpha_mask])
tiled_print = (tiled_bgr * alpha_mask + white_background * (1 - alpha_mask)).astype(np.uint8)
# ====================== 处理 Mask ======================
# Mask 也同样居中处理
resized_mask = cv2.resize(mask, dim, interpolation=cv2.INTER_NEAREST)
rotated_mask = rotate_image(resized_mask, angle) # 注意mask也需要旋转
unit_mask = np.zeros((cell_h, cell_w), dtype=np.uint8)
unit_mask[start_y:start_y + p_h, start_x:start_x + p_w] = rotated_mask
full_mask_tiled = np.tile(unit_mask, (tiles_y, tiles_x))
tiled_mask = full_mask_tiled[offset_y: offset_y + canvas_h,
offset_x: offset_x + canvas_w]
return tiled_print, cv2.bitwise_not(tiled_mask)
def rotate_image(image, angle):
"""
旋转图片并保持完整内容(自动扩大画布)
"""
if angle == 0:
return image
(h, w) = image.shape[:2]
(cX, cY) = (w // 2, h // 2)
# 获取旋转矩阵
M = cv2.getRotationMatrix2D((cX, cY), angle, 1.0)
# 计算旋转后新边界的 sine 和 cosine
cos = np.abs(M[0, 0])
sin = np.abs(M[0, 1])
# 计算新的画布尺寸
nW = int((h * sin) + (w * cos))
nH = int((h * cos) + (w * sin))
# 调整旋转矩阵以考虑平移
M[0, 2] += (nW / 2) - cX
M[1, 2] += (nH / 2) - cY
# 执行旋转
return cv2.warpAffine(image, M, (nW, nH))
def crop_image(image, image_size_h, image_size_w, location, print_shape):
print_w = print_shape[1]
print_h = print_shape[0]
# 1.拿到偏移量后和resize后的print宽高取余 得到真正偏移量
# 偏移量增加2分之print.w 使坐标位于图中间 如果要位于左上角删除+ print_w // 2 即可
x_offset = print_w - int(location[0][1] % print_w) + print_w // 2
y_offset = print_h - int(location[0][0] % print_h) + print_h // 2
# y_offset = int(location[0][0])
# x_offset = int(location[0][1])
if len(image.shape) == 2:
image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w]
elif len(image.shape) == 3:
image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w, :]
return image

View File

@@ -34,22 +34,25 @@ class Segmentation:
result['mask'] = result['front_mask'] + result['back_mask']
else:
# preview 过模型 不缓存
if "preview_submit" in result.keys() and result['preview_submit'] == "preview":
# 推理获得seg 结果
if result.get("design_type", None) == "merge":
seg_result = get_seg_result(result['image'])
# submit 过模型 缓存
elif "preview_submit" in result.keys() and result['preview_submit'] == "submit":
# 默认design 模式 - 过模型 缓存
# elif result.get("design_type", None) == "submit":
# 推理获得seg 结果
seg_result = get_seg_result(result['image'])
self.save_seg_result(seg_result, result['image_id'])
# null 正常流程 加载本地缓存 无缓存则过模型
# seg_result = get_seg_result(result['image'])
# self.save_seg_result(seg_result, result['image_id'])
# 默认模式- 加载模型,找不到则过模型推理,推理后保存到本地
else:
# 本地查询seg 缓存是否存在
_, seg_result = self.load_seg_result(result["image_id"])
# 判断缓存和实际图片size是否相同
_ = False
if not _ or result["image"].shape[:2] != seg_result.shape:
# 推理获得seg 结果
seg_result = get_seg_result(result['image'])
if result['name'] == 'others':
seg_result = seg_result.clip(max=1)
self.save_seg_result(seg_result, result['image_id'])
result['seg_result'] = seg_result

View File

@@ -4,6 +4,7 @@ import logging
import cv2
import numpy as np
from PIL import Image
from celery.bin.result import result
from app.service.design_fast.utils.conversion_image import rgb_to_rgba
from app.service.design_fast.utils.transparent import sketch_to_transparent
@@ -19,6 +20,52 @@ class Split(object):
def __call__(self, result):
try:
if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms', 'others'):
if result.get('design_type', None) == 'merge':
ori_front_mask = result['front_mask'].copy()
ori_back_mask = result['back_mask'].copy()
if result['resize_scale'][0] == 1.0 and result['resize_scale'][1] == 1.0:
front_mask = result['front_mask']
back_mask = result['back_mask']
else:
height, width = result['front_mask'].shape[:2]
new_width = int(width * result['resize_scale'][0])
new_height = int(height * result['resize_scale'][1])
front_mask = cv2.resize(result['front_mask'], (new_width, new_height), interpolation=cv2.INTER_AREA)
back_mask = cv2.resize(result['back_mask'], (new_width, new_height), interpolation=cv2.INTER_AREA)
result['merge_image'] = cv2.resize(result['merge_image'], (new_width, new_height), interpolation=cv2.INTER_AREA)
rgba_image = rgb_to_rgba(result['merge_image'], front_mask + back_mask)
new_size = (int(rgba_image.shape[1] * result["scale"]), int(rgba_image.shape[0] * result["scale"]))
rgba_image = cv2.resize(rgba_image, new_size, interpolation=cv2.INTER_AREA)
result_front_image = np.zeros_like(rgba_image)
front_mask = cv2.resize(front_mask, new_size, interpolation=cv2.INTER_AREA)
result_front_image[front_mask != 0] = rgba_image[front_mask != 0]
result_front_image_pil = Image.fromarray(cv2.cvtColor(result_front_image, cv2.COLOR_BGR2RGBA))
result['front_image'], result["front_image_url"], _ = upload_png_mask(self.minio_client, result_front_image_pil, f'{generate_uuid()}', mask=None)
height, width = ori_front_mask.shape
mask_image = np.zeros((height, width, 3))
mask_image[ori_front_mask != 0] = [0, 0, 255]
mask_image[ori_back_mask != 0] = [0, 255, 0]
rbga_mask = rgb_to_rgba(mask_image, ori_front_mask + ori_back_mask)
mask_pil = Image.fromarray(cv2.cvtColor(rbga_mask.astype(np.uint8), cv2.COLOR_BGR2RGBA))
image_data = io.BytesIO()
mask_pil.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
req = oss_upload_image(oss_client=self.minio_client, bucket="aida-clothing", object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
result['mask_url'] = req.bucket_name + "/" + req.object_name
result_back_image = np.zeros_like(rgba_image)
back_mask = cv2.resize(back_mask, new_size, interpolation=cv2.INTER_AREA)
result_back_image[back_mask != 0] = rgba_image[back_mask != 0]
result_back_image_pil = Image.fromarray(cv2.cvtColor(result_back_image, cv2.COLOR_BGR2RGBA))
result['back_image'], result["back_image_url"], _ = upload_png_mask(self.minio_client, result_back_image_pil, f'{generate_uuid()}', mask=None)
return result
else:
ori_front_mask = result['front_mask'].copy()
ori_back_mask = result['back_mask'].copy()
@@ -60,46 +107,9 @@ class Split(object):
result_front_image_pil = sketch_to_transparent(result_front_image_pil, front_mask, transparent["scale"])
result['front_image'], result["front_image_url"], _ = upload_png_mask(self.minio_client, result_front_image_pil, f'{generate_uuid()}', mask=None)
# 前片部分 (红图部分)
# height, width = front_mask.shape
# mask_image = np.zeros((height, width, 3))
# mask_image[front_mask != 0] = [0, 0, 255]
# 切换为原始图片尺寸-------------------------------
height, width = ori_front_mask.shape
mask_image = np.zeros((height, width, 3))
mask_image[ori_front_mask != 0] = [0, 0, 255]
# -----------------------------------------------
# if result["name"] in ('blouse', 'dress', 'outwear', 'tops'):
# result_back_image = np.zeros_like(rgba_image)
# back_mask = cv2.resize(back_mask, new_size, interpolation=cv2.INTER_AREA)
# result_back_image[back_mask != 0] = rgba_image[back_mask != 0]
# result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA))
# result['back_image'], result["back_image_url"], _ = upload_png_mask(self.minio_client, result_back_image_pil, f'{generate_uuid()}', mask=None)
# mask_image[back_mask != 0] = [0, 255, 0]
#
# rbga_mask = rgb_to_rgba(mask_image, front_mask + back_mask)
# mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
# image_data = io.BytesIO()
# mask_pil.save(image_data, format='PNG')
# image_data.seek(0)
# image_bytes = image_data.read()
# req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
# result['mask_url'] = req.bucket_name + "/" + req.object_name
# else:
# rbga_mask = rgb_to_rgba(mask_image, front_mask)
# mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
# image_data = io.BytesIO()
# mask_pil.save(image_data, format='PNG')
# image_data.seek(0)
# image_bytes = image_data.read()
# req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
# result['mask_url'] = req.bucket_name + "/" + req.object_name
# result['back_image'] = None
# result["back_image_url"] = None
# # result["back_mask_url"] = None
# # result['back_mask_image'] = None
result_back_image = np.zeros_like(rgba_image)
back_mask = cv2.resize(back_mask, new_size, interpolation=cv2.INTER_AREA)
@@ -118,6 +128,14 @@ class Split(object):
image_bytes = image_data.read()
req = oss_upload_image(oss_client=self.minio_client, bucket="aida-clothing", object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
result['mask_url'] = req.bucket_name + "/" + req.object_name
# 创建中间图层(未分割图层) 1.color + overall_print 2.color + overall_print + print
result_pattern_overall_image_pil = Image.fromarray(cv2.cvtColor(rgb_to_rgba(result['no_seg_sketch_overall'], ori_front_mask + ori_back_mask), cv2.COLOR_BGR2RGBA))
result['pattern_overall_image'], result['pattern_overall_image_url'], _ = upload_png_mask(self.minio_client, result_pattern_overall_image_pil, f'{generate_uuid()}')
result_pattern_print_image_pil = Image.fromarray(cv2.cvtColor(rgb_to_rgba(result['no_seg_sketch_print'], ori_front_mask + ori_back_mask), cv2.COLOR_BGR2RGBA))
result['pattern_print_image'], result['pattern_print_image_url'], _ = upload_png_mask(self.minio_client, result_pattern_print_image_pil, f'{generate_uuid()}')
return result
else:
ori_front_mask, ori_back_mask = None, None
# 创建中间图层(未分割图层) 1.color + overall_print 2.color + overall_print + print
@@ -127,5 +145,6 @@ class Split(object):
result_pattern_print_image_pil = Image.fromarray(cv2.cvtColor(rgb_to_rgba(result['no_seg_sketch_print'], ori_front_mask + ori_back_mask), cv2.COLOR_BGR2RGBA))
result['pattern_print_image'], result['pattern_print_image_url'], _ = upload_png_mask(self.minio_client, result_pattern_print_image_pil, f'{generate_uuid()}')
return result
except Exception as e:
logging.warning(f"split runtime exception : {e} image_id : {result['image_id']}")

View File

@@ -10,12 +10,12 @@
import logging
import cv2
import mmcv
import numpy as np
import torch
import tritonclient.http as httpclient
from app.core.config import DESIGN_MODEL_URL, DESIGN_MODEL_NAME
from app.service.utils.image_normalize import my_imnormalize
"""
keypoint
@@ -24,14 +24,14 @@ from app.core.config import DESIGN_MODEL_URL, DESIGN_MODEL_NAME
def keypoint_preprocess(img_path):
img = mmcv.imread(img_path)
img = img_path
img = cv2.copyMakeBorder(img, 25, 25, 25, 25, cv2.BORDER_CONSTANT, value=[255, 255, 255])
img_scale = (256, 256)
h, w = img.shape[:2]
img = cv2.resize(img, img_scale)
w_scale = img_scale[0] / w
h_scale = img_scale[1] / h
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
img = my_imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img, (w_scale, h_scale)
@@ -78,7 +78,7 @@ def keypoint_postprocess(output, scale_factor):
# KNet
def seg_preprocess(img_path):
img = mmcv.imread(img_path)
img = img_path
ori_shape = img.shape[:2]
img_scale_w, img_scale_h = ori_shape
if ori_shape[0] > 1024:
@@ -87,12 +87,12 @@ def seg_preprocess(img_path):
img_scale_h = 1024
# 如果图片size任意一边 大于 1024 则会resize 成1024
if ori_shape != (img_scale_w, img_scale_h):
# mmcv.imresize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了
# my_imnormalize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了
img = cv2.resize(img, (img_scale_h, img_scale_w))
# 扩充25的白边
img = cv2.copyMakeBorder(img, 25, 25, 25, 25, cv2.BORDER_CONSTANT, value=[255, 255, 255])
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
img = my_imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img, ori_shape

View File

@@ -23,20 +23,23 @@ def organize_clothing(layer):
front_layer = dict(priority=layer['priority'] if layer.get("layer_order", False) else PRIORITY_DICT.get(f'{layer["name"].lower()}_front', None),
name=f'{layer["name"].lower()}_front',
image=layer["front_image"],
merge_image=layer["front_image"],
# mask_image=layer['front_mask_image'],
image_url=layer['front_image_url'],
mask_url=layer['mask_url'],
mask_url=layer.get("mask_url", None),
sacle=layer['scale'],
clothes_keypoint=layer['clothes_keypoint'],
position=start_point,
resize_scale=layer["resize_scale"],
mask=cv2.resize(layer['mask'], layer["front_image"].size),
gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "",
pattern_overall_image_url=layer['pattern_overall_image_url'],
pattern_print_image_url=layer['pattern_print_image_url'],
pattern_overall_image_url=layer.get('pattern_overall_image_url', None),
pattern_print_image_url=layer.get('pattern_print_image_url', None),
pattern_image=layer['pattern_image'],
pattern_image=layer.get('pattern_image', None),
# back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else ""
transpose=layer.get("transpose", [1, 1]), # 默认为1, 1代表不镜像
rotate=layer.get('rotate', 0),
)
# 后片数据
back_layer = dict(priority=-layer.get("priority", 0) if layer.get("layer_order", False) else PRIORITY_DICT.get(f'{layer["name"].lower()}_back', None),
@@ -44,16 +47,18 @@ def organize_clothing(layer):
image=layer["back_image"],
# mask_image=layer['back_mask_image'],
image_url=layer['back_image_url'],
mask_url=layer['mask_url'],
mask_url=layer.get('mask_url', None),
sacle=layer['scale'],
clothes_keypoint=layer['clothes_keypoint'],
position=start_point,
resize_scale=layer["resize_scale"],
mask=cv2.resize(layer['mask'], layer["front_image"].size),
gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "",
pattern_overall_image_url=layer['pattern_overall_image_url'],
pattern_print_image_url=layer['pattern_print_image_url'],
pattern_overall_image_url=layer.get('pattern_overall_image_url', None),
pattern_print_image_url=layer.get('pattern_print_image_url', None),
# back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else ""
transpose=layer.get("transpose", [1, 1]), # 默认为1, 1代表不镜像
rotate=layer.get('rotate', 0),
)
return front_layer, back_layer
@@ -76,17 +81,19 @@ def organize_others(layer):
image=layer["front_image"],
# mask_image=layer['front_mask_image'],
image_url=layer['front_image_url'],
mask_url=layer['mask_url'],
mask_url=layer.get('mask_url', None),
sacle=layer['scale'],
clothes_keypoint=(0, 0),
position=start_point,
resize_scale=layer["resize_scale"],
mask=cv2.resize(layer['mask'], layer["front_image"].size),
gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "",
pattern_overall_image_url=layer['pattern_overall_image_url'],
pattern_print_image_url=layer['pattern_print_image_url'],
pattern_image=layer['pattern_image'],
pattern_overall_image_url=layer.get('pattern_overall_image_url', None),
pattern_print_image_url=layer.get('pattern_print_image_url', None),
pattern_image=layer.get('pattern_image', None),
# back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else ""
transpose=layer.get("transpose", [1, 1]), # 默认为1, 1代表不镜像
rotate=layer.get('rotate', 0),
)
# 后片数据
back_layer = dict(priority=-layer.get("priority", 0) if layer.get("layer_order", False) else PRIORITY_DICT.get(f'{layer["name"].lower()}_back', None),
@@ -94,16 +101,18 @@ def organize_others(layer):
image=layer["back_image"],
# mask_image=layer['back_mask_image'],
image_url=layer['back_image_url'],
mask_url=layer['mask_url'],
mask_url=layer.get('mask_url', None),
sacle=layer['scale'],
clothes_keypoint=(0, 0),
position=start_point,
resize_scale=layer["resize_scale"],
mask=cv2.resize(layer['mask'], layer["front_image"].size),
gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "",
pattern_overall_image_url=layer['pattern_overall_image_url'],
pattern_print_image_url=layer['pattern_print_image_url'],
pattern_overall_image_url=layer.get('pattern_overall_image_url', None),
pattern_print_image_url=layer.get('pattern_print_image_url', None),
# back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else ""
transpose=layer.get("transpose", [1, 1]), # 默认为1, 1代表不镜像
rotate=layer.get('rotate', 0),
)
return front_layer, back_layer

View File

@@ -151,9 +151,11 @@ def synthesis(data, size, basic_info):
if layer['image'] is not None:
if layer['name'] != "body":
test_image = Image.new('RGBA', size, (0, 0, 0, 0))
test_image.paste(layer['image'], (layer['adaptive_position'][1], layer['adaptive_position'][0]), layer['image'])
paste_img, position = transpose_rotate(layer, layer['image'])
test_image.paste(paste_img, position, paste_img)
mask_data = np.where(all_mask > 0, 255, 0).astype(np.uint8)
mask_alpha = Image.fromarray(mask_data)
mask_alpha.paste(paste_img.getchannel('A'), position, paste_img.getchannel('A'))
cropped_image = Image.composite(test_image, Image.new("RGBA", test_image.size, (255, 255, 255, 0)), mask_alpha)
base_image.paste(test_image, (0, 0), cropped_image) # test_image 已经按照坐标贴到最大宽值的图片上 坐着这里坐标为00
else:
@@ -185,6 +187,111 @@ def synthesis(data, size, basic_info):
logging.warning(f"synthesis runtime exception : {e}")
def merge(data, size, basic_info):
# out_of_bounds_control: 是否允许服装越界 True 允许 False 不允许 默认情况允许
out_of_bounds_control = basic_info.get('out_of_bounds_control', True)
# 创建底图
base_image = Image.new('RGBA', size, (0, 0, 0, 0))
try:
all_mask_shape = (size[1], size[0])
body_mask = None
for d in data:
if d['name'] == 'body' or d['name'] == 'mannequin':
# 创建一个新的宽高透明图像, 把模特贴上去获取mask
transparent_image = Image.new("RGBA", size, (0, 0, 0, 0))
transparent_image.paste(d['image'], (d['adaptive_position'][1], d['adaptive_position'][0]), d['image']) # 此处可变数组会被paste篡改值所以使用下标获取position
body_mask = np.array(transparent_image.split()[3])
# 根据新的坐标获取新的肩点
left_shoulder = [x + y for x, y in zip(basic_info['body_point_test']['shoulder_left'], [d['adaptive_position'][1], d['adaptive_position'][0]])]
right_shoulder = [x + y for x, y in zip(basic_info['body_point_test']['shoulder_right'], [d['adaptive_position'][1], d['adaptive_position'][0]])]
body_mask[:min(left_shoulder[1], right_shoulder[1]), left_shoulder[0]:right_shoulder[0]] = 255
_, binary_body_mask = cv2.threshold(body_mask, 127, 255, cv2.THRESH_BINARY)
top_outer_mask = np.array(binary_body_mask)
bottom_outer_mask = np.array(binary_body_mask)
others_outer_mask = np.array(binary_body_mask)
top = True
bottom = True
others = True
i = len(data)
while i:
i -= 1
if top and data[i]['name'] in ["blouse_front", "outwear_front", "dress_front", "tops_front"]:
if out_of_bounds_control:
top = True
else:
top = False
mask_shape = data[i]['mask'].shape
y_offset, x_offset = data[i]['adaptive_position']
# 初始化叠加区域的起始和结束位置
all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset)
all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset)
# 将叠加区域赋值为相应的像素值
_, sketch_mask = cv2.threshold(data[i]['mask'], 127, 255, cv2.THRESH_BINARY)
background = np.zeros_like(top_outer_mask)
background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end]
top_outer_mask = background + top_outer_mask
elif bottom and data[i]['name'] in ["trousers_front", "skirt_front", "bottoms_front", "dress_front"]:
# bottom = False
mask_shape = data[i]['mask'].shape
y_offset, x_offset = data[i]['adaptive_position']
# 初始化叠加区域的起始和结束位置
all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset)
all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset)
# 将叠加区域赋值为相应的像素值
_, sketch_mask = cv2.threshold(data[i]['mask'], 127, 255, cv2.THRESH_BINARY)
background = np.zeros_like(top_outer_mask)
background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end]
bottom_outer_mask = background + bottom_outer_mask
elif others and data[i]['name'] in ['others_front']:
mask_shape = data[i]['mask'].shape
y_offset, x_offset = data[i]['adaptive_position']
# 初始化叠加区域的起始和结束位置
all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset)
all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset)
# 将叠加区域赋值为相应的像素值
_, sketch_mask = cv2.threshold(data[i]['mask'], 127, 255, cv2.THRESH_BINARY)
background = np.zeros_like(top_outer_mask)
background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end]
others_outer_mask = background + others_outer_mask
pass
elif bottom is False and top is False:
break
all_mask = cv2.bitwise_or(top_outer_mask, bottom_outer_mask)
all_mask = cv2.bitwise_or(all_mask, others_outer_mask)
for layer in data:
if layer['image'] is not None:
if layer['name'] != "body":
test_image = Image.new('RGBA', size, (0, 0, 0, 0))
paste_img, position = transpose_rotate(layer, layer['image'])
test_image.paste(paste_img, position, paste_img)
mask_data = np.where(all_mask > 0, 255, 0).astype(np.uint8)
mask_alpha = Image.fromarray(mask_data)
mask_alpha.paste(paste_img.getchannel('A'), position, paste_img.getchannel('A'))
cropped_image = Image.composite(test_image, Image.new("RGBA", test_image.size, (255, 255, 255, 0)), mask_alpha)
base_image.paste(test_image, (0, 0), cropped_image) # test_image 已经按照坐标贴到最大宽值的图片上 坐着这里坐标为00
else:
base_image.paste(layer['merge_image'], (layer['adaptive_position'][1], layer['adaptive_position'][0]), layer['merge_image'])
result_image = base_image
image_data = io.BytesIO()
result_image.save(image_data, format='PNG')
image_data.seek(0)
# oss upload
image_bytes = image_data.read()
bucket_name = "aida-results"
object_name = f'result_{generate_uuid()}.png'
oss_upload_image(oss_client=minio_client, bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
return f"{bucket_name}/{object_name}"
except Exception as e:
logging.warning(f"synthesis runtime exception : {e}")
def synthesis_single(front_image, back_image):
result_image = None
if front_image:
@@ -232,3 +339,36 @@ def update_base_size_priority(layers):
for info in layers:
info['adaptive_position'] = (info['position'][0], info['position'][1] - min_x)
return layers, (new_width, new_height)
def transpose_rotate(layer, image):
# transpose[0]是左右 transpose[1]是上下
transpose = layer.get('transpose', [1, 1]) # 默认为1, 1代表不镜像
rotate = layer.get('rotate', 0)
paste_x, paste_y = layer['adaptive_position'][1], layer['adaptive_position'][0]
original_w = image.width
original_h = image.height
# transpose左右是1 上下是-1
if transpose[0] != 1:
# 左右
image = image.transpose(0)
if transpose[1] != 1:
# 上下
image = image.transpose(1)
if rotate:
image = image.rotate(-rotate, expand=True)
# 4. 计算粘贴位置以保持视觉中心一致
# 原本 (15, 36) 是 288*288 的左上角,我们计算其中心点
target_center_x = paste_x + original_w // 2
target_center_y = paste_y + original_h // 2
# 获取旋转后图像的新尺寸
new_w, new_h = image.size
# 计算新的左上角坐标,使得旋转后的图像中心依然在原定的中心位置
paste_x = target_center_x - new_w // 2
paste_y = target_center_y - new_h // 2
return image, (paste_x, paste_y)

View File

@@ -7,7 +7,7 @@ import numpy as np
import torch
import tritonclient.grpc as grpcclient
from minio import Minio
from pymilvus import MilvusClient
# from pymilvus import MilvusClient
from urllib3.exceptions import ResponseError
from app.core.config import settings, SR_MODEL_NAME, SR_TRITON_URL, MILVUS_TABLE_KEYPOINT, KEYPOINT_RESULT_TABLE_FIELD_SET
@@ -58,7 +58,21 @@ class DesignPreprocessing:
if len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
elif image.shape[2] == 4: # 如果是四通道 mask
image = image[:, :, :3]
# 分离RGB和Alpha通道
bgr = image[:, :, :3]
alpha = image[:, :, 3]
# 创建白色背景(也可改为其他颜色,如(255,255,255)就是白色)
background_color = (255, 255, 255)
background = np.full_like(bgr, background_color)
# 将Alpha通道转换为掩码0=透明255=不透明)
alpha_mask = alpha / 255.0 # 归一化到0-1
alpha_mask = np.expand_dims(alpha_mask, axis=-1) # 扩展维度,方便广播计算
# 混合背景和原图:透明区域显示背景色,不透明区域显示原图
image = (bgr * alpha_mask + background * (1 - alpha_mask)).astype(np.uint8)
# 此时image已经是3通道RGB无需再执行image = image[:, :, :3]
obj["image_obj"] = image
return image_list
@@ -174,8 +188,9 @@ class DesignPreprocessing:
scale = 0.4
if waist_width / scale >= image_width:
add_width = int((waist_width / scale - image_width) / 2)
ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256))
image_bytes = cv2.imencode(".jpg", ret)[1].tobytes()
ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(255, 255, 255))
img_rgba = cv2.cvtColor(ret, cv2.COLOR_RGB2RGBA)
image_bytes = cv2.imencode(".png", img_rgba)[1].tobytes()
# image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
bucket_name = image['image_url'].split('/', 1)[0]
object_name = image['image_url'].split('/', 1)[1].replace('.', '-show.')
@@ -261,14 +276,15 @@ class DesignPreprocessing:
def keypoint_cache(self, sketch):
try:
client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS)
# client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS)
keypoint_id = sketch['image_id']
res = client.query(
collection_name=MILVUS_TABLE_KEYPOINT,
# ids=[keypoint_id],
filter=f"keypoint_id == {keypoint_id}",
output_fields=['keypoint_vector', 'keypoint_site']
)
# res = client.query(
# collection_name=MILVUS_TABLE_KEYPOINT,
# # ids=[keypoint_id],
# filter=f"keypoint_id == {keypoint_id}",
# output_fields=['keypoint_vector', 'keypoint_site']
# )
res = []
if len(res) == 0:
# 没有结果 直接推理拿结果 并保存
keypoint_infer_result = self.infer_keypoint_result(sketch)

View File

@@ -11,7 +11,6 @@ import logging
import uuid
import cv2
import mmcv
import numpy as np
import pandas as pd
import torch
@@ -21,6 +20,7 @@ from minio import Minio
from tritonclient.utils import np_to_triton_dtype
from app.core.config import settings, FAST_GI_MODEL_URL, GI_MODEL_URL, DESIGN_MODEL_URL, FAST_GI_MODEL_NAME, GI_MODEL_NAME
from app.service.utils.image_normalize import my_imnormalize
from app.service.utils.new_oss_client import oss_upload_image
logger = logging.getLogger()
@@ -86,10 +86,9 @@ class AgentToolGenerateImage:
@staticmethod
def preprocess(img):
img = mmcv.imread(img)
img_scale = (224, 224)
img = cv2.resize(img, img_scale)
img = mmcv.imnormalize(
img = my_imnormalize(
img,
mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]),
to_rgb=True)

View File

@@ -189,10 +189,10 @@ if __name__ == '__main__':
tasks_id="123-89",
prompt="a single item of sketch of dress, 4k, white background",
image_url="aida-collection-element/89/Sketchboard/95f20cdc-e059-435c-b8b1-d04cc9e80c3d.png",
mode='img2img',
mode='txt2img',
category="sketch",
gender="Female",
version="fast"
version="hight"
)
server = GenerateImage(rd)
print(server.get_result())

View File

@@ -2,23 +2,23 @@ import logging
import time
import cv2
import mmcv
import numpy as np
import torch
import tritonclient.http as httpclient
from app.core.config import settings, DESIGN_MODEL_URL, DESIGN_MODEL_NAME
from app.service.generate_image.utils.upload_sd_image import upload_stain_png_sd, upload_face_png_sd
from app.service.utils.image_normalize import my_imnormalize
logger = logging.getLogger()
def seg_preprocess(img_path):
img = mmcv.imread(img_path)
img = img_path
ori_shape = img.shape[:2]
img_scale = ori_shape
img = cv2.resize(img, img_scale)
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
img = my_imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img, ori_shape
@@ -242,10 +242,9 @@ def stain_detection(image, user_id, category, tasks_id, spot_size=100):
def generate_category_recognition(image, gender):
def preprocess(img):
img = mmcv.imread(img)
img_scale = (224, 224)
img = cv2.resize(img, img_scale)
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
img = my_imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img

View File

@@ -1,7 +1,6 @@
import logging
import cv2
import mmcv
import numpy as np
import torch
import torch.nn.functional as F
@@ -10,6 +9,7 @@ from minio import Minio
from app.core.config import settings
from app.core.config import DESIGN_MODEL_URL
from app.schemas.image2sketch import Image2SketchModel
from app.service.utils.image_normalize import my_imnormalize
from app.service.utils.new_oss_client import oss_get_image, oss_upload_image
logger = logging.getLogger()
@@ -67,7 +67,7 @@ class LineArtService:
@staticmethod
def line_art_preprocess(image):
img = mmcv.imread(image)
img = image
ori_shape = img.shape[:2]
img_scale_w, img_scale_h = ori_shape
if ori_shape[0] > 1024:
@@ -76,9 +76,9 @@ class LineArtService:
img_scale_h = 1024
# 如果图片size任意一边 大于 1024 则会resize 成1024
if ori_shape != (img_scale_w, img_scale_h):
# mmcv.imresize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了
# my_imnormalize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了
img = cv2.resize(img, (img_scale_h, img_scale_w))
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
img = my_imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img, ori_shape

View File

@@ -90,7 +90,7 @@ def get_response(messages):
def get_translation_from_llama3(text):
start_time = time.time()
url = "http://10.1.1.240:11434/api/generate"
url = f"http://{settings.A6000_SERVICE_HOST}:12434/api/generate"
# url = "http://10.1.1.240:1143/api/generate"
# prompt = f"System: {prefix_for_llama}\nUser:[{text}]"
@@ -103,8 +103,8 @@ def get_translation_from_llama3(text):
# 创建请求的负载 translator是自定义的翻译模型
payload = {
"model": "translator",
"prompt": f"[{text}]",
"model": "AiDA-translator:latest",
"prompt": text,
"stream": False
}
# 将负载转换为 JSON 格式
@@ -148,7 +148,7 @@ def get_translation_from_llama3(text):
def get_prompt_from_image(image_path, text):
start_time = time.time()
# url = "http://localhost:11434/api/generate"
url = "http://10.1.1.243:11434/api/generate"
url = f"http://{settings.B_4_X_4090_SERVICE_HOST}:11434/api/generate"
image_base64 = minio_util.minio_url_to_base64(image_path.img)
# image_base64 = minio_url_to_base64(image_path)
@@ -180,7 +180,7 @@ def get_prompt_from_image(image_path, text):
def main():
"""Main function"""
text = get_translation_from_llama3("[火焰]")
text = get_translation_from_llama3("火焰")
print(text)

View File

@@ -1,241 +1,240 @@
# 预加载资源
import logging
import time
from collections import defaultdict
import os
import json
import numpy as np
from app.core.config import settings
from app.core.mysql_config import DB_CONFIG
logger = logging.getLogger()
import pymysql
from concurrent.futures import ThreadPoolExecutor
HEAT_VECTOR_FILE = 'heat_vectors_data/heat_vectors.json' # 可动态加载或配置
matrix_data = {
"interaction_matrix": None,
"feature_matrix": None,
"user_index_interaction": None,
"sketch_index_interaction": None,
"user_index_feature": None,
"sketch_index_feature": None,
"iid_to_sketch": None,
"category_to_iids": None,
"cached_scores": {},
"cached_valid_idxs": {},
"category_sketch_idxs_inter": None,
"category_sketch_idxs_feature": None,
"user_inter_full": dict(),
"user_feat_full": dict(),
"brand_feature_matrix": None,
"brand_index_map": None,
"heat_data": {},
}
def load_resources():
"""加载所有矩阵和映射关系,并触发预缓存"""
try:
start_time = time.time()
# 清空缓存
matrix_data["cached_scores"].clear()
matrix_data["cached_valid_idxs"].clear()
# 加载数据
sketch_to_iid = np.load(f'{settings.RECOMMEND_PATH_PREFIX}sketch_to_iid.npy', allow_pickle=True).item()
matrix_data["iid_to_sketch"] = {v: k for k, v in sketch_to_iid.items()}
matrix_data["interaction_matrix"] = np.load(f"{settings.RECOMMEND_PATH_PREFIX}interaction_matrix.npy", allow_pickle=True)
matrix_data["user_index_interaction"] = np.load(f"{settings.RECOMMEND_PATH_PREFIX}user_index_interaction_matrix.npy", allow_pickle=True).item()
matrix_data["sketch_index_interaction"] = np.load(f"{settings.RECOMMEND_PATH_PREFIX}sketch_index_interaction_matrix.npy",
allow_pickle=True).item()
matrix_data["feature_matrix"] = np.load(f"{settings.RECOMMEND_PATH_PREFIX}feature_matrix.npy", allow_pickle=True)
brand_feature_path = f"{settings.RECOMMEND_PATH_PREFIX}brand_feature_matrix.npy"
if os.path.exists(brand_feature_path):
matrix_data["brand_feature_matrix"] = np.load(brand_feature_path, allow_pickle=True)
else:
logger.warning("brand_feature_matrix 文件不存在,使用空数组")
matrix_data["brand_feature_matrix"] = np.array([])
# brand_index_map
brand_index_path = f"{settings.RECOMMEND_PATH_PREFIX}brand_index_map.npy"
if os.path.exists(brand_index_path):
matrix_data["brand_index_map"] = np.load(brand_index_path, allow_pickle=True).item()
else:
logger.warning("brand_index_map 文件不存在,使用空字典")
matrix_data["brand_index_map"] = {}
matrix_data["user_index_feature"] = np.load(f"{settings.RECOMMEND_PATH_PREFIX}user_index_feature_matrix.npy", allow_pickle=True).item()
matrix_data["sketch_index_feature"] = np.load(f"{settings.RECOMMEND_PATH_PREFIX}sketch_index_feature_matrix.npy", allow_pickle=True).item()
category_to_iid_map = np.load(f"{settings.RECOMMEND_PATH_PREFIX}iid_to_category_interaction_matrix.npy", allow_pickle=True).item()
matrix_data["category_to_iids"] = defaultdict(list)
for iid, cat in category_to_iid_map.items():
matrix_data["category_to_iids"][cat].append(iid)
logger.info(f"资源加载完成,耗时: {time.time() - start_time:.2f}")
# 触发预缓存
precache_user_category()
if os.path.exists(HEAT_VECTOR_FILE):
with open(HEAT_VECTOR_FILE, 'r', encoding='utf-8') as f:
heat_json = json.load(f)
matrix_data["heat_data"] = heat_json.get("data", {})
logger.info(f"热度向量数据加载完成,共加载 {len(matrix_data['heat_data'])} 个类别")
else:
matrix_data["heat_data"] = {}
except Exception as e:
logger.error(f"资源加载失败: {str(e)}")
raise RuntimeError("初始化失败")
def precache_user_category():
"""优化后的用户分类预缓存(添加耗时统计)"""
if not all([
matrix_data["interaction_matrix"] is not None,
matrix_data["feature_matrix"] is not None,
matrix_data["user_index_interaction"] is not None
]):
logger.warning("资源未加载完成,跳过预缓存")
return
start_time = time.perf_counter()
time_stats = {
"get_all_user_categories": 0,
"process_user_category": 0,
"thread_execution": 0,
"cache_update": 0,
"total": 0,
}
# 统计用户类别获取时间
t1 = time.perf_counter()
user_categories = get_all_user_categories()
time_stats["get_all_user_categories"] = time.perf_counter() - t1
precached_count = 0
def process_user_category(user_id, categories):
"""单用户类别缓存计算(统计耗时)"""
local_cache = {}
local_valid_idxs = {}
time.perf_counter()
for category in categories:
cache_key = (user_id, category)
if cache_key in matrix_data["cached_scores"]:
continue
try:
user_idx_inter = matrix_data["user_index_interaction"].get(user_id)
user_idx_feature = matrix_data["user_index_feature"].get(user_id)
# 统计获取类别 IID 耗时
t_iid = time.perf_counter()
category_iids = matrix_data["category_to_iids"].get(category, [])
valid_sketch_idxs_inter = [matrix_data["sketch_index_interaction"][iid]
for iid in category_iids if iid in matrix_data["sketch_index_interaction"]]
valid_sketch_idxs_feature = [matrix_data["sketch_index_feature"][iid]
for iid in category_iids if iid in matrix_data["sketch_index_feature"]]
time_stats["process_user_category"] += time.perf_counter() - t_iid
# 统计矩阵计算耗时
t_matrix = time.perf_counter()
processed_inter = np.zeros(len(valid_sketch_idxs_inter))
if user_idx_inter is not None and valid_sketch_idxs_inter:
raw_inter_scores = matrix_data["interaction_matrix"][user_idx_inter, valid_sketch_idxs_inter]
processed_inter = raw_inter_scores * 0.7
processed_feat = np.zeros(len(valid_sketch_idxs_feature))
if user_idx_feature is not None and valid_sketch_idxs_feature:
raw_feat_scores = matrix_data["feature_matrix"][user_idx_feature, valid_sketch_idxs_feature]
raw_feat_scores = (raw_feat_scores - np.min(raw_feat_scores)) / (
np.max(raw_feat_scores) - np.min(raw_feat_scores) + 1e-8)
processed_feat = raw_feat_scores * 0.3
time_stats["process_user_category"] += time.perf_counter() - t_matrix
if len(processed_inter) == len(processed_feat):
local_cache[cache_key] = (processed_inter, processed_feat)
local_valid_idxs[cache_key] = valid_sketch_idxs_inter
except Exception as e:
logger.error(f"预缓存失败 (user={user_id}, category={category}): {str(e)}")
return local_cache, local_valid_idxs
# 统计线程执行时间
t2 = time.perf_counter()
with ThreadPoolExecutor(max_workers=8) as executor:
futures = {executor.submit(process_user_category, user_id, categories): user_id for user_id, categories in user_categories.items()}
for future in futures:
try:
t_cache = time.perf_counter()
cache_part, valid_idxs_part = future.result()
matrix_data["cached_scores"].update(cache_part)
matrix_data["cached_valid_idxs"].update(valid_idxs_part)
time_stats["cache_update"] += time.perf_counter() - t_cache
precached_count += len(cache_part)
except Exception as e:
logger.error(f"线程执行错误: {str(e)}")
time_stats["thread_execution"] = time.perf_counter() - t2
time_stats["total"] = time.perf_counter() - start_time
# 输出统计信息
logger.info(f"""
预缓存完成,共缓存 {precached_count} 组数据,耗时统计如下:
- 获取用户类别数据: {time_stats["get_all_user_categories"]:.2f}s
- 计算用户类别缓存: {time_stats["process_user_category"]:.2f}s
- 线程任务执行: {time_stats["thread_execution"]:.2f}s
- 更新缓存数据: {time_stats["cache_update"]:.2f}s
- 总耗时: {time_stats["total"]:.2f}s
""")
def get_all_user_categories():
"""获取所有用户及其对应的分类"""
conn = None
try:
conn = pymysql.connect(**DB_CONFIG)
cursor = conn.cursor()
query = """
SELECT DISTINCT account_id, path
FROM user_preference_log_prediction \
"""
cursor.execute(query)
results = cursor.fetchall()
user_categories = defaultdict(set)
for account_id, path in results:
category = get_category_from_path(path)
user_categories[account_id].add(category)
return dict(user_categories)
except Exception as e:
logger.error(f"数据库查询失败: {str(e)}")
return {}
finally:
if conn:
conn.close()
def get_category_from_path(path: str) -> str:
"""从路径解析类别"""
try:
parts = path.split('/')
if len(parts) >= 4:
return f"{parts[2]}_{parts[3]}"
return "unknown"
except:
return "unknown"
# # 预加载资源
# import logging
# import time
# from collections import defaultdict
# import os
# import json
# import numpy as np
#
# from app.core.config import DB_CONFIG, RECOMMEND_PATH_PREFIX
#
# logger = logging.getLogger()
# import pymysql
# from concurrent.futures import ThreadPoolExecutor
#
# HEAT_VECTOR_FILE = 'heat_vectors_data/heat_vectors.json' # 可动态加载或配置
#
# matrix_data = {
# "interaction_matrix": None,
# "feature_matrix": None,
# "user_index_interaction": None,
# "sketch_index_interaction": None,
# "user_index_feature": None,
# "sketch_index_feature": None,
# "iid_to_sketch": None,
# "category_to_iids": None,
# "cached_scores": {},
# "cached_valid_idxs": {},
# "category_sketch_idxs_inter": None,
# "category_sketch_idxs_feature": None,
# "user_inter_full": dict(),
# "user_feat_full": dict(),
# "brand_feature_matrix": None,
# "brand_index_map": None,
# "heat_data": {},
# }
#
#
# def load_resources():
# """加载所有矩阵和映射关系,并触发预缓存"""
# try:
# start_time = time.time()
#
# # 清空缓存
# matrix_data["cached_scores"].clear()
# matrix_data["cached_valid_idxs"].clear()
#
# # 加载数据
# sketch_to_iid = np.load(f'{RECOMMEND_PATH_PREFIX}sketch_to_iid.npy', allow_pickle=True).item()
# matrix_data["iid_to_sketch"] = {v: k for k, v in sketch_to_iid.items()}
#
# matrix_data["interaction_matrix"] = np.load(f"{RECOMMEND_PATH_PREFIX}interaction_matrix.npy", allow_pickle=True)
# matrix_data["user_index_interaction"] = np.load(f"{RECOMMEND_PATH_PREFIX}user_index_interaction_matrix.npy", allow_pickle=True).item()
# matrix_data["sketch_index_interaction"] = np.load(f"{RECOMMEND_PATH_PREFIX}sketch_index_interaction_matrix.npy",
# allow_pickle=True).item()
#
# matrix_data["feature_matrix"] = np.load(f"{RECOMMEND_PATH_PREFIX}feature_matrix.npy", allow_pickle=True)
#
# brand_feature_path = f"{RECOMMEND_PATH_PREFIX}brand_feature_matrix.npy"
# if os.path.exists(brand_feature_path):
# matrix_data["brand_feature_matrix"] = np.load(brand_feature_path, allow_pickle=True)
# else:
# logger.warning("brand_feature_matrix 文件不存在,使用空数组")
# matrix_data["brand_feature_matrix"] = np.array([])
#
# # brand_index_map
# brand_index_path = f"{RECOMMEND_PATH_PREFIX}brand_index_map.npy"
# if os.path.exists(brand_index_path):
# matrix_data["brand_index_map"] = np.load(brand_index_path, allow_pickle=True).item()
# else:
# logger.warning("brand_index_map 文件不存在,使用空字典")
# matrix_data["brand_index_map"] = {}
#
# matrix_data["user_index_feature"] = np.load(f"{RECOMMEND_PATH_PREFIX}user_index_feature_matrix.npy", allow_pickle=True).item()
#
# matrix_data["sketch_index_feature"] = np.load(f"{RECOMMEND_PATH_PREFIX}sketch_index_feature_matrix.npy", allow_pickle=True).item()
#
# category_to_iid_map = np.load(f"{RECOMMEND_PATH_PREFIX}iid_to_category_interaction_matrix.npy", allow_pickle=True).item()
# matrix_data["category_to_iids"] = defaultdict(list)
# for iid, cat in category_to_iid_map.items():
# matrix_data["category_to_iids"][cat].append(iid)
#
# logger.info(f"资源加载完成,耗时: {time.time() - start_time:.2f}秒")
#
# # 触发预缓存
# precache_user_category()
#
# if os.path.exists(HEAT_VECTOR_FILE):
# with open(HEAT_VECTOR_FILE, 'r', encoding='utf-8') as f:
# heat_json = json.load(f)
# matrix_data["heat_data"] = heat_json.get("data", {})
# logger.info(f"热度向量数据加载完成,共加载 {len(matrix_data['heat_data'])} 个类别")
# else:
# matrix_data["heat_data"] = {}
#
# except Exception as e:
# logger.error(f"资源加载失败: {str(e)}")
# raise RuntimeError("初始化失败")
#
#
# def precache_user_category():
# """优化后的用户分类预缓存(添加耗时统计)"""
# if not all([
# matrix_data["interaction_matrix"] is not None,
# matrix_data["feature_matrix"] is not None,
# matrix_data["user_index_interaction"] is not None
# ]):
# logger.warning("资源未加载完成,跳过预缓存")
# return
#
# start_time = time.perf_counter()
# time_stats = {
# "get_all_user_categories": 0,
# "process_user_category": 0,
# "thread_execution": 0,
# "cache_update": 0,
# "total": 0,
# }
#
# # 统计用户类别获取时间
# t1 = time.perf_counter()
# user_categories = get_all_user_categories()
# time_stats["get_all_user_categories"] = time.perf_counter() - t1
#
# precached_count = 0
#
# def process_user_category(user_id, categories):
# """单用户类别缓存计算(统计耗时)"""
# local_cache = {}
# local_valid_idxs = {}
# t_start = time.perf_counter()
#
# for category in categories:
# cache_key = (user_id, category)
# if cache_key in matrix_data["cached_scores"]:
# continue
#
# try:
# user_idx_inter = matrix_data["user_index_interaction"].get(user_id)
# user_idx_feature = matrix_data["user_index_feature"].get(user_id)
#
# # 统计获取类别 IID 耗时
# t_iid = time.perf_counter()
# category_iids = matrix_data["category_to_iids"].get(category, [])
# valid_sketch_idxs_inter = [matrix_data["sketch_index_interaction"][iid]
# for iid in category_iids if iid in matrix_data["sketch_index_interaction"]]
# valid_sketch_idxs_feature = [matrix_data["sketch_index_feature"][iid]
# for iid in category_iids if iid in matrix_data["sketch_index_feature"]]
# time_stats["process_user_category"] += time.perf_counter() - t_iid
#
# # 统计矩阵计算耗时
# t_matrix = time.perf_counter()
# processed_inter = np.zeros(len(valid_sketch_idxs_inter))
# if user_idx_inter is not None and valid_sketch_idxs_inter:
# raw_inter_scores = matrix_data["interaction_matrix"][user_idx_inter, valid_sketch_idxs_inter]
# processed_inter = raw_inter_scores * 0.7
#
# processed_feat = np.zeros(len(valid_sketch_idxs_feature))
# if user_idx_feature is not None and valid_sketch_idxs_feature:
# raw_feat_scores = matrix_data["feature_matrix"][user_idx_feature, valid_sketch_idxs_feature]
# raw_feat_scores = (raw_feat_scores - np.min(raw_feat_scores)) / (
# np.max(raw_feat_scores) - np.min(raw_feat_scores) + 1e-8)
# processed_feat = raw_feat_scores * 0.3
# time_stats["process_user_category"] += time.perf_counter() - t_matrix
#
# if len(processed_inter) == len(processed_feat):
# local_cache[cache_key] = (processed_inter, processed_feat)
# local_valid_idxs[cache_key] = valid_sketch_idxs_inter
#
# except Exception as e:
# logger.error(f"预缓存失败 (user={user_id}, category={category}): {str(e)}")
#
# return local_cache, local_valid_idxs
#
# # 统计线程执行时间
# t2 = time.perf_counter()
# with ThreadPoolExecutor(max_workers=8) as executor:
# futures = {executor.submit(process_user_category, user_id, categories): user_id for user_id, categories in user_categories.items()}
# for future in futures:
# try:
# t_cache = time.perf_counter()
# cache_part, valid_idxs_part = future.result()
# matrix_data["cached_scores"].update(cache_part)
# matrix_data["cached_valid_idxs"].update(valid_idxs_part)
# time_stats["cache_update"] += time.perf_counter() - t_cache
# precached_count += len(cache_part)
# except Exception as e:
# logger.error(f"线程执行错误: {str(e)}")
# time_stats["thread_execution"] = time.perf_counter() - t2
#
# time_stats["total"] = time.perf_counter() - start_time
#
# # 输出统计信息
# logger.info(f"""
# 预缓存完成,共缓存 {precached_count} 组数据,耗时统计如下:
# - 获取用户类别数据: {time_stats["get_all_user_categories"]:.2f}s
# - 计算用户类别缓存: {time_stats["process_user_category"]:.2f}s
# - 线程任务执行: {time_stats["thread_execution"]:.2f}s
# - 更新缓存数据: {time_stats["cache_update"]:.2f}s
# - 总耗时: {time_stats["total"]:.2f}s
# """)
#
#
# def get_all_user_categories():
# """获取所有用户及其对应的分类"""
# conn = None
# try:
# conn = pymysql.connect(**DB_CONFIG)
# cursor = conn.cursor()
#
# query = """
# SELECT DISTINCT account_id, path
# FROM user_preference_log_prediction
# """
# cursor.execute(query)
# results = cursor.fetchall()
#
# user_categories = defaultdict(set)
# for account_id, path in results:
# category = get_category_from_path(path)
# user_categories[account_id].add(category)
#
# return dict(user_categories)
#
# except Exception as e:
# logger.error(f"数据库查询失败: {str(e)}")
# return {}
# finally:
# if conn:
# conn.close()
#
#
# def get_category_from_path(path: str) -> str:
# """从路径解析类别"""
# try:
# parts = path.split('/')
# if len(parts) >= 4:
# return f"{parts[2]}_{parts[3]}"
# return "unknown"
# except:
# return "unknown"

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,67 @@
"""
推荐系统配置
"""
import os
from app.core.config import settings
# Milvus 集合名称
MILVUS_COLLECTION_SKETCH_VECTORS = "sketch_vectors_norm"
# Redis key 前缀
REDIS_KEY_USER_PREF_PREFIX = "user_pref"
# 推荐系统配置参数
RECOMMENDATION_CONFIG = {
# 时间衰减半衰期(用于计算时间衰减权重)
# 值越小,最近的行为权重越大
"K_half": 10,
# 探索与利用的比例 (0.0-1.0)
# - 值越大,使用探索分支(随机推荐)的几率越大,结果更随机
# - 值越小,使用利用分支(基于用户偏好)的几率越大,结果更精准
# - 建议范围: 0.3-0.7,要增加随机性可提高到 0.6-0.8
"explore_ratio": 0.5,
# 向量检索返回的候选数量
# 值越大,候选池越大,但计算成本也越高
# 建议范围: 100-1000
"topk": 200,
# Style 加分系数(同 style 的候选进行加分)
# 值越大,匹配 style 的候选被选中的概率越大
# 要降低某个结果的重复率,可以降低此值(如 0.1 或 0.05
"style_bonus": 0.2,
# Softmax 抽样的温度参数
# - 温度越高(>1.0),概率分布越均匀,结果更随机,重复率更低
# - 温度越低(<1.0),高分项概率越大,结果更集中,重复率更高
# - 温度=1.0 为标准 Softmax
# - 建议范围: 1.0-3.0,要增加随机性可提高到 2.0-3.0
"softmax_temperature": 0.07,
# 监听间隔(秒)
"listen_interval_sec": 30,
# 批量处理大小
"batch_size": 1000,
# Redis 过期时间30天
"redis_expire_seconds": 2592000,
# 向量维度
"vector_dim": 2048,
}
# 数据库表名
TABLE_USER_PREFERENCE_LOG = "user_preference"
TABLE_SYS_FILE = "t_sys_file"
# MySQL 连接配置(用于推荐系统)
MYSQL_CONFIG = {
"host": settings.MYSQL_HOST,
"port": settings.MYSQL_PORT,
"user": settings.MYSQL_USER,
"password": settings.MYSQL_PASSWORD,
"database": settings.MYSQL_DB,
"charset": "utf8mb4"
}

View File

@@ -0,0 +1,331 @@
"""
独立脚本:从 t_sys_file 导入系统图向量到 Milvus
可以单独运行,不依赖整个项目启动
使用方法:
python -m app.service.recommendation_system.import_sys_sketch_to_milvus
python app/service/recommendation_system/import_sys_sketch_to_milvus.py
"""
import sys
import os
import logging
import argparse
from pathlib import Path
# 添加项目根目录到 Python 路径
project_root = Path(__file__).parent.parent.parent.parent
sys.path.insert(0, str(project_root))
import numpy as np
import pymysql
from tqdm import tqdm
from app.service.recommendation_system.config import (
MYSQL_CONFIG, TABLE_SYS_FILE,
RECOMMENDATION_CONFIG, MILVUS_COLLECTION_SKETCH_VECTORS
)
from app.service.recommendation_system.vector_utils import extract_feature_vector, normalize_vector
from app.service.recommendation_system.milvus_client import create_collection, insert_vectors
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(),
logging.FileHandler('import_sys_sketch.log', encoding='utf-8')
]
)
logger = logging.getLogger(__name__)
def get_sys_file_records(conn, limit=None, offset=0):
"""
从 t_sys_file 表获取系统图记录
Args:
conn: 数据库连接
limit: 限制数量None 表示不限制)
offset: 偏移量
Returns:
记录列表,每个元素为 (id, url, style, level3_type, level2_type, deprecated)
"""
cursor = conn.cursor()
query = f"""
SELECT id, url, style, level3_type, level2_type, deprecated
FROM {TABLE_SYS_FILE}
WHERE level1_type = 'Images'
AND style IS NOT NULL
AND style != ''
AND deprecated != 1
ORDER BY id
"""
if limit:
query += f" LIMIT {limit} OFFSET {offset}"
cursor.execute(query)
records = cursor.fetchall()
cursor.close()
return records
def get_total_count(conn):
"""获取总记录数"""
cursor = conn.cursor()
cursor.execute(f"""
SELECT COUNT(*)
FROM {TABLE_SYS_FILE}
WHERE level1_type = 'Images'
AND style IS NOT NULL
AND style != ''
AND deprecated != 1
""")
count = cursor.fetchone()[0]
cursor.close()
return count
def process_and_insert_batch(records, batch_size=1000, retry_times=3):
"""
处理并批量插入向量
Args:
records: 记录列表
batch_size: 批量大小
retry_times: 失败重试次数
Returns:
(成功数量, 失败数量)
"""
success_count = 0
failed_count = 0
failed_records = []
batch_data = []
# 使用 tqdm 显示进度
with tqdm(total=len(records), desc="处理记录", unit="") as pbar:
for idx, (sys_file_id, url, style, level3_type, level2_type, deprecated) in enumerate(records):
try:
# 计算 category
category = f"{level3_type.lower()}_{level2_type.lower()}"
# 提取特征向量
feature_vector = extract_feature_vector(url)
# 归一化,便于 IP≈cosine 度量
feature_vector = normalize_vector(feature_vector)
# 检查向量是否有效
if np.all(feature_vector == 0):
logger.warning(f"向量提取失败,跳过: {url} (id={sys_file_id})")
failed_count += 1
failed_records.append((sys_file_id, url))
pbar.update(1)
continue
# 准备数据
data_item = {
"path": url,
"sys_file_id": sys_file_id,
"style": style,
"category": category,
"is_system_sketch": 1,
"deprecated": deprecated if deprecated else 0,
"feature_vector": feature_vector.tolist()
}
batch_data.append(data_item)
# 批量写入
if len(batch_data) >= batch_size:
try:
insert_vectors(batch_data)
success_count += len(batch_data)
batch_data = []
logger.info(f"已成功插入 {success_count} 条记录")
except Exception as e:
logger.error(f"批量写入失败: {e}")
failed_count += len(batch_data)
failed_records.extend([(item["sys_file_id"], item["path"]) for item in batch_data])
batch_data = []
pbar.update(1)
except Exception as e:
logger.error(f"处理记录失败 [id={sys_file_id}, url={url}]: {e}")
failed_count += 1
failed_records.append((sys_file_id, url))
pbar.update(1)
# 写入剩余数据
if batch_data:
try:
insert_vectors(batch_data)
success_count += len(batch_data)
logger.info(f"写入剩余 {len(batch_data)} 条记录")
except Exception as e:
logger.error(f"写入剩余数据失败: {e}")
failed_count += len(batch_data)
failed_records.extend([(item["sys_file_id"], item["path"]) for item in batch_data])
# 重试失败记录
if failed_records and retry_times > 0:
logger.info(f"开始重试 {len(failed_records)} 条失败记录,最多重试 {retry_times} 次...")
for retry in range(retry_times):
if not failed_records:
break
retry_failed = []
with tqdm(total=len(failed_records), desc=f"重试第 {retry + 1}", unit="") as pbar:
for sys_file_id, url in failed_records:
try:
# 重新查询记录信息
conn = pymysql.connect(**MYSQL_CONFIG)
cursor = conn.cursor()
cursor.execute(f"""
SELECT id, url, style, level3_type, level2_type, deprecated
FROM {TABLE_SYS_FILE}
WHERE id = %s
""", (sys_file_id,))
record = cursor.fetchone()
cursor.close()
conn.close()
if not record:
retry_failed.append((sys_file_id, url))
pbar.update(1)
continue
sys_file_id, url, style, level3_type, level2_type, deprecated = record
category = f"{level3_type.lower()}_{level2_type.lower()}"
feature_vector = extract_feature_vector(url)
feature_vector = normalize_vector(feature_vector)
if np.all(feature_vector == 0):
retry_failed.append((sys_file_id, url))
pbar.update(1)
continue
data_item = {
"path": url,
"sys_file_id": sys_file_id,
"style": style,
"category": category,
"is_system_sketch": 1,
"deprecated": deprecated if deprecated else 0,
"feature_vector": feature_vector.tolist()
}
insert_vectors([data_item])
success_count += 1
failed_count -= 1
pbar.update(1)
except Exception as e:
logger.error(f"重试失败 [id={sys_file_id}, url={url}]: {e}")
retry_failed.append((sys_file_id, url))
pbar.update(1)
failed_records = retry_failed
if failed_records:
logger.warning(f"{retry + 1} 次重试后仍有 {len(failed_records)} 条记录失败")
return success_count, failed_count, failed_records
def main():
"""主函数"""
parser = argparse.ArgumentParser(description='从 t_sys_file 导入系统图向量到 Milvus')
parser.add_argument('--batch-size', type=int, default=1000, help='批量处理大小默认1000')
parser.add_argument('--retry-times', type=int, default=3, help='失败重试次数默认3')
parser.add_argument('--limit', type=int, default=None, help='限制处理数量(用于测试,默认:不限制)')
parser.add_argument('--offset', type=int, default=0, help='起始偏移量默认0')
parser.add_argument('--skip-create-collection', action='store_true', help='跳过创建集合(如果集合已存在)')
args = parser.parse_args()
logger.info("=" * 60)
logger.info("开始从 t_sys_file 导入系统图向量到 Milvus")
logger.info("=" * 60)
logger.info(f"配置参数:")
logger.info(f" - 批量大小: {args.batch_size}")
logger.info(f" - 重试次数: {args.retry_times}")
logger.info(f" - 限制数量: {args.limit if args.limit else '不限制'}")
logger.info(f" - 起始偏移: {args.offset}")
logger.info("=" * 60)
# 1. 创建 Milvus 集合
if not args.skip_create_collection:
logger.info("创建 Milvus 集合...")
try:
create_collection()
logger.info("Milvus 集合创建成功(或已存在)")
except Exception as e:
logger.error(f"创建 Milvus 集合失败: {e}")
return
else:
logger.info("跳过创建集合")
# 2. 连接数据库
logger.info("连接数据库...")
try:
conn = pymysql.connect(**MYSQL_CONFIG)
logger.info("数据库连接成功")
except Exception as e:
logger.error(f"数据库连接失败: {e}")
return
try:
# 3. 获取总记录数
total_count = get_total_count(conn)
logger.info(f"找到 {total_count} 条系统图记录")
if total_count == 0:
logger.warning("没有找到系统图数据")
return
# 4. 获取记录
logger.info("获取记录...")
records = get_sys_file_records(conn, limit=args.limit, offset=args.offset)
logger.info(f"获取到 {len(records)} 条记录")
if not records:
logger.warning("没有获取到记录")
return
# 5. 处理并插入
logger.info("开始处理记录...")
success_count, failed_count, failed_records = process_and_insert_batch(
records,
batch_size=args.batch_size,
retry_times=args.retry_times
)
# 6. 输出结果
logger.info("=" * 60)
logger.info("导入完成!")
logger.info(f" - 成功: {success_count}")
logger.info(f" - 失败: {failed_count}")
if failed_records:
logger.warning(f" - 失败记录列表前10条:")
for sys_file_id, url in failed_records[:10]:
logger.warning(f" ID={sys_file_id}, URL={url}")
if len(failed_records) > 10:
logger.warning(f" ... 还有 {len(failed_records) - 10} 条失败记录")
logger.info("=" * 60)
except Exception as e:
logger.error(f"处理过程中发生错误: {e}", exc_info=True)
finally:
conn.close()
logger.info("数据库连接已关闭")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,347 @@
"""
增量监听模块
实时监听 user_preference 表的新增记录,更新用户偏好向量
"""
import logging
import math
import pymysql
import numpy as np
from typing import List, Dict, Set, Tuple, Optional
from datetime import datetime
from collections import defaultdict
from apscheduler.schedulers.background import BackgroundScheduler
from apscheduler.schedulers.blocking import BlockingScheduler
from app.service.recommendation_system.config import (
MYSQL_CONFIG, TABLE_USER_PREFERENCE_LOG, TABLE_SYS_FILE,
RECOMMENDATION_CONFIG, REDIS_KEY_USER_PREF_PREFIX
)
from app.service.recommendation_system.vector_utils import extract_feature_vector, compute_weighted_average, normalize_vector
from app.service.recommendation_system.milvus_client import query_vectors_by_paths, insert_vectors
from app.service.utils.redis_utils import Redis
import json
logger = logging.getLogger(__name__)
class IncrementalListener:
"""增量监听器"""
def __init__(self):
self.last_process_time = None
self.processed_combinations: Set[Tuple[int, str]] = set() # 已处理的 (account_id, category) 组合
self.listen_interval = RECOMMENDATION_CONFIG["listen_interval_sec"]
def get_new_like_records(self) -> List[Tuple]:
"""
获取新增点赞记录
Returns:
记录列表,每个元素为 (id, account_id, path, category, style, data_time, is_system_sketch, sys_file_id)
"""
conn = None
try:
conn = pymysql.connect(**MYSQL_CONFIG)
cursor = conn.cursor()
if self.last_process_time is None:
# 第一次运行查询最近30分钟的数据
cursor.execute(f"""
SELECT id, account_id, path, category, style, data_time
FROM {TABLE_USER_PREFERENCE_LOG}
WHERE data_time > DATE_SUB(NOW(), INTERVAL 30 MINUTE)
ORDER BY data_time
""")
else:
# 基于上次处理时间查询
cursor.execute(f"""
SELECT id, account_id, path, category, style, data_time
FROM {TABLE_USER_PREFERENCE_LOG}
WHERE data_time > %s
ORDER BY data_time
""", (self.last_process_time,))
records = cursor.fetchall()
return records
except Exception as e:
logger.error(f"获取新增点赞记录失败: {e}", exc_info=True)
return []
finally:
if conn:
conn.close()
def process_new_records(self, records: List[Tuple]):
"""
处理新增记录
Args:
records: 记录列表
"""
if not records:
return
# 按用户+类别分组
user_category_records = defaultdict(list)
for record in records:
account_id = record[1]
category = record[3]
if category: # 只处理有类别的记录
user_category_records[(account_id, category)].append(record)
# 去重:只处理一次每个 (account_id, category) 组合
to_process = []
for (account_id, category), recs in user_category_records.items():
if (account_id, category) not in self.processed_combinations:
to_process.append((account_id, category, recs))
self.processed_combinations.add((account_id, category))
logger.info(f"需要处理 {len(to_process)} 个用户-类别组合")
# 处理每个组合
for account_id, category, recs in to_process:
try:
self.update_user_preference_vector(account_id, category)
except Exception as e:
logger.error(f"更新用户偏好向量失败 [user={account_id}, category={category}]: {e}", exc_info=True)
# 更新最后处理时间
if records:
self.last_process_time = records[-1][5] # data_time
# 重置去重集合,确保下次周期不会跳过同一用户-类别
self.processed_combinations.clear()
def update_user_preference_vector(self, account_id: int, category: str):
"""
更新用户偏好向量
Args:
account_id: 用户ID
category: 类别
"""
conn = None
try:
conn = pymysql.connect(**MYSQL_CONFIG)
cursor = conn.cursor()
# 1. 获取该用户该类别的所有点赞记录
cursor.execute(f"""
SELECT path, data_time
FROM {TABLE_USER_PREFERENCE_LOG}
WHERE account_id = %s AND category = %s
ORDER BY data_time DESC
""", (account_id, category))
like_records = cursor.fetchall()
if not like_records:
return
# 2. 批量查询点赞次数
paths = [r[0] for r in like_records]
placeholders = ','.join(['%s'] * len(paths))
cursor.execute(f"""
SELECT path, COUNT(*) as like_count
FROM {TABLE_USER_PREFERENCE_LOG}
WHERE account_id = %s AND category = %s AND path IN ({placeholders})
GROUP BY path
""", (account_id, category) + tuple(paths))
like_counts = {row[0]: row[1] for row in cursor.fetchall()}
# 3. 批量获取向量
vectors_dict = query_vectors_by_paths(paths)
# 处理查询不到的 path新用户图或异常情况
missing_paths = [p for p in paths if p not in vectors_dict]
if missing_paths:
logger.info(f"用户 {account_id} 类别 {category}{len(missing_paths)} 个 path 需要实时计算向量")
self._compute_and_insert_missing_vectors(missing_paths, conn)
# 重新查询
vectors_dict = query_vectors_by_paths(paths)
# 4. 计算权重并加权平均
vectors = []
weights = []
K_half = RECOMMENDATION_CONFIG["K_half"]
for k, (path, data_time) in enumerate(like_records, 1):
if path not in vectors_dict:
continue
vector_data = vectors_dict[path]
feature_vector = np.array(vector_data["feature_vector"])
# 时间衰减权重
d_k = 0.5 ** (k / K_half)
# 点赞次数权重
like_count = like_counts.get(path, 1)
p_i = 1 + math.log(1 + like_count)
# 综合权重
w_i = d_k * p_i
vectors.append(feature_vector)
weights.append(w_i)
if not vectors:
logger.warning(f"用户 {account_id} 类别 {category} 没有有效向量")
return
# 5. 计算加权平均并做 L2 归一化IP≈cosine
preference_vector = compute_weighted_average(vectors, weights)
preference_vector = normalize_vector(preference_vector)
# 6. 写入 Redis
key = f"{REDIS_KEY_USER_PREF_PREFIX}:{account_id}:{category}"
vector_json = json.dumps(preference_vector.tolist())
Redis.write(
key=key,
value=vector_json,
expire=RECOMMENDATION_CONFIG["redis_expire_seconds"]
)
logger.debug(f"用户偏好向量更新成功 [user={account_id}, category={category}]")
except Exception as e:
logger.error(f"更新用户偏好向量失败 [user={account_id}, category={category}]: {e}", exc_info=True)
raise
finally:
if conn:
conn.close()
def _compute_and_insert_missing_vectors(self, paths: List[str], conn: pymysql.connections.Connection):
"""
计算并插入缺失的向量
Args:
paths: 缺失的 path 列表
conn: 数据库连接
"""
cursor = conn.cursor()
data_to_insert = []
for path in paths:
try:
# 判断数据来源(查询 t_sys_file 表)
cursor.execute(f"""
SELECT id, url, style, level3_type, level2_type, deprecated
FROM {TABLE_SYS_FILE}
WHERE url = %s
LIMIT 1
""", (path,))
sys_file = cursor.fetchone()
# 提取特征向量
feature_vector = extract_feature_vector(path)
if np.all(feature_vector == 0):
logger.warning(f"向量提取失败,跳过: {path}")
continue
if sys_file:
# 系统图
sys_file_id, url, style, level3_type, level2_type, deprecated = sys_file
category = f"{level3_type.lower()}_{level2_type.lower()}"
data_item = {
"path": path,
"sys_file_id": sys_file_id,
"style": style,
"category": category,
"is_system_sketch": 1,
"deprecated": deprecated if deprecated else 0,
"feature_vector": feature_vector.tolist()
}
else:
# 用户图
# 从 user_preference 获取 category如果有
cursor.execute(f"""
SELECT category
FROM {TABLE_USER_PREFERENCE_LOG}
WHERE path = %s AND category IS NOT NULL
LIMIT 1
""", (path,))
category_result = cursor.fetchone()
category = category_result[0] if category_result else None
data_item = {
"path": path,
"sys_file_id": None,
"style": None,
"category": category,
"is_system_sketch": 0,
"deprecated": 0,
"feature_vector": feature_vector.tolist()
}
data_to_insert.append(data_item)
except Exception as e:
logger.error(f"处理缺失向量失败 [{path}]: {e}")
# 批量插入
if data_to_insert:
try:
insert_vectors(data_to_insert)
logger.info(f"成功插入 {len(data_to_insert)} 个缺失向量")
except Exception as e:
logger.error(f"插入缺失向量失败: {e}")
def process_once(self):
"""单次轮询任务,供调度器调用"""
try:
records = self.get_new_like_records()
if records:
logger.info(f"发现 {len(records)} 条新增记录")
self.process_new_records(records)
else:
logger.debug("没有新增记录")
except Exception as e:
logger.error(f"监听轮询异常: {e}", exc_info=True)
def start_background_listener(scheduler: BackgroundScheduler):
"""将增量监听任务注册到后台调度器"""
# 降低 apscheduler 的日志级别,避免大量刷屏
logging.getLogger('apscheduler.executors.default').setLevel(logging.WARNING)
logging.getLogger('apscheduler.scheduler').setLevel(logging.WARNING)
listener = IncrementalListener()
scheduler.add_job(
listener.process_once,
"interval",
seconds=listener.listen_interval,
max_instances=1,
coalesce=True,
id="recommendation_incremental_listener",
replace_existing=True,
)
logger.info("增量监听任务已注册到调度器")
def start_blocking_listener():
"""以阻塞方式启动调度器(用于独立脚本运行)"""
listener = IncrementalListener()
scheduler = BlockingScheduler()
scheduler.add_job(
listener.process_once,
"interval",
seconds=listener.listen_interval,
max_instances=1,
coalesce=True,
id="recommendation_incremental_listener",
replace_existing=True,
)
logger.info("增量监听调度器已启动BlockingScheduler")
scheduler.start()
if __name__ == "__main__":
start_blocking_listener()

View File

@@ -0,0 +1,332 @@
"""
Milvus 客户端封装
"""
import logging
from typing import List, Dict, Optional, Any
import numpy as np
from pymilvus import MilvusClient, FieldSchema, CollectionSchema, DataType, connections, Collection
from app.core.config import settings
from app.service.recommendation_system.config import MILVUS_COLLECTION_SKETCH_VECTORS, RECOMMENDATION_CONFIG
logger = logging.getLogger(__name__)
# Milvus 客户端(单例)
_milvus_client = None
def get_milvus_client() -> MilvusClient:
"""获取 Milvus 客户端(单例模式)"""
global _milvus_client
if _milvus_client is None:
try:
_milvus_client = MilvusClient(
uri=settings.MILVUS_URL,
token=settings.MILVUS_TOKEN,
db_name="",
)
logger.info("Milvus 客户端连接成功")
except Exception as e:
logger.error(f"Milvus 客户端连接失败: {e}")
raise
return _milvus_client
def create_collection():
"""
创建 Milvus 集合 sketch_vectors
集合结构:
- path (PK, varchar(512)) - 主键MinIO 逻辑 URL
- sys_file_id (int64, 可为NULL) - 系统文件ID
- style (varchar(50), 可为NULL) - 风格样式
- category (varchar(100), 可为NULL) - 类别
- is_system_sketch (int8, 默认 1) - 标记字段1-系统图0-用户图
- deprecated (int8, 默认 0) - 是否废弃
- feature_vector (FloatVector(2048)) - 2048维特征向量
"""
client = get_milvus_client()
# 检查集合是否已存在
collections = client.list_collections()
if MILVUS_COLLECTION_SKETCH_VECTORS in collections:
logger.info(f"集合 {MILVUS_COLLECTION_SKETCH_VECTORS} 已存在")
return
try:
# 解析 Milvus URL
# 处理 http://host.docker.internal:19530 格式
url_clean = settings.MILVUS_URL.replace("http://", "").replace("https://", "")
if ":" in url_clean:
host, port_str = url_clean.split(":", 1)
port = int(port_str)
else:
host = url_clean
port = 19530
# 使用传统 API 创建集合(更可靠)
# 连接到 Milvus如果未连接
try:
connections.connect(
alias=settings.MILVUS_ALIAS,
host=host,
port=port,
token=settings.MILVUS_TOKEN if settings.MILVUS_TOKEN else None
)
logger.info(f"已连接到 Milvus: {host}:{port}")
except Exception as conn_e:
# 如果连接已存在,忽略错误
if "already exists" in str(conn_e).lower() or "Connection already exists" in str(conn_e):
logger.info("Milvus 连接已存在")
else:
logger.warning(f"连接 Milvus 时出现警告: {conn_e}")
# 定义字段
fields = [
FieldSchema(name="path", dtype=DataType.VARCHAR, is_primary=True, max_length=512),
FieldSchema(name="sys_file_id", dtype=DataType.INT64),
FieldSchema(name="style", dtype=DataType.VARCHAR, max_length=50),
FieldSchema(name="category", dtype=DataType.VARCHAR, max_length=50),
FieldSchema(name="is_system_sketch", dtype=DataType.INT8),
FieldSchema(name="deprecated", dtype=DataType.INT8),
FieldSchema(
name="feature_vector",
dtype=DataType.FLOAT_VECTOR,
dim=RECOMMENDATION_CONFIG["vector_dim"]
)
]
# 创建 schema
schema = CollectionSchema(
fields=fields,
description="Sketch vectors collection for recommendation system"
)
# 创建集合
collection = Collection(
name=MILVUS_COLLECTION_SKETCH_VECTORS,
schema=schema,
using=settings.MILVUS_ALIAS
)
# 创建索引
# 注意:使用 IP内积作为度量类型与搜索时保持一致
# 如果向量已归一化IP 等价于 COSINE
index_params = {
"metric_type": "IP", # 内积Inner Product
"index_type": "IVF_FLAT",
"params": {"nlist": 1024}
}
collection.create_index(
field_name="feature_vector",
index_params=index_params
)
logger.info(f"集合 {MILVUS_COLLECTION_SKETCH_VECTORS} 创建成功")
except Exception as e:
logger.error(f"创建集合失败: {e}", exc_info=True)
raise
def insert_vectors(data: List[Dict[str, Any]]):
"""
批量插入向量到 Milvus
Args:
data: 数据列表,每个元素包含:
- path: str
- sys_file_id: int (可选)
- style: str (可选)
- category: str (可选)
- is_system_sketch: int (默认 1)
- deprecated: int (默认 0)
- feature_vector: List[float] (2048维)
"""
if not data:
return
client = get_milvus_client()
try:
client.insert(
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
data=data
)
logger.info(f"成功插入 {len(data)} 条向量数据")
except Exception as e:
logger.error(f"插入向量失败: {e}", exc_info=True)
raise
def query_vectors_by_paths(paths: List[str]) -> Dict[str, Dict]:
"""
根据 path 列表批量查询向量
Args:
paths: path 列表
Returns:
{path: {feature_vector: [...], ...}} 字典
"""
if not paths:
return {}
client = get_milvus_client()
try:
# 构建查询表达式
# 使用 filter 参数而不是 expr根据 pymilvus MilvusClient API
# 对于字符串列表,使用单引号包裹每个值
path_list = ", ".join([f"'{p}'" for p in paths])
filter_expr = f"path in [{path_list}]"
results = client.query(
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
filter=filter_expr,
output_fields=["path", "feature_vector", "style", "category", "sys_file_id", "is_system_sketch", "deprecated"]
)
# 转换为字典
result_dict = {}
for r in results:
result_dict[r["path"]] = r
return result_dict
except Exception as e:
logger.error(f"查询向量失败: {e}", exc_info=True)
return {}
def search_similar_vectors(
query_vector: np.ndarray,
category: str,
topk: int = 500,
style: Optional[str] = None,
style_boost_ratio: float = 0.2
) -> List[Dict]:
"""
向量相似度检索
Args:
query_vector: 查询向量2048维
category: 类别过滤
topk: 返回数量
style: 风格过滤(可选)- 当提供时会给对应style的结果加分
style_boost_ratio: 风格加分比例默认0.1即10%
Returns:
检索结果列表,每个元素包含 path, score, style, category 等字段
"""
client = get_milvus_client()
try:
# 如果没有指定style使用原始逻辑
if not style:
filter_expr = f"category == '{category}' && deprecated == 0"
results = client.search(
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
data=[query_vector.tolist()],
anns_field="feature_vector",
search_params={"metric_type": "IP", "params": {"nprobe": 10}},
limit=topk,
filter=filter_expr,
output_fields=["path", "style", "category", "sys_file_id"]
)
else:
# 有style参数时使用两阶段搜索策略
# 第一阶段搜索匹配style的向量使用boosted query vector
filter_expr_style = f"category == '{category}' && deprecated == 0 && style == '{style}'"
boosted_query = query_vector * (1 + style_boost_ratio)
results_style = client.search(
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
data=[boosted_query.tolist()],
anns_field="feature_vector",
search_params={"metric_type": "IP", "params": {"nprobe": 10}},
limit=topk,
filter=filter_expr_style,
output_fields=["path", "style", "category", "sys_file_id"]
)
# 第二阶段搜索其他style的向量
filter_expr_others = f"category == '{category}' && deprecated == 0 && style != '{style}'"
results_others = client.search(
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
data=[query_vector.tolist()],
anns_field="feature_vector",
search_params={"metric_type": "IP", "params": {"nprobe": 10}},
limit=topk,
filter=filter_expr_others,
output_fields=["path", "style", "category", "sys_file_id"]
)
# 合并结果
results = []
if results_style and len(results_style) > 0:
results.extend(results_style[0])
if results_others and len(results_others) > 0:
results.extend(results_others[0])
# 转换为单个结果列表格式
results = [results] if results else []
# 格式化结果
formatted_results = []
if results and len(results) > 0:
for hit in results[0]:
formatted_results.append({
"path": hit.get("entity", {}).get("path", ""),
"score": hit.get("distance", 0.0),
"style": hit.get("entity", {}).get("style", ""),
"category": hit.get("entity", {}).get("category", ""),
"sys_file_id": hit.get("entity", {}).get("sys_file_id")
})
# 按分数排序并返回topk
formatted_results.sort(key=lambda x: x["score"], reverse=True)
return formatted_results[:topk]
except Exception as e:
logger.error(f"向量检索失败: {e}", exc_info=True)
return []
def query_random_candidates(category: str, style: Optional[str] = None, limit: int = 10) -> List[Dict]:
"""
随机查询候选(用于探索分支)
Args:
category: 类别
style: 风格(可选)
limit: 返回数量
Returns:
候选列表
"""
client = get_milvus_client()
try:
# 构建过滤表达式
filter_expr = f"category == '{category}' && deprecated == 0"
if style:
filter_expr += f" && style == '{style}'"
# 查询所有符合条件的记录
results = client.query(
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
filter=filter_expr,
output_fields=["path", "style", "category"],
limit=10000
)
# 随机选择
if len(results) > limit:
import random
results = random.sample(results, limit)
return results
except Exception as e:
logger.error(f"随机查询候选失败: {e}", exc_info=True)
return []

View File

@@ -0,0 +1,557 @@
"""
预计算模块
包含数据库表结构优化、Milvus集合创建、系统图向量预计算、初始用户偏好向量生成
"""
import logging
import math
import pymysql
import numpy as np
from datetime import datetime
from typing import List, Dict, Tuple, Optional
from collections import defaultdict
from app.service.recommendation_system.config import (
MYSQL_CONFIG, TABLE_USER_PREFERENCE_LOG, TABLE_SYS_FILE,
RECOMMENDATION_CONFIG, REDIS_KEY_USER_PREF_PREFIX
)
from app.service.recommendation_system.vector_utils import extract_feature_vector, normalize_vector, compute_weighted_average
from app.service.recommendation_system.milvus_client import (
create_collection, insert_vectors, query_vectors_by_paths
)
from app.service.utils.redis_utils import Redis
import json
logger = logging.getLogger(__name__)
def optimize_database_table():
"""
优化 user_preference 表结构
添加冗余字段和索引
"""
conn = None
try:
conn = pymysql.connect(**MYSQL_CONFIG)
cursor = conn.cursor()
# 1. 添加冗余字段
logger.info("添加冗余字段...")
alter_sqls = [
f"ALTER TABLE {TABLE_USER_PREFERENCE_LOG} ADD COLUMN category VARCHAR(100) COMMENT '类别lower(level3_type + \"_\" + level2_type)'",
f"ALTER TABLE {TABLE_USER_PREFERENCE_LOG} ADD COLUMN style VARCHAR(50) COMMENT '风格样式'",
f"ALTER TABLE {TABLE_USER_PREFERENCE_LOG} ADD COLUMN is_system_sketch TINYINT(1) DEFAULT 1 COMMENT '是否为系统图1-是0-用户图)'",
f"ALTER TABLE {TABLE_USER_PREFERENCE_LOG} ADD COLUMN sys_file_id BIGINT NULL COMMENT '系统文件ID'",
]
for sql in alter_sqls:
try:
cursor.execute(sql)
logger.info(f"执行成功: {sql[:50]}...")
except Exception as e:
if "Duplicate column name" in str(e):
logger.info(f"字段已存在,跳过: {sql[:50]}...")
else:
logger.warning(f"执行失败: {sql[:50]}... 错误: {e}")
# 2. 创建索引MySQL 不支持 IF NOT EXISTS需要先检查
logger.info("创建索引...")
index_definitions = [
("idx_account_category_time", ["account_id", "category", "data_time"]),
("idx_account_path", ["account_id", "path"]),
]
for index_name, columns in index_definitions:
try:
# 检查索引是否已存在
cursor.execute(f"""
SELECT COUNT(*)
FROM information_schema.statistics
WHERE table_schema = DATABASE()
AND table_name = '{TABLE_USER_PREFERENCE_LOG}'
AND index_name = '{index_name}'
""")
exists = cursor.fetchone()[0] > 0
if exists:
logger.info(f"索引已存在,跳过: {index_name}")
else:
# 创建索引
columns_str = ', '.join(columns)
create_sql = f"CREATE INDEX {index_name} ON {TABLE_USER_PREFERENCE_LOG}({columns_str})"
cursor.execute(create_sql)
logger.info(f"索引创建成功: {index_name}")
except Exception as e:
logger.warning(f"索引创建失败: {index_name} 错误: {e}")
conn.commit()
logger.info("数据库表结构优化完成")
except Exception as e:
logger.error(f"数据库表结构优化失败: {e}", exc_info=True)
if conn:
conn.rollback()
finally:
if conn:
conn.close()
def migrate_historical_data(batch_size: int = 1000):
"""
历史数据迁移:批量更新冗余字段
Args:
batch_size: 每批处理数量
"""
conn = None
try:
conn = pymysql.connect(**MYSQL_CONFIG)
cursor = conn.cursor()
# 查询需要更新的记录数
cursor.execute(f"""
SELECT COUNT(*)
FROM {TABLE_USER_PREFERENCE_LOG} u
WHERE u.category IS NULL
""")
total_count = cursor.fetchone()[0]
logger.info(f"需要迁移的记录数: {total_count}")
if total_count == 0:
logger.info("无需迁移数据")
return
# 分批处理
offset = 0
processed = 0
while offset < total_count:
# 查询一批记录
cursor.execute(f"""
SELECT u.id, u.path
FROM {TABLE_USER_PREFERENCE_LOG} u
WHERE u.category IS NULL
LIMIT {batch_size} OFFSET {offset}
""")
records = cursor.fetchall()
if not records:
break
# 批量更新
for record_id, path in records:
# 查询 t_sys_file 表
cursor.execute(f"""
SELECT id, url, style, level3_type, level2_type, deprecated
FROM {TABLE_SYS_FILE}
WHERE url = %s
LIMIT 1
""", (path,))
sys_file = cursor.fetchone()
if sys_file:
# 系统图
sys_file_id, url, style, level3_type, level2_type, deprecated = sys_file
category = f"{level3_type.lower()}_{level2_type.lower()}"
cursor.execute(f"""
UPDATE {TABLE_USER_PREFERENCE_LOG}
SET category = %s,
style = %s,
is_system_sketch = 1,
sys_file_id = %s
WHERE id = %s
""", (category, style, sys_file_id, record_id))
else:
# 用户图
cursor.execute(f"""
UPDATE {TABLE_USER_PREFERENCE_LOG}
SET is_system_sketch = 0,
category = NULL,
style = NULL,
sys_file_id = NULL
WHERE id = %s
""", (record_id,))
conn.commit()
processed += len(records)
offset += batch_size
logger.info(f"已迁移 {processed}/{total_count} 条记录")
logger.info("历史数据迁移完成")
except Exception as e:
logger.error(f"历史数据迁移失败: {e}", exc_info=True)
if conn:
conn.rollback()
finally:
if conn:
conn.close()
def precompute_system_sketch_vectors(batch_size: int = 1000, retry_times: int = 3):
"""
系统图向量预计算与导入
Args:
batch_size: 每批处理数量
retry_times: 失败重试次数
"""
conn = None
try:
conn = pymysql.connect(**MYSQL_CONFIG)
cursor = conn.cursor()
# 1. 数据筛选
logger.info("查询系统图数据...")
cursor.execute(f"""
SELECT id, url, style, level3_type, level2_type, deprecated
FROM {TABLE_SYS_FILE}
WHERE level1_type = 'Images'
AND style IS NOT NULL
AND style != ''
AND deprecated != 1
""")
records = cursor.fetchall()
logger.info(f"找到 {len(records)} 条系统图记录")
if not records:
logger.warning("没有找到系统图数据")
return
# 2. 批量处理
failed_records = []
batch_data = []
for idx, (sys_file_id, url, style, level3_type, level2_type, deprecated) in enumerate(records, 1):
try:
# 计算 category
category = f"{level3_type.lower()}_{level2_type.lower()}"
# 提取特征向量
feature_vector = extract_feature_vector(url)
# 检查向量是否有效
if np.all(feature_vector == 0):
logger.warning(f"向量提取失败,跳过: {url}")
failed_records.append((sys_file_id, url))
continue
# 准备数据
data_item = {
"path": url,
"sys_file_id": sys_file_id,
"style": style,
"category": category,
"is_system_sketch": 1,
"deprecated": deprecated if deprecated else 0,
"feature_vector": feature_vector.tolist()
}
batch_data.append(data_item)
# 批量写入
if len(batch_data) >= batch_size:
try:
insert_vectors(batch_data)
batch_data = []
logger.info(f"已处理 {idx}/{len(records)} 条记录")
except Exception as e:
logger.error(f"批量写入失败: {e}")
failed_records.extend([(item["sys_file_id"], item["path"]) for item in batch_data])
batch_data = []
except Exception as e:
logger.error(f"处理记录失败 [{url}]: {e}")
failed_records.append((sys_file_id, url))
# 写入剩余数据
if batch_data:
try:
insert_vectors(batch_data)
except Exception as e:
logger.error(f"写入剩余数据失败: {e}")
failed_records.extend([(item["sys_file_id"], item["path"]) for item in batch_data])
# 3. 重试失败记录
if failed_records and retry_times > 0:
logger.info(f"重试 {len(failed_records)} 条失败记录...")
for retry in range(retry_times):
retry_failed = []
for sys_file_id, url in failed_records:
try:
category = f"{level3_type.lower()}_{level2_type.lower()}"
feature_vector = extract_feature_vector(url)
if not np.all(feature_vector == 0):
data_item = {
"path": url,
"sys_file_id": sys_file_id,
"style": style,
"category": category,
"is_system_sketch": 1,
"deprecated": 0,
"feature_vector": feature_vector.tolist()
}
insert_vectors([data_item])
else:
retry_failed.append((sys_file_id, url))
except Exception as e:
logger.error(f"重试失败 [{url}]: {e}")
retry_failed.append((sys_file_id, url))
failed_records = retry_failed
if not failed_records:
break
if failed_records:
logger.warning(f"仍有 {len(failed_records)} 条记录处理失败")
logger.info("系统图向量预计算完成")
except Exception as e:
logger.error(f"系统图向量预计算失败: {e}", exc_info=True)
finally:
if conn:
conn.close()
def compute_user_preference_vector(
account_id: int,
category: str,
conn: Optional[pymysql.connections.Connection] = None,
max_date: Optional[datetime] = None
) -> Optional[np.ndarray]:
"""
计算用户偏好向量
Args:
account_id: 用户ID
category: 类别
conn: 数据库连接(可选)
max_date: 最大日期(可选,用于评估时只使用训练集数据)
Returns:
用户偏好向量2048维失败返回 None
"""
from datetime import datetime
should_close = False
if conn is None:
conn = pymysql.connect(**MYSQL_CONFIG)
should_close = True
try:
cursor = conn.cursor()
# 1. 获取点赞记录如果指定了max_date只查询该日期之前的数据
if max_date:
cursor.execute(f"""
SELECT path, data_time
FROM {TABLE_USER_PREFERENCE_LOG}
WHERE account_id = %s AND category = %s AND style is not null
AND data_time < %s
ORDER BY data_time DESC
""", (account_id, category, max_date))
else:
cursor.execute(f"""
SELECT path, data_time
FROM {TABLE_USER_PREFERENCE_LOG}
WHERE account_id = %s AND category = %s AND style is not null
ORDER BY data_time DESC
""", (account_id, category))
like_records = cursor.fetchall()
if not like_records:
return None
# 2. 批量查询点赞次数如果指定了max_date只统计该日期之前的点赞
paths = [r[0] for r in like_records]
if not paths:
return None
placeholders = ','.join(['%s'] * len(paths))
if max_date:
cursor.execute(f"""
SELECT path, COUNT(*) as like_count
FROM {TABLE_USER_PREFERENCE_LOG}
WHERE account_id = %s AND category = %s AND path IN ({placeholders})
AND data_time < %s
GROUP BY path
""", (account_id, category) + tuple(paths) + (max_date,))
else:
cursor.execute(f"""
SELECT path, COUNT(*) as like_count
FROM {TABLE_USER_PREFERENCE_LOG}
WHERE account_id = %s AND category = %s AND path IN ({placeholders})
GROUP BY path
""", (account_id, category) + tuple(paths))
like_counts = {row[0]: row[1] for row in cursor.fetchall()}
# 3. 批量获取向量
vectors_dict = query_vectors_by_paths(paths)
# 处理查询不到的 path用户图或异常情况
missing_paths = [p for p in paths if p not in vectors_dict]
if missing_paths:
logger.info(f"用户 {account_id} 类别 {category}{len(missing_paths)} 个 path 需要实时计算向量")
# 目前未有非系统图向量,跳过
# 这里可以实时计算并写入 Milvus但为了简化先跳过
# 实际实现中应该调用 vector_utils.extract_feature_vector 并写入 Milvus
# 4. 计算权重并加权平均
vectors = []
weights = []
K_half = RECOMMENDATION_CONFIG["K_half"]
for k, (path, data_time) in enumerate(like_records, 1):
if path not in vectors_dict:
continue
vector_data = vectors_dict[path]
feature_vector = np.array(vector_data["feature_vector"])
# 时间衰减权重
d_k = 0.5 ** (k / K_half)
# 点赞次数权重
like_count = like_counts.get(path, 1)
p_i = 1 + math.log(1 + like_count)
# 综合权重
w_i = d_k * p_i
# w_i = p_i
vectors.append(feature_vector)
weights.append(w_i)
if not vectors:
return None
# 5. 计算加权平均并做 L2 归一化IP≈cosine
preference_vector = compute_weighted_average(vectors, weights)
preference_vector = normalize_vector(preference_vector)
return preference_vector
except Exception as e:
logger.error(f"计算用户偏好向量失败 [user={account_id}, category={category}]: {e}", exc_info=True)
return None
finally:
if should_close and conn:
conn.close()
def generate_initial_user_preference_vectors(batch_size: int = 100):
"""
初始用户偏好向量生成
Args:
batch_size: 每批处理用户数
"""
conn = None
try:
conn = pymysql.connect(**MYSQL_CONFIG)
cursor = conn.cursor()
# 1. 扫描历史数据
logger.info("扫描用户和类别组合...")
cursor.execute(f"""
SELECT DISTINCT account_id, category
FROM {TABLE_USER_PREFERENCE_LOG}
WHERE category IS NOT NULL
AND style IS NOT NULL
""")
user_categories = cursor.fetchall()
logger.info(f"找到 {len(user_categories)} 个用户-类别组合")
if not user_categories:
logger.warning("没有找到用户-类别组合")
return
# 2. 批量处理
processed = 0
failed = 0
for account_id, category in user_categories:
try:
# 计算偏好向量
preference_vector = compute_user_preference_vector(account_id, category, conn)
if preference_vector is not None:
# 写入 Redis
key = f"{REDIS_KEY_USER_PREF_PREFIX}:{account_id}:{category}"
# 序列化向量(使用 JSON
vector_json = json.dumps(preference_vector.tolist())
Redis.write(
key=key,
value=vector_json,
expire=RECOMMENDATION_CONFIG["redis_expire_seconds"]
)
processed += 1
else:
failed += 1
if (processed + failed) % batch_size == 0:
logger.info(f"已处理 {processed + failed}/{len(user_categories)} 个组合,成功: {processed}, 失败: {failed}")
except Exception as e:
logger.error(f"处理失败 [user={account_id}, category={category}]: {e}")
failed += 1
logger.info(f"初始用户偏好向量生成完成,成功: {processed}, 失败: {failed}")
except Exception as e:
logger.error(f"初始用户偏好向量生成失败: {e}", exc_info=True)
finally:
if conn:
conn.close()
def run_precompute():
"""
运行所有预计算任务
"""
logger.info("=" * 50)
logger.info("开始预计算任务")
logger.info("=" * 50)
# 1. 优化数据库表结构
# logger.info("\n[1/5] 优化数据库表结构...")
# optimize_database_table()
# # 2. 创建 Milvus 集合
# logger.info("\n[2/5] 创建 Milvus 集合...")
# create_collection()
# 3. 历史数据迁移
# logger.info("\n[3/5] 历史数据迁移...")
# migrate_historical_data()
# # 4. 系统图向量预计算
# logger.info("\n[4/5] 系统图向量预计算...")
# precompute_system_sketch_vectors()
# 5. 初始用户偏好向量生成
logger.info("\n[5/5] 初始用户偏好向量生成...")
generate_initial_user_preference_vectors()
logger.info("=" * 50)
logger.info("预计算任务完成")
logger.info("=" * 50)
if __name__ == "__main__":
# # 1. 优化数据库表结构
# logger.info("\n[1/5] 优化数据库表结构...")
# optimize_database_table()
#
# # 3. 历史数据迁移
# logger.info("\n[3/5] 历史数据迁移...")
# migrate_historical_data()
# 5. 初始用户偏好向量生成
logger.info("\n[5/5] 初始用户偏好向量生成...")
generate_initial_user_preference_vectors()

View File

@@ -0,0 +1,214 @@
"""
推荐接口实现
实现探索/利用分支、向量检索、Softmax抽样等功能
"""
import logging
import math
import random
import numpy as np
from typing import List, Dict, Optional
from app.service.recommendation_system.config import RECOMMENDATION_CONFIG, REDIS_KEY_USER_PREF_PREFIX
from app.service.recommendation_system.milvus_client import search_similar_vectors, query_random_candidates
from app.service.recommendation_system.precompute import compute_user_preference_vector
from app.service.recommendation_system.vector_utils import normalize_vector
from app.service.utils.redis_utils import Redis
import json
logger = logging.getLogger(__name__)
def get_user_preference_vector(user_id: int, category: str) -> Optional[np.ndarray]:
"""
获取用户偏好向量
Args:
user_id: 用户ID
category: 类别
Returns:
用户偏好向量2048维失败返回 None
"""
# 1. 从 Redis 获取
key = f"{REDIS_KEY_USER_PREF_PREFIX}:{user_id}:{category}"
vector_json = Redis.read(key)
if vector_json:
try:
vector_list = json.loads(vector_json)
return np.array(vector_list, dtype=np.float32)
except Exception as e:
logger.warning(f"解析 Redis 向量失败 [user={user_id}, category={category}]: {e}")
# 2. 如果不存在,实时计算
logger.info(f"Redis 中不存在用户偏好向量,实时计算 [user={user_id}, category={category}]")
preference_vector = compute_user_preference_vector(user_id, category)
if preference_vector is not None:
# 写入 Redis
vector_json = json.dumps(preference_vector.tolist())
Redis.write(
key=key,
value=vector_json,
expire=RECOMMENDATION_CONFIG["redis_expire_seconds"]
)
return preference_vector
def explore_branch(category: str, style: Optional[str] = None) -> List[str]:
"""
探索分支(随机推荐)
Args:
category: 类别
style: 风格(可选)
Returns:
推荐结果列表,每个元素包含 path, style, category 等字段
"""
# 查询候选(随机池)
pool_size = 10 # 固定查询10个然后随机选择
candidates = query_random_candidates(category, style, limit=pool_size)
if not candidates:
logger.warning(f"探索分支:类别 {category} 没有候选数据")
return []
# 随机选择
if len(candidates) > 1:
import random
candidates = random.sample(candidates, 1)
# 格式化返回结果
return [candidate.get("path", "") for candidate in candidates[:1]]
def exploit_branch(
user_id: int,
category: str,
style: Optional[str] = None
) -> List[str]:
"""
利用分支(基于向量相似度推荐)
Args:
user_id: 用户ID
category: 类别
num_recommendations: 返回数量
style: 风格(可选,用于加分)
Returns:
推荐结果列表,每个元素包含 path, style, category, similarity, sample_score 等字段
"""
# 1. 获取用户偏好向量
embedding = get_user_preference_vector(user_id, category)
if embedding is None:
logger.warning(f"利用分支:无法获取用户偏好向量,回退到探索分支 [user={user_id}, category={category}]")
return explore_branch(category, style)
# 2. Milvus 相似度检索(内积 IP
topk = RECOMMENDATION_CONFIG["topk"]
results = search_similar_vectors(embedding, category, topk)
if not results:
logger.warning(f"利用分支:向量检索无结果,回退到探索分支 [user={user_id}, category={category}]")
return explore_branch(category, style)
# 3. Style 加分(可选,需传入 style 参数)
style_bonus = RECOMMENDATION_CONFIG["style_bonus"]
if style:
for result in results:
similarity = result["score"]
if result.get("style") == style:
# 加分:相似度 * (1 + style_bonus)
similarity = similarity * (1 + style_bonus)
result["final_score"] = similarity
else:
for result in results:
result["final_score"] = result["score"]
# 4. Softmax 抽样
scores = [r["final_score"] for r in results]
probabilities = softmax_with_temperature(scores, RECOMMENDATION_CONFIG["softmax_temperature"])
# 根据概率抽样
if not results:
return []
selected_index = np.random.choice(len(results), size=1, p=probabilities, replace=False)
selected_results = [results[int(selected_index[0])]]
# 5. 返回结果
return [result.get("path", "") for result in selected_results]
def softmax_with_temperature(scores: List[float], temperature: float = 1.0) -> List[float]:
"""
Softmax 函数(带温度参数)
Args:
scores: 分数列表
temperature: 温度参数
Returns:
概率列表
"""
if not scores:
return []
# 除以温度
scaled_scores = [s / temperature for s in scores]
# 减去最大值(数值稳定性)
max_score = max(scaled_scores)
exp_scores = [math.exp(s - max_score) for s in scaled_scores]
# 归一化
sum_exp = sum(exp_scores)
if sum_exp == 0:
# 如果所有分数都是负无穷或非常小,返回均匀分布
return [1.0 / len(scores)] * len(scores)
probabilities = [exp_s / sum_exp for exp_s in exp_scores]
return probabilities
def get_recommendations(
user_id: int,
category: str,
style: Optional[str] = None
) -> List[str]:
"""
获取推荐结果(主函数)
Args:
user_id: 用户ID
category: 类别(如 female_skirt
num_recommendations: 返回推荐数量(默认 1
style: 风格(可选):若传入,则在利用分支对同 style 的候选进行加分
Returns:
推荐结果列表,每个元素包含 path 等字段
"""
try:
# 1. 读取配置参数
explore_ratio = RECOMMENDATION_CONFIG["explore_ratio"]
# 2. 探索/利用决策
r = random.random() # 生成随机数 (0-1)
if r < explore_ratio:
logger.debug(f"探索分支 [user={user_id}, category={category}]")
return explore_branch(category, style)
logger.debug(f"利用分支 [user={user_id}, category={category}]")
return exploit_branch(user_id, category, style)
except Exception as e:
logger.error(f"获取推荐结果失败 [user={user_id}, category={category}]: {e}", exc_info=True)
# 容错:回退到探索分支
return explore_branch(category, style)

View File

@@ -0,0 +1,189 @@
"""
向量计算工具类
包含 ResNet50 特征提取、向量归一化等功能
"""
import io
import logging
import numpy as np
import torch
from torchvision import models, transforms
from PIL import Image
from minio import Minio
from app.core.config import settings
from app.service.recommendation_system.config import RECOMMENDATION_CONFIG
logger = logging.getLogger(__name__)
# 图像预处理与ResNet训练时的预处理一致
transform = transforms.Compose([
transforms.Resize((224, 224)), # ResNet 要求 224x224 的输入
transforms.ToTensor(), # 转换为 Tensor
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # 标准化
])
# 加载预训练的 ResNet50 模型(去掉最后全连接层)
_resnet_model = None
def get_resnet_model():
"""获取 ResNet50 模型(单例模式)"""
global _resnet_model
if _resnet_model is None:
logger.info("加载 ResNet50 模型...")
_resnet_model = models.resnet50(pretrained=True)
modules = list(_resnet_model.children())[:-1] # 移除最后的全连接层
_resnet_model = torch.nn.Sequential(*modules)
_resnet_model.eval() # 设置为评估模式
logger.info("ResNet50 模型加载完成")
return _resnet_model
# MinIO 客户端(单例)
_minio_client = None
def get_minio_client():
"""获取 MinIO 客户端(单例模式)"""
global _minio_client
if _minio_client is None:
_minio_client = Minio(
settings.MINIO_URL,
access_key=settings.MINIO_ACCESS,
secret_key=settings.MINIO_SECRET,
secure=settings.MINIO_SECURE
)
return _minio_client
def get_image_from_minio(path: str) -> Image.Image:
"""
从 MinIO 获取图片
Args:
path: MinIO 逻辑 URL格式如 "bucket_name/object_name"
Returns:
PIL Image 对象,失败返回 None
"""
try:
# 分割路径,获取桶名和文件路径
path_parts = path.split('/', 1)
if len(path_parts) != 2:
logger.error(f"路径格式错误: {path}")
return None
bucket_name, file_name = path_parts
minio_client = get_minio_client()
# 获取文件
obj = minio_client.get_object(bucket_name, file_name)
img_data = obj.read() # 读取图像数据
img = Image.open(io.BytesIO(img_data)) # 将数据转为图像对象
return img
except Exception as e:
logger.error(f"从 MinIO 获取图片失败 [{path}]: {e}")
return None
def extract_feature_vector(path: str) -> np.ndarray:
"""
使用 ResNet50 提取图片特征向量2048维
Args:
path: MinIO 逻辑 URL
Returns:
2048维特征向量numpy array失败返回零向量
"""
try:
# 从 MinIO 获取图像
img = get_image_from_minio(path)
if img is None:
logger.warning(f"无法获取图片,返回零向量: {path}")
return np.zeros(RECOMMENDATION_CONFIG["vector_dim"], dtype=np.float32)
# 预处理
# 部分 MinIO 图片可能是 RGBA/CMYK转换成 RGB 以匹配 3 通道标准化参数
if img.mode != "RGB":
try:
img = img.convert("RGB")
except Exception:
logger.warning(f"无法转换图片为RGB返回零向量: {path}")
return np.zeros(RECOMMENDATION_CONFIG["vector_dim"], dtype=np.float32)
img_tensor = transform(img).unsqueeze(0) # 扩展维度以适应批量处理
# 提取特征
resnet_model = get_resnet_model()
with torch.no_grad(): # 在不需要计算梯度的情况下进行推断
feature_vector = resnet_model(img_tensor) # 获取 ResNet 的输出
feature_vector = feature_vector.squeeze().cpu().numpy() # 转换为 NumPy 数组并去掉 batch 维度
# 确保是 2048 维
if feature_vector.ndim > 1:
feature_vector = feature_vector.flatten()
# 确保维度正确
if len(feature_vector) != RECOMMENDATION_CONFIG["vector_dim"]:
logger.warning(f"向量维度不正确: {len(feature_vector)}, 期望: {RECOMMENDATION_CONFIG['vector_dim']}")
# 如果维度不对,尝试调整
if len(feature_vector) > RECOMMENDATION_CONFIG["vector_dim"]:
feature_vector = feature_vector[:RECOMMENDATION_CONFIG["vector_dim"]]
else:
padded = np.zeros(RECOMMENDATION_CONFIG["vector_dim"], dtype=np.float32)
padded[:len(feature_vector)] = feature_vector
feature_vector = padded
return feature_vector.astype(np.float32)
except Exception as e:
logger.error(f"提取特征向量失败 [{path}]: {e}", exc_info=True)
return np.zeros(RECOMMENDATION_CONFIG["vector_dim"], dtype=np.float32)
def normalize_vector(vector: np.ndarray) -> np.ndarray:
"""
L2 归一化向量
Args:
vector: 输入向量
Returns:
归一化后的向量
"""
norm = np.linalg.norm(vector)
if norm == 0:
return vector
return vector / norm
def compute_weighted_average(vectors: list, weights: list) -> np.ndarray:
"""
计算加权平均向量
Args:
vectors: 向量列表
weights: 权重列表
Returns:
加权平均向量(不做归一化,模长为加权平均后的尺度)
"""
if not vectors or not weights:
return np.zeros(RECOMMENDATION_CONFIG["vector_dim"], dtype=np.float32)
# 确保所有向量都是 numpy array
vectors = [np.array(v) for v in vectors]
weights = np.array(weights)
# 计算加权和
weighted_sum = np.zeros_like(vectors[0])
for v, w in zip(vectors, weights):
weighted_sum += v * w
# 返回加权平均(除以权重和,不做 L2 归一化,模长不会随条数线性暴涨)
weight_total = weights.sum()
if weight_total == 0:
return weighted_sum
return weighted_sum / weight_total

View File

@@ -0,0 +1,27 @@
import cv2
import numpy as np
def my_imnormalize(img, mean, std, to_rgb=True):
"""Inplace normalize an image with mean and std.
Args:
img (ndarray): Image to be normalized.
mean (ndarray): The mean to be used for normalize.
std (ndarray): The std to be used for normalize.
to_rgb (bool): Whether to convert to rgb.
Returns:
ndarray: The normalized image.
"""
# cv2 inplace normalization does not accept uint8
img = img.copy().astype(np.float32)
assert img.dtype != np.uint8
mean = np.float64(mean.reshape(1, -1))
stdinv = 1 / np.float64(std.reshape(1, -1))
if to_rgb:
cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) # inplace
cv2.subtract(img, mean, img) # inplace
cv2.multiply(img, stdinv, img) # inplace
return img

View File

@@ -81,7 +81,7 @@ if __name__ == '__main__':
# url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg"
# url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png"
# url = "aida-users/89/single_logo/123-89.png"
url = "lanecarford/lc_stylist_agent_outfit_items/141/ee25ec85-d504-4b42-9a18-db6682fe9e3b-6.jpg"
url = "aida-results/result_a7adcbd8-ef8d-11f0-8c92-0966ede33ab5.png"
# url = "aida-collection-element/12148/Sketchboard/95ea577b-305b-4a62-b30a-39c0dd3ddb3f.png"
read_type = "2"

View File

@@ -91,6 +91,21 @@ class Redis(object):
r = cls._get_r()
r.expire(name, expire_in_seconds)
@classmethod
def scan_keys(cls, pattern="*"):
"""
扫描匹配模式的key
"""
r = cls._get_r()
keys = []
cursor = 0
while True:
cursor, partial_keys = r.scan(cursor, match=pattern, count=1000)
keys.extend(partial_keys)
if cursor == 0:
break
return [key.decode('utf-8') if isinstance(key, bytes) else key for key in keys]
if __name__ == '__main__':
redis_client = Redis()

View File

@@ -1,13 +1,20 @@
services:
aida_server:
container_name: "AiDA_${SERVE_ENV}_Server"
build:
context: .
dockerfile: Dockerfile
working_dir: /app
volumes:
- ./app:/app/app
- ./.env_prod:/app/.env
- ./.env:/app/.env
- /etc/localtime:/etc/localtime:ro
- ./seg_cache:/seg_cache
ports:
- "10200:80"
- "${SERVE_PORT}:80"
networks:
- aida_app_net
networks:
aida_app_net:
external: true
name: aida_app_net

View File

@@ -1,10 +1,15 @@
import os
from app.core.config import settings
LOGGER_CONFIG_DICT = {
'version': 1,
'disable_existing_loggers': False,
'formatters': {
'simple': {'format': '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s'}
'simple': {
'format': '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s',
'datefmt': '%Y-%m-%d %H:%M:%S' # 补充日期格式,日志更易读
}
},
'handlers': {
'console': {
@@ -17,7 +22,7 @@ LOGGER_CONFIG_DICT = {
'class': 'logging.handlers.RotatingFileHandler',
'level': 'INFO',
'formatter': 'simple',
'filename': f'{settings.LOGS_PATH}info.log',
'filename': os.path.join(settings.LOGS_PATH, 'info.log'),
'maxBytes': 10485760,
'backupCount': 50,
'encoding': 'utf8',
@@ -26,7 +31,7 @@ LOGGER_CONFIG_DICT = {
'class': 'logging.handlers.RotatingFileHandler',
'level': 'ERROR',
'formatter': 'simple',
'filename': f'{settings.LOGS_PATH}error.log',
'filename': os.path.join(settings.LOGS_PATH, 'error.log'),
'maxBytes': 10485760,
'backupCount': 20,
'encoding': 'utf8',
@@ -35,7 +40,7 @@ LOGGER_CONFIG_DICT = {
'class': 'logging.handlers.RotatingFileHandler',
'level': 'DEBUG',
'formatter': 'simple',
'filename': f'{settings.LOGS_PATH}debug.log',
'filename': os.path.join(settings.LOGS_PATH, 'debug.log'),
'maxBytes': 10485760,
'backupCount': 50,
'encoding': 'utf8',
@@ -45,7 +50,7 @@ LOGGER_CONFIG_DICT = {
'my_module': {'level': 'INFO', 'handlers': ['console'], 'propagate': 'no'}
},
'root': {
'level': 'INFO',
'level': 'DEBUG',
'handlers': ['error_file_handler', 'info_file_handler', 'debug_file_handler', 'console'],
},
}

View File

@@ -23,8 +23,8 @@ dependencies = [
"load-dotenv>=0.1.0",
"loguru>=0.7.3",
"minio>=7.2.20",
"mmcv>=2.2.0",
"moviepy==1.0.3",
"np>=1.0.2",
"numpy<2",
"ollama>=0.6.1",
"opencv-python>=4.11.0.86",

Binary file not shown.

Binary file not shown.

88
uv.lock generated
View File

@@ -8,15 +8,6 @@ resolution-markers = [
"(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')",
]
[[package]]
name = "addict"
version = "2.4.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/85/ef/fd7649da8af11d93979831e8f1f8097e85e82d5bfeabc8c68b39175d8e75/addict-2.4.0.tar.gz", hash = "sha256:b3b2210e0e067a281f5646c8c5db92e99b7231ea8b0eb5f74dbdf9e259d4e494", size = 9186, upload-time = "2020-11-21T16:21:31.416Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/6a/00/b08f23b7d7e1e14ce01419a467b583edbb93c6cdb8654e54a9cc579cd61f/addict-2.4.0-py3-none-any.whl", hash = "sha256:249bb56bbfd3cdc2a004ea0ff4c2b6ddc84d53bc2194761636eb314d5cfa5dfc", size = 3832, upload-time = "2020-11-21T16:21:29.588Z" },
]
[[package]]
name = "agentaction"
version = "0.1.7"
@@ -1671,43 +1662,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/3e/9a/b697530a882588a84db616580f2ba5d1d515c815e11c30d219145afeec87/minio-7.2.20-py3-none-any.whl", hash = "sha256:eb33dd2fb80e04c3726a76b13241c6be3c4c46f8d81e1d58e757786f6501897e", size = 93751, upload-time = "2025-11-27T00:37:13.993Z" },
]
[[package]]
name = "mmcv"
version = "2.2.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "addict" },
{ name = "mmengine" },
{ name = "numpy" },
{ name = "opencv-python" },
{ name = "packaging" },
{ name = "pillow" },
{ name = "pyyaml" },
{ name = "regex", marker = "sys_platform == 'win32'" },
{ name = "yapf" },
]
sdist = { url = "https://files.pythonhosted.org/packages/e9/a2/57a733e7e84985a8a0e3101dfb8170fc9db92435c16afad253069ae3f9df/mmcv-2.2.0.tar.gz", hash = "sha256:ac479247e808d8802f89eadf04d4118de86bdfe81361ec5aed0cc1bf731c67c9", size = 479121, upload-time = "2024-04-24T14:24:28.064Z" }
[[package]]
name = "mmengine"
version = "0.10.7"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "addict" },
{ name = "matplotlib" },
{ name = "numpy" },
{ name = "opencv-python" },
{ name = "pyyaml" },
{ name = "regex", marker = "sys_platform == 'win32'" },
{ name = "rich" },
{ name = "termcolor" },
{ name = "yapf" },
]
sdist = { url = "https://files.pythonhosted.org/packages/17/14/959360bbd8374e23fc1b720906999add16a3ac071a501636db12c5861ff5/mmengine-0.10.7.tar.gz", hash = "sha256:d20ffcc31127567e53dceff132612a87f0081de06cbb7ab2bdb7439125a69225", size = 378090, upload-time = "2025-03-04T12:23:09.568Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/98/8e/f98332248aad102511bea4ae19c0ddacd2f0a994f3ca4c82b7a369e0af8b/mmengine-0.10.7-py3-none-any.whl", hash = "sha256:262ac976a925562f78cd5fd14dd1bc9b680ed0aa81f0d85b723ef782f99c54ee", size = 452720, upload-time = "2025-03-04T12:23:06.339Z" },
]
[[package]]
name = "mmh3"
version = "5.2.0"
@@ -1801,6 +1755,12 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/9e/c9/b2622292ea83fbb4ec318f5b9ab867d0a28ab43c5717bb85b0a5f6b3b0a4/networkx-3.6.1-py3-none-any.whl", hash = "sha256:d47fbf302e7d9cbbb9e2555a0d267983d2aa476bac30e90dfbe5669bd57f3762", size = 2068504, upload-time = "2025-12-08T17:02:38.159Z" },
]
[[package]]
name = "np"
version = "1.0.2"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/40/7d/749666e5a9976dcbc4d16d487bbe571efc6bbf4cdf3f4620c0ccc52b57ef/np-1.0.2.tar.gz", hash = "sha256:781265283f3823663ad8fb48741aae62abcf4c78bc19f908f8aa7c1d3eb132f8", size = 7419, upload-time = "2017-10-05T11:26:00.956Z" }
[[package]]
name = "numpy"
version = "1.26.4"
@@ -2269,15 +2229,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/bc/96/aaa61ce33cc98421fb6088af2a03be4157b1e7e0e87087c888e2370a7f45/pillow-12.0.0-cp312-cp312-win_arm64.whl", hash = "sha256:7dfb439562f234f7d57b1ac6bc8fe7f838a4bd49c79230e0f6a1da93e82f1fad", size = 2436012, upload-time = "2025-10-15T18:22:23.621Z" },
]
[[package]]
name = "platformdirs"
version = "4.5.1"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/cf/86/0248f086a84f01b37aaec0fa567b397df1a119f73c16f6c7a9aac73ea309/platformdirs-4.5.1.tar.gz", hash = "sha256:61d5cdcc6065745cdd94f0f878977f8de9437be93de97c1c12f853c9c0cdcbda", size = 21715, upload-time = "2025-12-05T13:52:58.638Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/cb/28/3bfe2fa5a7b9c46fe7e13c97bda14c895fb10fa2ebf1d0abb90e0cea7ee1/platformdirs-4.5.1-py3-none-any.whl", hash = "sha256:d03afa3963c806a9bed9d5125c8f4cb2fdaf74a55ab60e5d59b3fde758104d31", size = 18731, upload-time = "2025-12-05T13:52:56.823Z" },
]
[[package]]
name = "posthog"
version = "5.4.0"
@@ -2746,17 +2697,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/2c/58/ca301544e1fa93ed4f80d724bf5b194f6e4b945841c5bfd555878eea9fcb/referencing-0.37.0-py3-none-any.whl", hash = "sha256:381329a9f99628c9069361716891d34ad94af76e461dcb0335825aecc7692231", size = 26766, upload-time = "2025-10-13T15:30:47.625Z" },
]
[[package]]
name = "regex"
version = "2025.11.3"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/cc/a9/546676f25e573a4cf00fe8e119b78a37b6a8fe2dc95cda877b30889c9c45/regex-2025.11.3.tar.gz", hash = "sha256:1fedc720f9bb2494ce31a58a1631f9c82df6a09b49c19517ea5cc280b4541e01", size = 414669, upload-time = "2025-11-03T21:34:22.089Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/59/9b/7c29be7903c318488983e7d97abcf8ebd3830e4c956c4c540005fcfb0462/regex-2025.11.3-cp312-cp312-win32.whl", hash = "sha256:3839967cf4dc4b985e1570fd8d91078f0c519f30491c60f9ac42a8db039be204", size = 266194, upload-time = "2025-11-03T21:31:51.53Z" },
{ url = "https://files.pythonhosted.org/packages/1a/67/3b92df89f179d7c367be654ab5626ae311cb28f7d5c237b6bb976cd5fbbb/regex-2025.11.3-cp312-cp312-win_amd64.whl", hash = "sha256:e721d1b46e25c481dc5ded6f4b3f66c897c58d2e8cfdf77bbced84339108b0b9", size = 277069, upload-time = "2025-11-03T21:31:53.151Z" },
{ url = "https://files.pythonhosted.org/packages/d7/55/85ba4c066fe5094d35b249c3ce8df0ba623cfd35afb22d6764f23a52a1c5/regex-2025.11.3-cp312-cp312-win_arm64.whl", hash = "sha256:64350685ff08b1d3a6fff33f45a9ca183dc1d58bbfe4981604e70ec9801bbc26", size = 270330, upload-time = "2025-11-03T21:31:54.514Z" },
]
[[package]]
name = "requests"
version = "2.32.5"
@@ -3224,8 +3164,8 @@ dependencies = [
{ name = "load-dotenv" },
{ name = "loguru" },
{ name = "minio" },
{ name = "mmcv" },
{ name = "moviepy" },
{ name = "np" },
{ name = "numpy" },
{ name = "ollama" },
{ name = "opencv-python" },
@@ -3275,8 +3215,8 @@ requires-dist = [
{ name = "load-dotenv", specifier = ">=0.1.0" },
{ name = "loguru", specifier = ">=0.7.3" },
{ name = "minio", specifier = ">=7.2.20" },
{ name = "mmcv", specifier = ">=2.2.0" },
{ name = "moviepy", specifier = "==1.0.3" },
{ name = "np", specifier = ">=1.0.2" },
{ name = "numpy", specifier = "<2" },
{ name = "ollama", specifier = ">=0.6.1" },
{ name = "opencv-python", specifier = ">=4.11.0.86" },
@@ -3605,18 +3545,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/54/85/6ec269b0952ec7e36ba019125982cf11d91256a778c7c3f98a4c5043d283/xxhash-3.6.0-cp312-cp312-win_arm64.whl", hash = "sha256:eae5c13f3bc455a3bbb68bdc513912dc7356de7e2280363ea235f71f54064829", size = 27876, upload-time = "2025-10-02T14:34:54.371Z" },
]
[[package]]
name = "yapf"
version = "0.43.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "platformdirs" },
]
sdist = { url = "https://files.pythonhosted.org/packages/23/97/b6f296d1e9cc1ec25c7604178b48532fa5901f721bcf1b8d8148b13e5588/yapf-0.43.0.tar.gz", hash = "sha256:00d3aa24bfedff9420b2e0d5d9f5ab6d9d4268e72afbf59bb3fa542781d5218e", size = 254907, upload-time = "2024-11-14T00:11:41.584Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/37/81/6acd6601f61e31cfb8729d3da6d5df966f80f374b78eff83760714487338/yapf-0.43.0-py3-none-any.whl", hash = "sha256:224faffbc39c428cb095818cf6ef5511fdab6f7430a10783fdfb292ccf2852ca", size = 256158, upload-time = "2024-11-14T00:11:39.37Z" },
]
[[package]]
name = "yarl"
version = "1.22.0"