138 Commits

Author SHA1 Message Date
zcr
35e791b4e2 更新图形生成工具,优化返回格式并添加新功能 2026-06-15 17:10:04 +08:00
zcr
b14ccab723 修改生产部署 2026-06-15 15:15:03 +08:00
zcr
ddad5c9d2b 修改生产部署 2026-06-15 15:09:56 +08:00
zcr
76d8eb6cdc 修改生产部署 2026-06-15 15:01:22 +08:00
zcr
dbbaa7503c aida agent (基础版)搭建完成 2026-06-15 14:48:17 +08:00
zcr
b602c47fc9 Merge remote-tracking branch 'origin/develop' into research 2026-06-04 16:33:17 +08:00
zcr
ea1b017f75 Merge branch 'master' into develop 2026-06-04 11:42:46 +08:00
zcr
de349a6d20 代码结构化 优化 2026-06-04 11:35:59 +08:00
zcr
e724caec81 1 2026-05-29 15:41:23 +08:00
zcr
96dee8b376 1 2026-05-29 15:40:33 +08:00
zcr
1575321be5 删除不再使用的请求数据和配置文件 2026-05-29 15:38:16 +08:00
zcr
02f78853b9 111 2026-05-29 15:35:03 +08:00
zcr
893f5e87b4 3D 打板部署
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-04-28 17:17:29 +08:00
zcr
c73bfa7e2a 3D 打板部署
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-04-28 17:03:04 +08:00
zcr
ad4db736de 新增nacos 配置 测试
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-04-24 10:17:42 +08:00
zcr
cfbd9e47ac 新增nacos 配置 测试
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-04-23 17:10:22 +08:00
zcr
cc2404831d Merge branch 'develop' 2026-04-17 14:15:57 +08:00
zcr
6892361050 修复design印花部分 overall 模式印花平铺起始从印花图片中心开始
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-04-15 17:36:29 +08:00
zcr
f0b73d5fc1 修复design印花部分 mask_inv_print 提取错误
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-04-15 17:23:00 +08:00
zcr
ea7522a45d Merge branch 'develop'
# Conflicts:
#	app/service/design_fast/utils/synthesis_item.py
#	app/service/prompt_generation/chatgpt_for_translation.py
2026-04-14 10:18:04 +08:00
zcr
7543d6b346 feat: 更新flux2 klein 的输出示例 ; fix:
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-04-14 10:16:30 +08:00
zcr
3ca4003e30 feat: 更新flux2 klein 的输出示例 ; fix: 2026-03-30 17:22:14 +08:00
zcr
59e8a88a01 feat: 更新flux2 klein 的输出示例 ; fix: 2026-03-30 17:14:18 +08:00
zcr
ac6f74438d feat: 更新分割模型参数 ; fix: 2026-03-27 15:10:16 +08:00
zcr
e0d519bfb3 feat: 更新分割模型参数 ; fix: 2026-03-27 15:10:16 +08:00
zcr
3414f2c1aa feat: 更新分割模型参数 ; fix: 2026-03-27 14:59:27 +08:00
zcr
160bf1a6b1 feat: 更新分割模型参数 ; fix: 2026-03-27 14:56:32 +08:00
zcr
79eb3fb859 feat: flux2 增加状态码 ; fix: 2026-03-25 23:17:03 +08:00
zcr
4395d67288 feat: flux2 增加状态码 ; fix: 2026-03-25 23:17:02 +08:00
zcr
674514ec11 feat: brand dna logo生成替换flux2klein ; fix: 2026-03-25 23:16:53 +08:00
zcr
e9ca1d301b feat: 新增flux2klein作为moodboard的localbase 模型 ; fix: 2026-03-25 23:16:48 +08:00
zcr
a4d55fdb14 feat: flux2 增加状态码 ; fix: 2026-03-25 10:29:03 +08:00
zcr
7f2f79d029 feat: flux2 增加状态码 ; fix: 2026-03-24 14:35:39 +08:00
zcr
6d9e96305b feat: brand dna logo生成替换flux2klein ; fix: 2026-03-23 11:21:50 +08:00
zcr
d93c50ce2b feat: 新增flux2klein作为moodboard的localbase 模型 ; fix: 2026-03-23 10:46:16 +08:00
zcr
316c2fef67 feat:
fix: 删除计数中间件
2026-03-13 11:22:57 +08:00
zcr
e25f49a776 feat:
fix: 删除计数中间件
2026-03-13 11:22:12 +08:00
zcr
33b4dd4a7f feat:
fix: 翻译 模型ip更换
2026-03-05 15:20:40 +08:00
zcr
ac8ca4dd46 feat:
fix: 翻译 模型ip更换
2026-03-05 15:17:47 +08:00
zcr
db88d9b813 feat:
fix: 翻译 模型ip更换
2026-03-05 15:13:40 +08:00
zcr
6ea9837f83 feat:
fix: sam 模型ip更换
2026-03-05 15:06:27 +08:00
zcr
7e48420ba7 feat:
fix: sam 模型ip更换
2026-03-05 15:06:19 +08:00
zcr
13002eefda feat:
fix:  others 旋转功能修复
2026-03-05 14:02:01 +08:00
zcr
09e25f423e feat:
fix:  others 旋转功能修复
2026-03-05 14:01:29 +08:00
zcr
bcc82ba065 feat:
fix:  替换项目中所有mmcv的依赖
2026-02-27 15:45:28 +08:00
zcr
dcc88adfc0 feat:
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
fix:  替换项目中所有mmcv的依赖
2026-02-27 15:26:07 +08:00
zcr
e4fd7b2fb9 feat:
fix:  替换项目中所有mmcv的依赖
2026-02-10 13:04:14 +08:00
ba93d33a17 更新 .gitea/workflows/prod_build_scheduled.yaml 2026-02-10 11:44:59 +08:00
292da1de2b 更新 .gitea/workflows/prod_build_manual.yaml 2026-02-10 11:44:48 +08:00
c6ebfae942 更新 .gitea/workflows/ltx_develop_build_manual.yaml 2026-02-10 11:44:36 +08:00
4dd8416911 更新 .gitea/workflows/develop_build_scheduled.yaml 2026-02-10 11:44:27 +08:00
zcr
bafcb68028 feat:
fix:  替换项目中所有mmcv的依赖
2026-02-10 11:34:09 +08:00
zcr
c03b7e263e feat:
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
fix:  替换项目中所有mmcv的依赖
2026-02-10 11:17:31 +08:00
zcr
200414e5ad feat: 停用flux2 img2product 复用sdxl img2product
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
fix:
2026-02-09 17:33:07 +08:00
23a6a30cc4 更新 .gitea/workflows/develop_build_manual.yaml 2026-02-09 15:28:47 +08:00
4d0688afd5 删除 .gitea/workflows/develop_build_commit.yaml 2026-02-09 15:28:22 +08:00
zcr
9a00fce0eb feat: 印花逻辑修改 默认不处理除overall以外所有印花类型
fix:
2026-02-03 16:44:01 +08:00
zcr
4656eeee91 feat: 印花逻辑修改 默认不处理除overall以外所有印花类型
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
fix:
2026-02-03 16:43:33 +08:00
zcr
f017d7e212 feat:
fix: 修复sketch类型为others时 跳过 上印花 导致的尺寸与分割尺寸不一致问题, 修复others分割出后片的问题
2026-02-03 16:23:05 +08:00
zcr
fe25f5878b feat:
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
fix: 修复sketch类型为others时 跳过 上印花 导致的尺寸与分割尺寸不一致问题, 修复others分割出后片的问题
2026-02-03 16:22:47 +08:00
zcr
c1b80c58f1 feat:
fix: 队列名修复
2026-02-02 15:37:31 +08:00
zcr
2cc17a1210 feat:
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
fix: 队列名修复
2026-02-02 15:37:01 +08:00
zcr
be92d48abb feat:
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
fix: 回溯镜像旋转逻辑
2026-01-30 15:45:57 +08:00
zcr
57be559cf2 feat:
fix:  修复类别为other时出现的pipeline item缺失
2026-01-29 16:26:16 +08:00
zcr
f8382f280f feat:
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
fix:  修复类别为other时出现的pipeline item缺失
2026-01-29 16:25:43 +08:00
zcr
c24862507f feat:
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
fix:  slogan 服务迁移
2026-01-28 15:37:03 +08:00
zcr
d5452098f3 feat:
fix: 移除打印
2026-01-27 13:53:21 +08:00
zcr
315e298ba8 feat:
fix:
2026-01-27 13:53:17 +08:00
zcr
ec26c8b507 feat:
fix:  印花overall 角度异常
2026-01-27 13:53:04 +08:00
zcr
e02ca351b6 feat:
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
fix:  印花overall 角度异常
2026-01-27 13:42:34 +08:00
zcr
c987f498bc feat:
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
fix:
2026-01-27 11:28:36 +08:00
zcr
3aa8dfa0f4 feat:
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
fix: 移除打印
2026-01-27 10:12:23 +08:00
zcr
265f4de50e feat:
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
fix: 更新端口
2026-01-26 16:32:30 +08:00
zcr
a996a1853d feat:
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
fix: 更新端口
2026-01-26 16:11:10 +08:00
zcr
1cbd019ffd feat: 更新翻译模型
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
fix:
2026-01-26 15:56:42 +08:00
zcr
e2a49e2f3a feat: 新增to product img flux2 版,停用sdxl版
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
fix:
2026-01-26 15:26:15 +08:00
zcr
66037c94e6 feat: 新增to product img flux2 版,停用sdxl版
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
fix:
2026-01-26 15:23:49 +08:00
zcr
754e8d7735 feat: 新增to product img flux2 版,停用sdxl版
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
fix:
2026-01-26 15:21:51 +08:00
zcr
cdaeb6daac feat: 新增to product img flux2 版,停用sdxl版
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
fix:
2026-01-26 15:19:28 +08:00
zcr
863d9287dc fix: 参数对齐
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
(cherry picked from commit ddef6af1cf)
2026-01-26 14:56:49 +08:00
zcr
ddef6af1cf fix: 参数对齐 2026-01-26 14:49:57 +08:00
zcr
fdffb1e724 Merge branch 'develop' 2026-01-24 22:05:57 +08:00
zcr
ecf10611c2 fix: merge 模式下 镜像和旋转功能与前端对其
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-24 14:43:10 +08:00
zcr
f78809b22a fix: merge 模式下 镜像和旋转功能与前端对其
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-24 14:03:35 +08:00
63a2f5e007 更新 .gitea/workflows/prod_build_scheduled.yaml 2026-01-24 03:20:19 +08:00
aeb67f366a 更新 .gitea/workflows/prod_build_manual.yaml 2026-01-24 03:20:03 +08:00
zcr
c244e313ae Merge branch 'develop' 2026-01-24 03:14:24 +08:00
zcr
15934085e0 fix: 修复design merge 模式 ,旋转sketch位置计算错误
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-24 03:03:26 +08:00
zcr
40b41d02a4 fix: 修复design merge 模式 ,旋转sketch位置计算错误
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-24 03:01:34 +08:00
1a1fd46f81 添加 .gitea/workflows/prod_build_manual.yaml 2026-01-24 02:55:56 +08:00
zcr
dcd8e26f0f fix: 修复design merge 模式 ,旋转sketch位置计算错误 2026-01-24 02:46:52 +08:00
zcr
fd94a3b4f0 Merge branch 'develop' 2026-01-24 02:44:51 +08:00
zcr
682c589238 fix: 修复design merge 模式 ,旋转sketch位置计算错误
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-24 02:44:13 +08:00
3f9309235a 2026.01.23 生产部署
All checks were successful
定时 AiDA python prod 分支构建部署 / scheduled_deploy (push) Successful in 2m20s
2026-01-23 21:06:19 +08:00
zcr
a578aa4fc5 暂时移除design 缓存
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-23 18:09:21 +08:00
zcr
ebd665b241 Merge branch 'develop'
# Conflicts:
#	app/core/config.py
2026-01-23 17:37:49 +08:00
zcr
ec649152e3 移除keypoint 缓存
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-23 17:34:51 +08:00
833e1bc924 暂时停用定时部署 2026-01-23 10:41:43 +08:00
zcr
7ed5911336 服务迁移测试
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-22 13:41:47 +08:00
zcr
b09538e294 feat: 新增design模式 merge,回参增加mask
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-15 14:13:56 +08:00
zcr
313863a6a7 fix: design 预处理 读取四通道图片背景变黑问题
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-13 15:36:28 +08:00
zcr
9ca1a2ba1f fix: design 单品未传design_type
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-13 14:58:31 +08:00
litianxiang
fb46a9521d Merge remote-tracking branch 'origin/develop' into dev-ltx
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-13 13:57:28 +08:00
litianxiang
b90688f835 更改增量更新日志级别 2026-01-13 13:57:15 +08:00
zcr
7e30779aec feat: seg any thing 新增box模式
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-13 12:43:30 +08:00
zcr
f7294f5966 feat: seg any thing 新增box模式
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-13 12:32:18 +08:00
zcr
0ac5a4e0a8 Merge remote-tracking branch 'origin/develop' into develop
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-12 16:18:15 +08:00
zcr
40b57b749c feat: 新增design模式 merge,前端CV python 合成 2026-01-12 16:18:04 +08:00
litianxiang
b8a538a8a1 fix:增量更新向量问题修改
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-12 13:59:06 +08:00
litianxiang
29b4f43a27 debug:推荐接口
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-12 13:34:56 +08:00
litianxiang
69dc20207d debug:推荐接口
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-12 13:03:58 +08:00
litianxiang
18979af604 debug:推荐接口返回redis值
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-12 13:01:26 +08:00
litianxiang
74406f9be4 推荐接口更新向量接口注册
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-12 11:59:01 +08:00
litianxiang
df99e3ac76 新增查看redis内容接口
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-12 11:51:37 +08:00
litianxiang
19346c2eb7 Merge remote-tracking branch 'origin/develop' into dev-ltx
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-12 09:51:52 +08:00
litianxiang
2af9cbfe78 fix:推荐接口 2026-01-12 09:49:07 +08:00
2b7e4013ee 更新 .gitea/workflows/develop_build_manual.yaml
All checks were successful
定时 AiDA python develop 分支构建部署 / scheduled_deploy (push) Successful in 2m14s
2026-01-08 10:23:45 +08:00
5b2bb3ce7c 添加 .gitea/workflows/ltx_develop_build_manual.yaml
All checks were successful
定时 AiDA python develop 分支构建部署 / scheduled_deploy (push) Successful in 2m15s
2025-12-29 11:01:49 +08:00
6739a92d28 2025.12.19 生产部署
All checks were successful
定时 AiDA python prod 分支构建部署 / scheduled_deploy (push) Successful in 22s
定时 AiDA python develop 分支构建部署 / scheduled_deploy (push) Successful in 28s
2025-12-19 17:47:54 +08:00
f23b99d326 更新 .gitea/workflows/prod_build_scheduled.yaml 2025-12-19 17:46:35 +08:00
10d41cd32f 新增生产部署actions文件 2025-12-19 17:46:03 +08:00
zcr
bb7b85bfb8 Merge branch 'develop'
All checks were successful
定时 AiDA python develop 分支构建部署 / scheduled_deploy (push) Successful in 15s
# Conflicts:
#	.gitea/workflows/develop_build_scheduled.yaml
#	app/service/design_fast/design_generate.py
2025-12-16 14:37:05 +08:00
6ecb6be59c 更新 .gitea/workflows/develop_build_scheduled.yaml
All checks were successful
定时 AiDA python develop 分支构建部署 / scheduled_deploy (push) Successful in 17s
2025-11-28 17:42:22 +08:00
64285cd5f3 更新 .gitea/workflows/develop_build_scheduled.yaml 2025-11-28 17:41:22 +08:00
fe6a5fb029 更新 .gitea/workflows/develop_build_scheduled.yaml
All checks were successful
定时 AiDA python develop 分支构建部署 / scheduled_deploy (push) Successful in 13s
2025-11-28 17:31:29 +08:00
5217847d49 上传文件至「.gitea/workflows」
All checks were successful
定时 AiDA python develop 分支构建部署 / scheduled_deploy (push) Successful in 14s
2025-11-28 17:23:07 +08:00
0a9fc51310 更新 .gitea/workflows/develop_build_commit.yaml 2025-11-28 17:19:47 +08:00
cf052f9632 上传文件至「.gitea/workflows」 2025-11-28 17:18:52 +08:00
19a8ea9a93 更新 .gitea/workflows/develop_build_commit.yaml 2025-11-28 17:11:35 +08:00
09ff2f1ab7 更新 .gitea/workflows/develop_build_commit.yaml 2025-11-28 17:10:04 +08:00
109a23197a 上传文件至「.gitea/workflows」 2025-11-28 17:02:38 +08:00
zhh
2135a180be feat(新功能):
fix(修复bug):  删除java端debug callback api url
docs(文档变更):
refactor(重构):
test(增加测试):
2025-11-21 23:17:53 +08:00
zhh
09032c0564 Merge branch 'develop'
# Conflicts:
#	app/service/design_fast/design_generate.py
2025-11-21 23:01:34 +08:00
zhh
167faa10c8 feat(新功能): fix(修复bug): 取消design-v2 java端测试接口(重构): test(增加测试): 2025-09-27 00:02:23 +08:00
zhh
0a048bf37f Merge branch 'develop'
# Conflicts:
#	app/service/design_fast/design_generate.py
2025-09-26 23:31:42 +08:00
zhh
05045dda76 feat(新功能): fix(修复bug): : refactor(重构): test(增加测试): 徐佩design测试 2025-09-23 11:38:33 +08:00
zhh
30f9a99df2 Merge branch 'develop'
# Conflicts:
#	app/service/design_fast/pipeline/split.py
2025-09-22 17:56:18 +08:00
zhh
3932b8359a feat(新功能):
fix(修复bug):
docs(文档变更):
refactor(重构):
test(增加测试):  mask 使用原尺寸测试
2025-09-17 16:43:26 +08:00
83 changed files with 5715 additions and 3373 deletions

View File

@@ -1,2 +1,6 @@
seg_cache
test
test
.venv
__pycache__/
*.pyc
.git/

View File

@@ -7,7 +7,7 @@ jobs:
runs-on: ubuntu-latest
env:
REMOTE_DEPLOY_PATH: /workspace/Trinity/Fastapi_AiDA_Trinity_Dev
REMOTE_DEPLOY_PATH: /workspace/AiDA_Workspace/Python_Server_Workspace/Dev
steps:
- name: 1.检出代码
@@ -35,6 +35,4 @@ jobs:
cd ${{ env.REMOTE_DEPLOY_PATH }}
docker-compose down 2>&1
docker-compose up -d --build --remove-orphans 2>&1
docker image prune -f 2>&1
docker-compose up -d 2>&1

View File

@@ -1,15 +1,15 @@
name: 定时 AiDA python develop 分支构建部署
on:
# 使用 schedule 触发器,遵循标准的 Cron 格式 (分钟 小时-8 日期 月份 星期)
schedule:
- cron: '30 9 * * *'
# schedule:
# - cron: '30 9 * * *'
jobs:
scheduled_deploy:
runs-on: ubuntu-latest
env:
REMOTE_DEPLOY_PATH: /workspace/Trinity/Fastapi_AiDA_Trinity_Dev
REMOTE_DEPLOY_PATH: /workspace/AiDA_Workspace/Python_Server_Workspace/Dev
steps:
- name: 1.检出代码

View File

@@ -1,23 +1,19 @@
name: git commit AiDA python develop 分支构建部署
name: 手动 AiDA python develop 分支构建部署
on:
workflow_dispatch:
push:
branches:
- develop
jobs:
scheduled_deploy:
runs-on: ubuntu-latest
if: "contains(github.event.head_commit.message, '[run build]')"
env:
REMOTE_DEPLOY_PATH: /workspace/Trinity/Fastapi_AiDA_Trinity_Dev
REMOTE_DEPLOY_PATH: /workspace/AiDA_Workspace/Python_Server_Workspace/Dev
steps:
- name: 1.检出代码
uses: actions/checkout@v4
with:
ref: 'develop'
ref: 'dev-ltx'
- name: 2.复制文件到服务器
uses: appleboy/scp-action@v0.1.7
@@ -28,7 +24,7 @@ jobs:
source: "."
target: ${{ env.REMOTE_DEPLOY_PATH }}
- name: Restart Docker containers
- name: 3.重启docker-compose
uses: appleboy/ssh-action@v0.1.10
with:
host: ${{ secrets.SERVER_HOST }}

View File

@@ -0,0 +1,40 @@
name: 定时 AiDA python prod 分支构建部署
on:
workflow_dispatch:
jobs:
scheduled_deploy:
runs-on: ubuntu-latest
env:
REMOTE_DEPLOY_PATH: /workspace/AiDA_Workspace/Python_Server_Workspace/Prod
steps:
- name: 1.检出代码
uses: actions/checkout@v4
with:
ref: 'master'
- name: 2.复制文件到服务器
uses: appleboy/scp-action@v0.1.7
with:
host: ${{ secrets.SERVER_HOST }}
username: ${{ secrets.SERVER_USER }}
password: ${{ secrets.SERVER_PASSWORD }}
source: "."
target: ${{ env.REMOTE_DEPLOY_PATH }}
- name: Restart Docker containers
uses: appleboy/ssh-action@v0.1.10
with:
host: ${{ secrets.SERVER_HOST }}
username: ${{ secrets.SERVER_USER }}
password: ${{ secrets.SERVER_PASSWORD }}
script: |
# 进入项目目录
cd ${{ env.REMOTE_DEPLOY_PATH }}
docker-compose down 2>&1
docker-compose up -d 2>&1
docker image prune -f 2>&1

View File

@@ -0,0 +1,42 @@
name: 定时 AiDA python prod 分支构建部署
on:
# 使用 schedule 触发器,遵循标准的 Cron 格式 (分钟 小时-8 日期 月份 星期)
schedule:
- cron: '07 13 23 1 *'
jobs:
scheduled_deploy:
runs-on: ubuntu-latest
env:
REMOTE_DEPLOY_PATH: /workspace/AiDA_Workspace/Python_Server_Workspace/Prod
steps:
- name: 1.检出代码
uses: actions/checkout@v4
with:
ref: 'master'
- name: 2.复制文件到服务器
uses: appleboy/scp-action@v0.1.7
with:
host: ${{ secrets.SERVER_HOST }}
username: ${{ secrets.SERVER_USER }}
password: ${{ secrets.SERVER_PASSWORD }}
source: "."
target: ${{ env.REMOTE_DEPLOY_PATH }}
- name: Restart Docker containers
uses: appleboy/ssh-action@v0.1.10
with:
host: ${{ secrets.SERVER_HOST }}
username: ${{ secrets.SERVER_USER }}
password: ${{ secrets.SERVER_PASSWORD }}
script: |
# 进入项目目录
cd ${{ env.REMOTE_DEPLOY_PATH }}
docker-compose down 2>&1
docker-compose up -d 2>&1
docker image prune -f 2>&1

3
.gitignore vendored
View File

@@ -150,4 +150,5 @@ app/logs/*
*.avi
*.json
*.env*
config.backup.py
config.backup.py
*.pckl

View File

@@ -19,4 +19,4 @@ RUN apt-get update && apt-get install -y \
RUN uv sync --frozen --no-cache
# Run the application.
CMD ["/app/.venv/bin/fastapi", "run", "app/main.py", "--port", "80", "--host", "0.0.0.0"]
CMD ["/app/.venv/bin/uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "80", "--workers", "4"]

View File

@@ -20,7 +20,6 @@
$ conda activate trinity_client_aida
$ pip install -r requirements.txt
$ conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia -y
$ pip install mmcv==1.4.2 -f https://download.openmmlab.com/mmcv/dist/cu117/torch1.13/index.html
1. 启动服务器

View File

@@ -4,6 +4,7 @@ import logging
import requests
from fastapi import APIRouter, HTTPException, BackgroundTasks
from app.core.config import settings
from app.schemas.design import DesignModel, ModelProgressModel, DesignStreamModel, SAMRequestModel
from app.schemas.response_template import ResponseModel
from app.service.design_fast.design_generate import design_generate, design_generate_v2
@@ -27,6 +28,15 @@ def design(request_data: DesignModel):
- **mask_url** 非空"mask_url" -> 区域透明
- **transpose** 镜像模式 ,:"top_bottom""left_right"
- **rotate** 45,
- ** design 参数变更:
design detail 请求参数中 basic -> preview_submit 替换为design_type 可选参数 default ,merge (移除preview和submit)
design_type 参数说明:
defuault模式下 请求参数不变
merge模式下 items -> 每个item需要新增 merge_image_path , merge_image_path为前端处理 print color等操作后的单件结果图
**
- 创建一个具有以下参数的请求体:
示例参数:
```json
@@ -61,7 +71,7 @@ def design(request_data: DesignModel):
]
},
"layer_order": true,
"preview_submit": "preview",
"design_type": "preview",
"scale_bag": 0.7,
"scale_earrings": 0.16,
"self_template": true,
@@ -385,7 +395,11 @@ async def seg_anything(request_data: SAMRequestModel):
通过传入图片路径和点击的点坐标,返回分割后的掩码数据。
### 参数说明:
- **bucket**: minio bucket name
- **object_name**: minio object name
- **image_path**: 图片在服务器或云端的相对路径。
- **type**: 推理类型
- **box**: 框选矩形点位信息
- **points**: 交互点的坐标列表。每个点为 [x, y] 像素格式。
- **labels**: 坐标点的属性标签,必须与 points 长度一致:
- 1: **前景点** (代表想要分割出的区域)
@@ -393,16 +407,29 @@ async def seg_anything(request_data: SAMRequestModel):
### 请求体示例:
```json
point
{
"bucket": "test",
"object_name": "7068-400a-ac94-c01647fa5f6f.png",
"image_path": "aida-users/89/sketch/4e8fe37d-7068-400a-ac94-c01647fa5f6f.png",
"type":"point",
"points": [[310, 403], [493, 375], [261, 266], [404, 484]],
"labels": [1, 1, 0, 1]
}
box
{
"bucket": "test",
"object_name": "7068-400a-ac94-c01647fa5f6f.png",
"image_path": "aida-users/89/sketch/4e8fe37d-7068-400a-ac94-c01647fa5f6f.png",
"type":"box",
"box": [350, 286, 544, 520]
}
```
"""
try:
logger.info(f"seg_anything request item is : @@@@@@:{json.dumps(request_data.dict(), indent=4)}")
data = requests.post("http://10.1.1.240:10075/predict", json=request_data.dict())
data = requests.post(f"http://{settings.B_4_X_4090_SERVICE_HOST}:10075/predict", json=request_data.dict())
logger.info(f"seg_anything response @@@@@@:{json.dumps(json.loads(data.content), indent=4)}")
return ResponseModel(data=json.loads(data.content))
except Exception as e:

View File

@@ -0,0 +1,47 @@
import json
import logging
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from app.schemas.fashion_agent import FashionAgentRequest
from app.service.fashion_agent.service import FashionAgentService
router = APIRouter()
logger = logging.getLogger()
@router.post("/agent/stream")
async def fashion_agent_stream(request_item: FashionAgentRequest):
"""
服装设计 Agent(流式输出)
- **message**: 用户输入的消息(必填)
- **user_id**: 用户ID,默认 "agent"
- **enable_thinking**: 是否开启思考模式
- **call_print**: 是否直接调用 print 生成印花
- **print_need_prompt_generation**: print 是否需要 LLM 生成 prompt
- **call_logo**: 是否直接调用 logo 生成装饰图案
- **call_sketch**: 是否直接调用 sketch 生成草图
- **sketch_need_prompt_generation**: sketch 是否需要 LLM 生成 prompt
- **call_design**: 是否直接调用 design 生成设计系列
- **design_request_data**: design 请求参数(objects, process_id, requestId, callback_url)
- **call_trending**: 是否直接调用 trending 趋势分析
- **call_explor**: 是否直接调用 explorer 灵感探索
- **provider**: 图片源 (pexels/unsplash),默认 unsplash
返回 SSE 事件流:
- **tools** 事件:工具调用的 started/finished 状态
- **custom** 事件:design 工具的逐个生成结果
- **Done** 事件:流结束标记
"""
try:
logger.info(f"fashion_agent stream request: {json.dumps(request_item.model_dump(), indent=4, ensure_ascii=False)}")
service = FashionAgentService()
return StreamingResponse(
service.run_stream(request_item),
media_type="text/event-stream",
)
except Exception as e:
logger.warning(f"fashion_agent stream exception: {e}")
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -1,9 +1,12 @@
import json
import logging
import httpx
import requests
from fastapi import APIRouter, BackgroundTasks, HTTPException
from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel, GenerateRelightImageModel, GenerateMultiViewModel, BatchGenerateProductImageModel, BatchGenerateRelightImageModel, AgentTollGenerateImageModel
from app.core.config import settings
from app.schemas.generate_image import GenerateImageModel, GenerateProductImageModel, GenerateSingleLogoImageModel, GenerateRelightImageModel, GenerateMultiViewModel, BatchGenerateProductImageModel, BatchGenerateRelightImageModel, AgentTollGenerateImageModel, Flux2ToProductImgModel, GenerateSloganImageModel, GenerateImageFlux2KleinModel
from app.schemas.pose_transform import BatchPoseTransformModel
from app.schemas.response_template import ResponseModel
from app.service.generate_batch_image.service import start_product_batch_generate, start_relight_batch_generate, start_pose_transform_batch_generate
@@ -20,6 +23,61 @@ logger = logging.getLogger()
'''generate image'''
# flux2 klein
@router.post("/generate_image_flux2_klein")
async def generate_image_flux2_klein(request_item: GenerateImageFlux2KleinModel):
"""
创建一个具有以下参数的请求体:
- **bucket_name**: OSS桶名 (必填)
- **object_name**: OSS对象名文件路径(必填)
- **width**: 图片宽度默认1024像素 (非必填,1024)
- **height**: 图片高度默认1024像素 (非必填,默认1024)
- **prompt**: 文本提示词,用于模型推理等场景 (非必填,默认"")
- **steps**: 推理步数,控制模型生成过程的迭代次数 (非必填,默认4)
- **guidance**: 引导系数,调节提示词对生成结果的影响程度 (非必填,默认 4.0 )
### 示例参数:
```
{
"bucket_name": "aida-users",
"object_name": "89/moodboard/5fdc698c-cb9b-4b36-afa9ce4-1-89.png",
"prompt": "a single item of sketch of dress, 4k, white background"
}
```
### 输出示例:
```
{
"code": 200,
"msg": "OK!",
"data": {
"output_path": "aida-users/89/moodboard/5fdc698c-cb9b-4b36-afa9ce4-1-89.png"
}
}
```
"""
try:
logger.info(f"generate_image_flux2_gen_img request: {json.dumps(request_item.model_dump(), indent=4)}")
async with httpx.AsyncClient(timeout=120) as client:
resp = await client.post(
f"http://{settings.FLUX2_GEN_IMG_MODEL_URL}/predict",
json=request_item.model_dump(),
)
if resp.status_code == 200:
result = resp.json()
logger.info(f"flux2_gen_img response: {json.dumps(result, indent=4)}")
return ResponseModel(data=result)
else:
error = resp.json()
logger.info(f"flux2_gen_img response: {json.dumps(error, indent=4)}")
return ResponseModel(data=error, msg="ERROR!", code=500)
except Exception as e:
logger.warning(f"generate_image_flux2_gen_img Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
# sdxl
@router.post("/generate_image")
def generate_image(request_item: GenerateImageModel, background_tasks: BackgroundTasks):
"""
@@ -154,6 +212,62 @@ def generate_single_logo_image(tasks_id: str):
return ResponseModel(data=data['data'])
"""slogan """
@router.post("/generate_slogan")
async def generate_slogan(request_data: GenerateSloganImageModel):
"""
### 请求体示例:
```json
{
"num_point": 16,
"image_url": "aida-slogan/6886785f-0aac-4052-b6fd-7ae20a841d8d.png",
"prompt": "123",
"tasks_id": "string-89"
}
```
"""
try:
logger.info(f"generate_slogan request item is : @@@@@@:{json.dumps(request_data.dict(), indent=4)}")
data = requests.post(f"http://{settings.A6000_SERVICE_HOST}:10020/api/slogan", json=request_data.dict())
logger.info(f"generate_slogan response @@@@@@:{json.dumps(json.loads(data.content), indent=4)}")
return ResponseModel(data=json.loads(data.content))
except Exception as e:
logger.warning(f"generate_slogan Run Exception @@@@@@:{e}")
"""product image flux2.0"""
# @router.post("/img_to_product")
# async def img_to_product(request_data: Flux2ToProductImgModel):
# """
# 创建一个具有以下参数的请求体:
# - **tasks_id**: 任务id 用于取消生成任务和获取生成结果
# - **prompt**: 想要生成图片的描述词
# - **image_path**: 被生成图片的S3或minio url地址
# - **infer_step**: 推理步数
#
# ### 请求体示例:
# ```json
# point
# {
# "prompt": "Create realistic studio photo with real people model standing and wearing this garment, in white studio, Keep original model if present, or generate appropriate model, Standing pose, facing camera.",
# "image_path":"aida-results/result_38151e0a-f83b-11f0-89f6-0242ac130002.png",
# "infer_step":4,
# "tasks_id":"123456-123"
# }
# ```
# """
# try:
# logger.info(f"img_to_product request item is : @@@@@@:{json.dumps(request_data.dict(), indent=4)}")
# data = requests.post(f"http://{settings.A6000_SERVICE_HOST}:10090/api/v1/to_product", json=request_data.dict())
# logger.info(f"img_to_product response @@@@@@:{json.dumps(json.loads(data.content), indent=4)}")
# return ResponseModel(data=json.loads(data.content))
# except Exception as e:
# logger.warning(f"img_to_product Run Exception @@@@@@:{e}")
'''product image'''
@@ -178,7 +292,7 @@ def generate_product_image(request_item: GenerateProductImageModel, background_t
}
"""
try:
logger.info(f"generate_product_image request item is : @@@@@@:{json.dumps(request_item.dict(),indent=4)}")
logger.info(f"generate_product_image request item is : @@@@@@:{json.dumps(request_item.dict(), indent=4)}")
service = GenerateProductImage(request_item)
background_tasks.add_task(service.get_result)
except Exception as e:

View File

@@ -137,10 +137,13 @@ router = APIRouter()
# logger.error(f"推荐失败: {str(e)}", exc_info=True)
# raise HTTPException(status_code=500, detail=str(e))
# @router.on_event("startup")
@router.on_event("startup")
async def startup_event():
"""启动时初始化增量监听任务"""
try:
# 屏蔽 apscheduler 的 INFO 日志
logging.getLogger("apscheduler").setLevel(logging.WARNING)
# 确保 Milvus 集合已创建(若已存在则直接返回)
try:
create_collection()
@@ -172,4 +175,32 @@ async def recommend(
return [path]
except Exception as e:
logger.error("新版推荐接口失败 [user=%s, category=%s]: %s", user_id, category, e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.get("/redis/user_pref")
async def get_all_user_preferences():
"""
获取所有以 user_pref 为前缀的 Redis key 信息
"""
try:
from app.service.utils.redis_utils import Redis
from app.service.recommendation_system.config import REDIS_KEY_USER_PREF_PREFIX
# 扫描所有匹配 user_pref:* 的 key
pattern = f"{REDIS_KEY_USER_PREF_PREFIX}:*"
keys = Redis.scan_keys(pattern)
# 直接返回所有 key 和原始 value
result = {}
for key in keys:
# 读取对应的值
value = Redis.read(key)
if value:
result[key] = value
return result
except Exception as e:
logger.error("获取用户偏好数据失败: %s", e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -4,12 +4,15 @@ from app.api import api_brand_dna
from app.api import api_clothing_seg
from app.api import api_design
from app.api import api_design_pre_processing
from app.api import api_fashion_agent
from app.api import api_generate_image
from app.api import api_mannequins_edit
from app.api import api_pose_transform
from app.api import api_precompute
from app.api import api_prompt_generation
from app.api import api_recommendation
from app.api import api_test
from app.api import api_sketch_to_garment
router = APIRouter()
@@ -21,9 +24,12 @@ router.include_router(api_prompt_generation.router, tags=['prompt_generation'],
router.include_router(api_design_pre_processing.router, tags=['design_pre_processing'], prefix="/api")
router.include_router(api_brand_dna.router, tags=['api_brand_dna'], prefix="/api")
router.include_router(api_recommendation.router, tags=['api_recommendation'], prefix="/api")
router.include_router(api_precompute.router, tags=['api_precompute'], prefix="/api")
router.include_router(api_mannequins_edit.router, tags=['api_mannequins_edit'], prefix="/api")
router.include_router(api_pose_transform.router, tags=['api_pose_transform'], prefix="/api")
router.include_router(api_clothing_seg.router, tags=['api_clothing_seg'], prefix="/api")
router.include_router(api_sketch_to_garment.router, tags=['sketch_to_garment'], prefix="/api")
router.include_router(api_fashion_agent.router, tags=['fashion_agent'], prefix="/api")
"""停用"""
# from app.api import api_chat_robot

View File

@@ -0,0 +1,104 @@
import json
import logging
from fastapi import APIRouter, HTTPException
from app.schemas.response_template import ResponseModel
from app.schemas.sketch_to_garment_schemas import SketchToGarmentModel
from app.service.sketch2garment.server import submit_sketch_to_garment_task
logger = logging.getLogger()
router = APIRouter()
@router.post("/sketch_to_garment")
def sketch_to_garment_api(request_item: SketchToGarmentModel):
"""
### 接口说明:
将图片转换为3D模型异步处理。接口接收请求后立即返回任务ID后台通过 Celery 处理,处理完成后结果会通过 RabbitMQ 发送。
### 参数说明:
- **input_image_path**: 输入图片路径
- **bucket_name**: bucket name
- **user_id**: 用户id
- **callback_url**: 回调url
- **task_id**: 任务id
- **model**: 转换模式 文本和图片 ,默认只有图片
### 请求体示例:
**单张图片模式:**
```json
{
"input_image_path": "test/53d38bd5-f77b-4034-ada2-45f1e2ebe00c.png",
"bucket_name": "test",
"user_id": "string-456",
"callback_url": "http://18.167.251.121:10015/api/image/webhook/img-to-3d",
"task_id": "string12",
"model": "picture"
}
```
### 输出示例:
```json
{
"code": 200,
"msg": "OK!",
"data": {
"state": "success",
"task_id": "string12",
"message": "任务已成功提交,正在后台处理..."
}
}
```
### 错误输出
参考文档: https://platform.tripo3d.ai/docs/error-handling
```json
{
"code": 500,
"message": "You dont have enough credit to create this task",
"data": {
"status": "fail",
"task_id": "123",
"message": "You dont have enough credit to create this task",
"error": str(e)
}
}
```
回调请求参数例子:
```json
{
"task_id": "string12",
"status": "success",
"result": {
"pattern": "test/string-456/pattern_making/now_string-456_pattern.png",
"texture": "test/string-456/pattern_making/now_string-456_texture.png",
"glb": "test/string-456/pattern_making/now_string-456_sim.glb",
"texture_fabric": "test/string-456/pattern_making/now_string-456_texture_fabric.png"
}
}
```
"""
try:
logger.info(f"sketch_to_garment request item is : @@@@@@:{json.dumps(request_item.model_dump(), indent=4)}")
result = submit_sketch_to_garment_task(
task_id=request_item.task_id,
callback_url=request_item.callback_url,
bucket_name=request_item.bucket_name,
input_image_path=request_item.input_image_path,
user_id=request_item.user_id,
model=request_item.model
)
result = {
"state": "success",
"task_id": request_item.task_id,
"message": "任务已成功提交,正在后台处理...",
}
state_code = 200
return ResponseModel(data=result, code=state_code)
except Exception as e:
logger.warning(f"super_resolution Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))

View File

@@ -1,235 +0,0 @@
import os
import pika
from dotenv import load_dotenv
from pydantic import BaseSettings
BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../'))
load_dotenv(os.path.join(BASE_DIR, '.env'))
class Settings(BaseSettings):
PROJECT_NAME: str = 'FASTAPI BASE'
SECRET_KEY: str = ''
API_PREFIX: str = ''
BACKEND_CORS_ORIGINS: list[str] = ['*']
DATABASE_URL: str = ''
ACCESS_TOKEN_EXPIRE_SECONDS: int = 60 * 60 * 24 * 7 # Token expired after 7 days
SECURITY_ALGORITHM: str = 'HS256'
LOGGING_CONFIG_FILE: str = os.path.join(BASE_DIR, 'logging_env.py')
OSS = "minio"
DEBUG = False
if DEBUG:
LOGS_PATH = "logs/"
CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv"
SEG_CACHE_PATH = "../seg_cache/"
POSE_TRANSFORM_VIDEO_PATH = "../pose_transform_video/"
RECOMMEND_PATH_PREFIX = "service/recommend/"
CHROMADB_PATH = "./chromadb/"
else:
LOGS_PATH = "app/logs/"
CATEGORY_PATH = "app/service/attribute/config/descriptor/category/category_dis.csv"
SEG_CACHE_PATH = "/seg_cache/"
POSE_TRANSFORM_VIDEO_PATH = "/pose_transform_video/"
RECOMMEND_PATH_PREFIX = "app/service/recommend/"
CHROMADB_PATH = "/chromadb/"
# RABBITMQ_ENV = "" # 生产环境
RABBITMQ_ENV = os.getenv("RABBITMQ_ENV", "-dev")
# RABBITMQ_ENV = "-local" # 本地测试环境
if RABBITMQ_ENV == "-dev":
JAVA_STREAM_API_URL = f"https://develop.api.aida.com.hk/api/third/party/receiveDesignResults"
elif RABBITMQ_ENV == "-prod":
JAVA_STREAM_API_URL = f"https://api.aida.com.hk/api/third/party/receiveDesignResults"
settings = Settings()
# minio 配置
MINIO_URL = "www.minio-api.aida.com.hk"
MINIO_ACCESS = 'vXKFLSJkYeEq2DrSZvkB'
MINIO_SECRET = 'uKTZT3x7C43WvPN9QTc99DiRkwddWZrG9Uh3JVlR'
MINIO_SECURE = True
# S3 配置
S3_ACCESS_KEY = "AKIAVD3OJIMF6UJFLSHZ"
S3_AWS_SECRET_ACCESS_KEY = "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"
S3_REGION_NAME = "ap-east-1"
# redis 配置
REDIS_HOST = "10.1.1.240"
REDIS_PORT = "6379"
REDIS_DB = "2"
# rabbitmq config
RABBITMQ_PARAMS = {
"host": "18.167.251.121",
"port": 5672,
"credentials": pika.credentials.PlainCredentials(username='rabbit', password='123456'),
"virtual_host": "/"
}
# milvus 配置
MILVUS_URL = "http://10.1.1.240:19530"
MILVUS_TOKEN = "root:Milvus"
MILVUS_ALIAS = "default"
MILVUS_TABLE_KEYPOINT = "keypoint_cache_2"
MILVUS_TABLE_SEG = "seg_cache"
# Mysql 配置
DB_HOST = '18.167.251.121' # 数据库主机地址
# DB_PORT = int( 33006)
DB_PORT = 33008 # 数据库端口
DB_USERNAME = 'aida_con_python' # 数据库用户名
DB_PASSWORD = '123456' # 数据库密码
DB_NAME = 'aida' # 数据库库名
# openai
os.environ['SERPAPI_API_KEY'] = "a793513017b0718db7966207c31703d280d12435c982f1e67bbcbffa52e7632c"
OPENAI_STREAM = True
BUFFER_THRESHOLD = 6 # must be even number
SINGLE_TOKEN_THRESHOLD = 200
TOKEN_THRESHOLD = 600
OPENAI_TEMPERATURE = 0
# OPENAI_API_KEY = "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"
OPENAI_API_KEY = "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"
OPENAI_MODEL = "gpt-3.5-turbo-0613"
OPENAI_MODEL_LIST = {"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-4-0314",
"gpt-4-32k-0314",
"gpt-4-0613",
"gpt-4-32k-0613", }
# SR service config
SR_MODEL_NAME = "super_resolution"
SR_TRITON_URL = "10.1.1.240:10031"
SR_MINIO_BUCKET = "aida-users"
SR_RABBITMQ_QUEUES = f"SuperResolution{RABBITMQ_ENV}"
# GenerateImage service config
FAST_GI_MODEL_URL = '10.1.1.243:10011'
FAST_GI_MODEL_NAME = 'stable_diffusion_xl'
GI_MODEL_URL = '10.1.1.240:10061'
GI_MODEL_NAME = 'flux'
GMV_MODEL_URL = '10.1.1.243:10081'
GMV_MODEL_NAME = 'multi_view'
GMV_RABBITMQ_QUEUES = f"GenerateMultiView{RABBITMQ_ENV}"
GI_MINIO_BUCKET = "aida-users"
GI_RABBITMQ_QUEUES = f"GenerateImage{RABBITMQ_ENV}"
GI_SYS_IMAGE_URL = "aida-sys-image/generate_image/white_image.jpg"
# SLOGAN service config
SLOGAN_RABBITMQ_QUEUES = f"Slogan{RABBITMQ_ENV}"
# Generate Single Logo service config
GSL_MODEL_URL = '10.1.1.243:10041'
GSL_MINIO_BUCKET = "aida-users"
GSL_MODEL_NAME = 'stable_diffusion_xl_transparent'
GEN_SINGLE_LOGO_RABBITMQ_QUEUES = f"GenSingleLogo{RABBITMQ_ENV}"
# Generate Product service config
# GPI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"ToProductImage{RABBITMQ_ENV}")
# GPI_MODEL_NAME_OVERALL = 'sdxl_ensemble_all'
# GPI_MODEL_URL = '10.1.1.243:10051'
# Generate Product service config 旧版product img 模型
GPI_RABBITMQ_QUEUES = f"ToProductImage{RABBITMQ_ENV}"
BATCH_GPI_RABBITMQ_QUEUES = f"BatchToProductImage{RABBITMQ_ENV}"
GPI_MODEL_NAME_OVERALL = 'diffusion_ensemble_all'
GPI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_cnet'
GPI_MODEL_URL = '10.1.1.243:10051'
# Generate Single Logo service config
GRI_RABBITMQ_QUEUES = f"Relight{RABBITMQ_ENV}"
BATCH_GRI_RABBITMQ_QUEUES = f"BatchRelight{RABBITMQ_ENV}"
GRI_MODEL_NAME_OVERALL = 'diffusion_relight_ensemble'
GRI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_relight'
GRI_MODEL_URL = '10.1.1.240:10051'
# Pose Transform service config
PS_RABBITMQ_QUEUES = f"PoseTransform{RABBITMQ_ENV}"
BATCH_PS_RABBITMQ_QUEUES = f"BatchPoseTransform{RABBITMQ_ENV}"
PT_MODEL_URL = '10.1.1.243:10061'
# SEG service config
SEGMENTATION = {
"new_model_name": "seg_knet",
"name": "seg_ocrnet_hr18",
"input": "seg_input__0",
"output": "seg_output__0",
}
# ollama config
OLLAMA_URL = "http://10.1.1.240:11434/api/embeddings"
# design batch
BATCH_DESIGN_RABBITMQ_QUEUES = f"DesignBatch{RABBITMQ_ENV}"
# DESIGN config
DESIGN_MODEL_URL = '10.1.1.240:10000'
AIDA_CLOTHING = "aida-clothing"
KEYPOINT_RESULT_TABLE_FIELD_SET = ('neckline_left', 'neckline_right', 'shoulder_left', 'shoulder_right', 'armpit_left', 'armpit_right',
'cuff_left_in', 'cuff_left_out', 'cuff_right_in', 'cuff_right_out', 'waistband_left', 'waistband_right')
# DESIGN 预处理
IF_DEBUG_SHOW = False
# 优先级
PRIORITY_DICT = {
'earring_front': 99,
'bag_front': 98,
'hairstyle_front': 97,
'outwear_front': 20,
'tops_front': 19,
'dress_front': 18,
'blouse_front': 17,
'skirt_front': 16,
'trousers_front': 15,
'bottoms_front': 14,
'shoes_right': 1,
'shoes_left': 1,
'body': 0,
'bottoms_back': -14,
'trousers_back': -15,
'skirt_back': -16,
'blouse_back': -17,
'dress_back': -18,
'tops_back': -19,
'outwear_back': -20,
'hairstyle_back': -97,
'bag_back': -98,
'earring_back': -99,
}
QWEN_API_KEY = "sk-f31c29e61ac2498ba5e307aaa6dc10e0"
DB_CONFIG = {
"host": "18.167.251.121",
"port": 3306,
"user": "root",
"password": "QWa998345",
"database": "aida",
"charset": "utf8mb4"
}
TABLE_CATEGORIES = {
"female_dress": "female/dress",
"female_outwear": "female/outwear",
"female_trousers": "female/trousers",
"female_skirt": "female/skirt",
"female_blouse": "female/blouse",
"male_tops": "male/tops",
"male_bottoms": "male/bottoms",
"male_outwear": "male/outwear"
}
# --- ComfyUI 配置信息 ---
COMFYUI_SERVER_ADDRESS = "10.1.2.227:8080" # 替换为您的 ComfyUI 服务器地址

View File

@@ -1,74 +1,99 @@
import logging
from typing import Dict, Any
import yaml
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict
from v2.nacos import ClientConfigBuilder, GRPCConfig, NacosConfigService, ConfigParam, NacosNamingService, RegisterInstanceParam, DeregisterInstanceParam
logger = logging.getLogger(__name__)
# ====================== Nacos 配置 ======================
NACOS_SERVER_ADDRESSES = "18.167.251.121:28848"
NACOS_NAMESPACE = "zcr"
NACOS_USERNAME = "nacos"
NACOS_PASSWORD = "Aidlab123123!"
NACOS_GROUP = "LOCAL"
NACOS_DATA_ID = "aida.python"
SERVICE_NAME = "fastapi-service" # ←←← 必须修改!建议格式:项目名-环境,例如 ai-image-service-dev
class Settings(BaseSettings):
"""
应用配置类。Pydantic Settings 会自动从环境变量和 .env 文件中加载这些值。
"""
model_config = SettingsConfigDict(
env_file='.env',
env_file_encoding='utf-8',
# extra='ignore' # 忽略环境变量中多余的键
)
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore") # 忽略环境变量中多余的键
# --- 服务端口配置信息 ---
PORT: int = Field(default=8001, description="")
# --- 服务环境 配置信息 ---
SERVE_ENV: str = Field(default='', description="")
SERVE_ENV: str = Field(default="", description="")
# --- 开发状态 配置信息 ---
DEBUG: bool = Field(default=False, description="")
# --- 千问api 配置信息 ---
QWEN_API_KEY: str = Field(default="", description="")
# --- ComfyUI 配置信息 ---
COMFYUI_SERVER_ADDRESS: str = Field(default='', description="")
COMFYUI_SERVER_ADDRESS: str = Field(default="", description="")
# --- minio 配置信息 ---
MINIO_URL: str = Field(default='', description="")
MINIO_ACCESS: str = Field(default='', description="")
MINIO_SECRET: str = Field(default='', description="")
MINIO_URL: str = Field(default="", description="")
MINIO_ACCESS: str = Field(default="", description="")
MINIO_SECRET: str = Field(default="", description="")
MINIO_SECURE: bool = Field(default=True, description="")
# --- redis 配置信息 ---
REDIS_HOST: str = Field(default='', description="")
REDIS_PORT: str = Field(default='', description="")
REDIS_HOST: str = Field(default="", description="")
REDIS_PORT: str = Field(default="", description="")
REDIS_DB: int = Field(default=0, description="")
# --- mysql 配置信息 ---
MYSQL_HOST: str = Field(default='', description="")
MYSQL_PORT: int = Field(default='', description="")
MYSQL_USER: str = Field(default='', description="")
MYSQL_PASSWORD: str = Field(default='', description="")
MYSQL_DB: str = Field(default='', description="")
MYSQL_CHARSET: str = Field(default='utf8mb4', description="")
MYSQL_HOST: str = Field(default="", description="")
MYSQL_PORT: int = Field(default=3306, description="")
MYSQL_USER: str = Field(default="", description="")
MYSQL_PASSWORD: str = Field(default="", description="")
MYSQL_DB: str = Field(default="", description="")
MYSQL_CHARSET: str = Field(default="utf8mb4", description="")
# --- rabbit-mq 配置信息 ---
MQ_HOST: str = Field(default='', description="")
MQ_PORT: str = Field(default='', description="")
MQ_USERNAME: str = Field(default='', description="")
MQ_PASSWORD: str = Field(default='', description="")
MQ_VIRTUAL_HOST: str = Field(default='/', description="")
MQ_ENV: str = Field(default='', description="")
MQ_HOST: str = Field(default="", description="")
MQ_PORT: str = Field(default="", description="")
MQ_USERNAME: str = Field(default="", description="")
MQ_PASSWORD: str = Field(default="", description="")
MQ_VIRTUAL_HOST: str = Field(default="/", description="")
MQ_ENV: str = Field(default="", description="")
# --- milvus 配置信息 ---
MILVUS_URL: str = Field(default='', description="")
MILVUS_TOKEN: str = Field(default='', description="")
MILVUS_ALIAS: str = Field(default='', description="")
MILVUS_URL: str = Field(default="", description="")
MILVUS_TOKEN: str = Field(default="", description="")
MILVUS_ALIAS: str = Field(default="", description="")
# --- ollama 配置信息 ---
CHROMADB_PATH: str = Field(default='', description="")
CHROMADB_PATH: str = Field(default="", description="")
# --- ollama 配置信息 ---
OLLAMA_URL: str = Field(default='', description="")
OLLAMA_URL: str = Field(default="", description="")
# --- Design Callback Java 接口 ---
JAVA_STREAM_API_URL: str = Field(default='', description="")
JAVA_STREAM_API_URL: str = Field(default="", description="")
# --- flux2 klein model url ---
FLUX2_GEN_IMG_MODEL_URL: str = Field(default="", description="")
# --- 服务器IP ---
A6000_SERVICE_HOST: str = Field(default="", description="")
B_4_X_4090_SERVICE_HOST: str = Field(default="", description="")
# --- sketch to garment 模型url ---
SKETCH_TO_GARMENT_URL: str = Field(default='', description="")
# --- 其他配置信息 以下均为Docker容器内配置---
LOGS_PATH: str = Field(default="/logs/", description="")
CATEGORY_PATH: str = Field(default="/app/service/attribute/config/descriptor/category/category_dis.csv", description="")
SEG_CACHE_PATH: str = Field(default="/seg_cache/", description="")
RECOMMEND_PATH_PREFIX: str = Field(default="/app/service/recommend/", description="")
SERVE_PORT: int = Field(default=2010, description="")
sketch_to_garment_url: str = Field(default="", description="")
settings = Settings()
@@ -83,73 +108,88 @@ TABLE_CATEGORIES = {
"female_blouse": "female/blouse",
"male_tops": "male/tops",
"male_bottoms": "male/bottoms",
"male_outwear": "male/outwear"
"male_outwear": "male/outwear",
}
# Design前后排优先级
PRIORITY_DICT = {
'earring_front': 99,
'bag_front': 98,
'hairstyle_front': 97,
'outwear_front': 20,
'tops_front': 19,
'dress_front': 18,
'blouse_front': 17,
'skirt_front': 16,
'trousers_front': 15,
'bottoms_front': 14,
'shoes_right': 1,
'shoes_left': 1,
'body': 0,
'bottoms_back': -14,
'trousers_back': -15,
'skirt_back': -16,
'blouse_back': -17,
'dress_back': -18,
'tops_back': -19,
'outwear_back': -20,
'hairstyle_back': -97,
'bag_back': -98,
'earring_back': -99,
"earring_front": 99,
"bag_front": 98,
"hairstyle_front": 97,
"outwear_front": 20,
"tops_front": 19,
"dress_front": 18,
"blouse_front": 17,
"skirt_front": 16,
"trousers_front": 15,
"bottoms_front": 14,
"shoes_right": 1,
"shoes_left": 1,
"body": 0,
"bottoms_back": -14,
"trousers_back": -15,
"skirt_back": -16,
"blouse_back": -17,
"dress_back": -18,
"tops_back": -19,
"outwear_back": -20,
"hairstyle_back": -97,
"bag_back": -98,
"earring_back": -99,
}
# Design 关键点字段
KEYPOINT_RESULT_TABLE_FIELD_SET = ('neckline_left', 'neckline_right', 'shoulder_left', 'shoulder_right', 'armpit_left', 'armpit_right', 'cuff_left_in', 'cuff_left_out', 'cuff_right_in', 'cuff_right_out', 'waistband_left', 'waistband_right')
KEYPOINT_RESULT_TABLE_FIELD_SET = (
"neckline_left",
"neckline_right",
"shoulder_left",
"shoulder_right",
"armpit_left",
"armpit_right",
"cuff_left_in",
"cuff_left_out",
"cuff_right_in",
"cuff_right_out",
"waistband_left",
"waistband_right",
)
# milvus配置信息
MILVUS_TABLE_KEYPOINT = "keypoint_cache_2"
# ollama 地址
OLLAMA_URL = "http://10.1.1.240:11434/api/embeddings"
OLLAMA_URL = f"http://{settings.A6000_SERVICE_HOST}:11434/api/embeddings"
"""Triton Server Config"""
# Design
DESIGN_MODEL_URL = '10.1.1.240:10000'
DESIGN_MODEL_NAME = 'seg_knet'
DESIGN_MODEL_URL = f"{settings.A6000_SERVICE_HOST}:10000"
DESIGN_MODEL_NAME = "seg_knet"
# Seg Product
SEG_PRODUCT_MODEL_URL = f"{settings.B_4_X_4090_SERVICE_HOST}:30000"
# Generate Image
GI_MODEL_URL = '10.1.1.240:10061'
GI_MODEL_NAME = 'flux'
GI_MODEL_URL = f"{settings.A6000_SERVICE_HOST}:10061"
GI_MODEL_NAME = "flux"
# Generate Single Logo
GSL_MODEL_URL = '10.1.1.243:10041'
GSL_MODEL_NAME = 'stable_diffusion_xl_transparent'
GSL_MODEL_URL = f"{settings.B_4_X_4090_SERVICE_HOST}:10041"
GSL_MODEL_NAME = "stable_diffusion_xl_transparent"
# Generate Product (整套和单品)
GPI_MODEL_URL = '10.1.1.243:10051'
GPI_MODEL_NAME_OVERALL = 'diffusion_ensemble_all'
GPI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_cnet'
GPI_MODEL_URL = f"{settings.B_4_X_4090_SERVICE_HOST}:10051"
GPI_MODEL_NAME_OVERALL = "diffusion_ensemble_all"
GPI_MODEL_NAME_SINGLE = "stable_diffusion_1_5_cnet"
# 以下停用中...*************
# 多视角生成
GMV_MODEL_URL = '10.1.1.243:10081'
GMV_MODEL_NAME = 'multi_view'
GMV_MODEL_URL = f"{settings.B_4_X_4090_SERVICE_HOST}:10081"
GMV_MODEL_NAME = "multi_view"
# 超分
SR_MODEL_NAME = "super_resolution"
SR_TRITON_URL = "10.1.1.240:10031"
SR_TRITON_URL = f"{settings.A6000_SERVICE_HOST}:10031"
# 打光
GRI_MODEL_URL = '10.1.1.240:10051'
GRI_MODEL_NAME_OVERALL = 'diffusion_relight_ensemble'
GRI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_relight'
GRI_MODEL_URL = f"{settings.A6000_SERVICE_HOST}:10051"
GRI_MODEL_NAME_OVERALL = "diffusion_relight_ensemble"
GRI_MODEL_NAME_SINGLE = "stable_diffusion_1_5_relight"
# agent 图片生成
FAST_GI_MODEL_URL = '10.1.1.243:10011'
FAST_GI_MODEL_NAME = 'stable_diffusion_xl'
FAST_GI_MODEL_URL = f"{settings.B_4_X_4090_SERVICE_HOST}:10011"
FAST_GI_MODEL_NAME = "stable_diffusion_xl"
# 图转视频 triton版
PT_MODEL_URL = '10.1.1.243:10061'
PT_MODEL_URL = f"{settings.B_4_X_4090_SERVICE_HOST}:10061"
# *************

View File

@@ -0,0 +1,343 @@
import logging
import socket
from typing import Dict, Any
import yaml
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict
from v2.nacos import ClientConfigBuilder, GRPCConfig, NacosConfigService, ConfigParam, NacosNamingService, RegisterInstanceParam, DeregisterInstanceParam
logger = logging.getLogger(__name__)
# ====================== Nacos 配置 ======================
NACOS_SERVER_ADDRESSES = "18.167.251.121:28848"
NACOS_NAMESPACE = "zcr"
NACOS_USERNAME = "nacos"
NACOS_PASSWORD = "Aidlab123123!"
NACOS_GROUP = "LOCAL"
NACOS_DATA_ID = "aida.python"
SERVICE_NAME = "fastapi-service" # ←←← 必须修改!建议格式:项目名-环境,例如 ai-image-service-dev
class Settings(BaseSettings):
"""
应用配置类。Pydantic Settings 会自动从环境变量和 .env 文件中加载这些值。
"""
model_config = SettingsConfigDict(
env_file='.env',
env_file_encoding='utf-8',
# extra='ignore' # 忽略环境变量中多余的键
)
# --- 服务端口配置信息 ---
PORT: int = Field(default=8001, description="")
# --- 服务环境 配置信息 ---
SERVE_ENV: str = Field(default='', description="")
# --- 开发状态 配置信息 ---
DEBUG: bool = Field(default=False, description="")
# --- 千问api 配置信息 ---
QWEN_API_KEY: str = Field(default="", description="")
# --- ComfyUI 配置信息 ---
COMFYUI_SERVER_ADDRESS: str = Field(default='', description="")
# --- minio 配置信息 ---
MINIO_URL: str = Field(default='', description="")
MINIO_ACCESS: str = Field(default='', description="")
MINIO_SECRET: str = Field(default='', description="")
MINIO_SECURE: bool = Field(default=True, description="")
# --- redis 配置信息 ---
REDIS_HOST: str = Field(default='', description="")
REDIS_PORT: str = Field(default='', description="")
REDIS_DB: int = Field(default=0, description="")
# --- mysql 配置信息 ---
MYSQL_HOST: str = Field(default='', description="")
MYSQL_PORT: int = Field(default=3306, description="")
MYSQL_USER: str = Field(default='', description="")
MYSQL_PASSWORD: str = Field(default='', description="")
MYSQL_DB: str = Field(default='', description="")
MYSQL_CHARSET: str = Field(default='utf8mb4', description="")
# --- rabbit-mq 配置信息 ---
MQ_HOST: str = Field(default='', description="")
MQ_PORT: str = Field(default='', description="")
MQ_USERNAME: str = Field(default='', description="")
MQ_PASSWORD: str = Field(default='', description="")
MQ_VIRTUAL_HOST: str = Field(default='/', description="")
MQ_ENV: str = Field(default='', description="")
# --- milvus 配置信息 ---
MILVUS_URL: str = Field(default='', description="")
MILVUS_TOKEN: str = Field(default='', description="")
MILVUS_ALIAS: str = Field(default='', description="")
# --- ollama 配置信息 ---
CHROMADB_PATH: str = Field(default='', description="")
# --- ollama 配置信息 ---
OLLAMA_URL: str = Field(default='', description="")
# --- Design Callback Java 接口 ---
JAVA_STREAM_API_URL: str = Field(default='', description="")
# --- flux2 klein model url ---
FLUX2_GEN_IMG_MODEL_URL: str = Field(default='', description="")
# --- 服务器IP ---
A6000_SERVICE_HOST: str = Field(default='', description="")
B_4_X_4090_SERVICE_HOST: str = Field(default='', description="")
# --- 其他配置信息 以下均为Docker容器内配置---
LOGS_PATH: str = Field(default="/logs/", description="")
CATEGORY_PATH: str = Field(default="/app/service/attribute/config/descriptor/category/category_dis.csv", description="")
SEG_CACHE_PATH: str = Field(default="/seg_cache/", description="")
RECOMMEND_PATH_PREFIX: str = Field(default="/app/service/recommend/", description="")
SERVE_PORT: int = Field(default=2010, description="")
settings = Settings()
# ====================== Nacos 配置管理 ======================
client_config = (ClientConfigBuilder()
.server_address(NACOS_SERVER_ADDRESSES)
.username(NACOS_USERNAME)
.password(NACOS_PASSWORD)
.namespace_id(NACOS_NAMESPACE)
.log_level('INFO')
.grpc_config(GRPCConfig(grpc_timeout=5000))
.build())
# ====================== Nacos 配置管理 ======================
nacos_config_data: Dict[str, Any] = {}
nacos_config_client = None
async def load_nacos_config() -> None:
"""初始化 Nacos 配置并监听变化"""
global nacos_config_data, settings
try:
client = await NacosConfigService.create_config_service(client_config)
# 1. 第一次获取配置
content = await client.get_config(ConfigParam(
data_id=NACOS_DATA_ID,
group=NACOS_GROUP
))
if content:
loaded = yaml.safe_load(content) or {}
nacos_config_data = loaded
# 用 Nacos 配置覆盖 settings
for key, value in loaded.items():
if hasattr(settings, key):
setattr(settings, key, value)
logger.info(f"✅ Nacos 配置加载成功: {NACOS_DATA_ID} | 覆盖字段数量: {len(loaded)}")
else:
logger.warning("Nacos 返回配置为空,使用 .env + 默认值")
# 2. 注册动态监听器(配置变更自动刷新)
async def listener(tenant: str, data_id: str, group: str, content: str):
global nacos_config_data, settings
try:
new_config = yaml.safe_load(content) if content else {}
nacos_config_data = new_config
# 实时覆盖 settings
for key, value in new_config.items():
if hasattr(settings, key):
old_val = getattr(settings, key)
setattr(settings, key, value)
if old_val != value:
logger.info(f"🔄 配置更新 → {key}: {old_val}{value}")
logger.info(f"【Nacos 动态更新】{NACOS_DATA_ID}")
except Exception as e:
logger.error(f"Nacos 配置解析失败: {e}")
await client.add_listener(NACOS_DATA_ID, NACOS_GROUP, listener)
logger.info("✅ Nacos 配置监听器已注册(支持热更新)")
await register_service_to_nacos()
except Exception as e:
logger.error(f"❌ Nacos 初始化失败: {e},将仅使用 .env 配置")
async def register_service_to_nacos():
"""启动时把服务注册到 Nacos"""
global nacos_config_client
nacos_config_client = await NacosConfigService.create_config_service(client_config)
if not nacos_config_client: # 如果配置客户端都没连上,就不注册
logger.warning("Nacos 配置客户端未初始化,跳过服务注册")
return
try:
nacos_naming_client = await NacosNamingService.create_naming_service(client_config)
# 获取服务 IP生产环境建议通过环境变量传入避免 Docker/K8s 内获取错误)
host_ip = socket.gethostbyname(socket.gethostname())
if not host_ip or host_ip.startswith('127.'):
host_ip = "127.0.0.1" # 本地测试用
param = RegisterInstanceParam(
service_name="aida.python",
group_name=NACOS_GROUP,
ip=host_ip,
port=settings.PORT, # 使用你 settings 中的 PORT
cluster_name="DEFAULT",
weight=1.0,
metadata={
"version": "1.0.0",
"env": settings.SERVE_ENV,
"framework": "fastapi",
"debug": str(settings.DEBUG),
},
enabled=True,
healthy=True,
ephemeral=True, # 临时实例,推荐生产使用
)
await nacos_naming_client.register_instance(request=param)
logger.info(f"✅ 服务已成功注册到 Nacos{SERVICE_NAME} | {host_ip}:{settings.PORT} | env={settings.SERVE_ENV}")
except Exception as e:
logger.error(f"❌ 服务注册到 Nacos 失败: {e}")
async def deregister_service_from_nacos():
"""服务关闭时优雅注销(防止 Nacos 长时间显示不健康实例)"""
try:
nacos_naming_client = await NacosNamingService.create_naming_service(client_config)
host_ip = socket.gethostbyname(socket.gethostname()) or "127.0.0.1"
param = DeregisterInstanceParam(
service_name=SERVICE_NAME,
group_name=NACOS_GROUP,
ip=host_ip,
port=settings.PORT,
cluster_name='c1',
ephemeral=True,
)
await nacos_naming_client.deregister_instance(request=param)
logger.info(f"✅ 服务已从 Nacos 注销 → {SERVICE_NAME}")
except Exception as e:
logger.warning(f"服务注销时出现异常(通常可忽略): {e}")
# 提供给 FastAPI 的依赖
def get_settings() -> Settings:
return settings
"""Design 服务"""
# 推荐服装类别映射
TABLE_CATEGORIES = {
"female_dress": "female/dress",
"female_outwear": "female/outwear",
"female_trousers": "female/trousers",
"female_skirt": "female/skirt",
"female_blouse": "female/blouse",
"male_tops": "male/tops",
"male_bottoms": "male/bottoms",
"male_outwear": "male/outwear"
}
# Design前后排优先级
PRIORITY_DICT = {
'earring_front': 99,
'bag_front': 98,
'hairstyle_front': 97,
'outwear_front': 20,
'tops_front': 19,
'dress_front': 18,
'blouse_front': 17,
'skirt_front': 16,
'trousers_front': 15,
'bottoms_front': 14,
'shoes_right': 1,
'shoes_left': 1,
'body': 0,
'bottoms_back': -14,
'trousers_back': -15,
'skirt_back': -16,
'blouse_back': -17,
'dress_back': -18,
'tops_back': -19,
'outwear_back': -20,
'hairstyle_back': -97,
'bag_back': -98,
'earring_back': -99,
}
# Design 关键点字段
KEYPOINT_RESULT_TABLE_FIELD_SET = ('neckline_left', 'neckline_right', 'shoulder_left', 'shoulder_right', 'armpit_left', 'armpit_right', 'cuff_left_in', 'cuff_left_out', 'cuff_right_in', 'cuff_right_out', 'waistband_left', 'waistband_right')
# milvus配置信息
MILVUS_TABLE_KEYPOINT = "keypoint_cache_2"
# ollama 地址
OLLAMA_URL = f"http://{settings.A6000_SERVICE_HOST}:11434/api/embeddings"
"""Triton Server Config"""
# Design
DESIGN_MODEL_URL = f'{settings.A6000_SERVICE_HOST}:10000'
DESIGN_MODEL_NAME = 'seg_knet'
# Seg Product
SEG_PRODUCT_MODEL_URL = f'{settings.B_4_X_4090_SERVICE_HOST}:30000'
# Generate Image
GI_MODEL_URL = f'{settings.A6000_SERVICE_HOST}:10061'
GI_MODEL_NAME = 'flux'
# Generate Single Logo
GSL_MODEL_URL = f'{settings.B_4_X_4090_SERVICE_HOST}:10041'
GSL_MODEL_NAME = 'stable_diffusion_xl_transparent'
# Generate Product (整套和单品)
GPI_MODEL_URL = f'{settings.B_4_X_4090_SERVICE_HOST}:10051'
GPI_MODEL_NAME_OVERALL = 'diffusion_ensemble_all'
GPI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_cnet'
# 以下停用中...*************
# 多视角生成
GMV_MODEL_URL = f'{settings.B_4_X_4090_SERVICE_HOST}:10081'
GMV_MODEL_NAME = 'multi_view'
# 超分
SR_MODEL_NAME = "super_resolution"
SR_TRITON_URL = f"{settings.A6000_SERVICE_HOST}:10031"
# 打光
GRI_MODEL_URL = f'{settings.A6000_SERVICE_HOST}:10051'
GRI_MODEL_NAME_OVERALL = 'diffusion_relight_ensemble'
GRI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_relight'
# agent 图片生成
FAST_GI_MODEL_URL = f'{settings.B_4_X_4090_SERVICE_HOST}:10011'
FAST_GI_MODEL_NAME = 'stable_diffusion_xl'
# 图转视频 triton版
PT_MODEL_URL = f'{settings.B_4_X_4090_SERVICE_HOST}:10061'
# *************
"""MQ 队列信息"""
# 生成图片 moodboard printboard sketchboard
GI_RABBITMQ_QUEUES = f"GenerateImage-{settings.SERVE_ENV}"
# 生成slogan
SLOGAN_RABBITMQ_QUEUES = f"Slogan-{settings.SERVE_ENV}"
# 转产品图
GPI_RABBITMQ_QUEUES = f"ToProductImage-{settings.SERVE_ENV}"
# 产品图转视频
PS_RABBITMQ_QUEUES = f"PoseTransform-{settings.SERVE_ENV}"
# 以下停用中...*************
# 产品图打光
GRI_RABBITMQ_QUEUES = f"Relight-{settings.SERVE_ENV}"
# 超分
SR_RABBITMQ_QUEUES = f"SuperResolution-{settings.SERVE_ENV}"
# 生成多视图
GMV_RABBITMQ_QUEUES = f"GenerateMultiView-{settings.SERVE_ENV}"
# 批量转产品图
BATCH_GPI_RABBITMQ_QUEUES = f"BatchToProductImage-{settings.SERVE_ENV}"
# 批量打光
BATCH_GRI_RABBITMQ_QUEUES = f"BatchRelight-{settings.SERVE_ENV}"
# 批量图片转视频
BATCH_PS_RABBITMQ_QUEUES = f"BatchPoseTransform-{settings.SERVE_ENV}"
# 批量design
BATCH_DESIGN_RABBITMQ_QUEUES = f"DesignBatch-{settings.SERVE_ENV}"
# *************

108
app/core/nacos_config.py Normal file
View File

@@ -0,0 +1,108 @@
import logging
import yaml
import nacos
from app.core.config import settings
logger = logging.getLogger("nacos")
# client config
NACOS_SERVER_ADDRESSES = "18.167.251.121:28848"
NACOS_NAMESPACE = "zcr"
NACOS_USERNAME = "nacos"
NACOS_PASSWORD = "Aidlab123123!"
# nacos config info
NACOS_CONFIG_GROUP = "LOCAL"
NACOS_CONFIG_DATA_ID = "aida.python"
# nacos server config
NACOS_SERVICE_NAME = "AiDA-DEV" # ←←← 必须修改!建议格式:项目名-环境,例如 ai-image-service-dev
NACOS_SERVICE_IP = "127.0.0.1"
NACOS_SERVICE_PORT = 8445
# nacos client
client = nacos.NacosClient(
server_addresses=NACOS_SERVER_ADDRESSES,
namespace=NACOS_NAMESPACE,
username=NACOS_USERNAME,
password=NACOS_PASSWORD
)
def listener_config_callback(args):
data_id = args['data_id']
namespace = args['namespace']
group = args['group']
content = args['content']
logger.info("【Nacos】配置")
try:
logger.info(f"【Nacos】 动态更新 : data_id : {data_id} | namespace : {namespace} | group_name: {group}")
new_config = yaml.safe_load(content) if content else {}
for key, value in new_config.items():
if hasattr(settings, key):
old_val = getattr(settings, key)
setattr(settings, key, value)
if old_val != value:
logger.info(f"🔄 配置更新 → {key}: {old_val}{value}")
except Exception as e:
logger.error(f"【Nacos】 配置解析失败: {e}")
def remove_config_callback(args):
data_id = args['data_id']
namespace = args['namespace']
print(f" remove_config_callback : {data_id} | namespace : {namespace}")
def load_nacos_config():
"""初始化 Nacos 配置并监听变化"""
logger.info(f"【Nacos】 配置订阅 - 初次获取配置信息")
try:
# 1. 第一次获取配置
content = client.get_config(data_id=NACOS_CONFIG_DATA_ID, group=NACOS_CONFIG_GROUP)
if content:
loaded = yaml.safe_load(content) or {}
for key, value in loaded.items():
if hasattr(settings, key):
setattr(settings, key, value)
logger.info(f"【Nacos】✅ 配置加载成功: {NACOS_CONFIG_DATA_ID} | 覆盖字段数量: {len(loaded)}")
else:
logger.warning("【Nacos】 返回配置为空,使用 .env + 默认值")
client.add_config_watcher(data_id=NACOS_CONFIG_DATA_ID, group=NACOS_CONFIG_GROUP, cb=listener_config_callback)
logger.info("【Nacos】✅ 配置监听器已注册(支持热更新)")
except Exception as e:
logger.error(f"【Nacos】❌ 初始化失败: {e},将仅使用 .env 配置")
finally:
client.remove_config_watcher(
data_id=NACOS_CONFIG_DATA_ID,
group=NACOS_CONFIG_GROUP,
cb=remove_config_callback
)
def register_server():
logger.info(f"nacos 服务注册")
try:
client.add_naming_instance(
service_name=NACOS_SERVICE_NAME,
ip=NACOS_SERVICE_IP,
port=NACOS_SERVICE_PORT,
metadata={"status": "ok"},
)
except Exception as e:
logger.warning(f"【Nacos】❌ 服务注册失败 : {e}")
def deregister_server():
logger.info(f"nacos 服务注册")
try:
client.remove_naming_instance(
service_name=NACOS_SERVICE_NAME,
ip=NACOS_SERVICE_IP,
port=NACOS_SERVICE_PORT
)
except Exception as e:
logger.warning(f"【Nacos】❌ 服务注销失败 : {e}")

View File

@@ -1,86 +0,0 @@
{
"objects": [
{
"basic": {
"body_point_test": {
"waistband_right": [
201,
242
],
"hand_point_right": [
222,
312
],
"waistband_left": [
114,
243
],
"hand_point_left": [
94,
310
],
"shoulder_left": [
102,
116
],
"shoulder_right": [
211,
115
]
},
"layer_order": true,
"scale_bag": 0.7,
"scale_earrings": 0.16,
"self_template": true,
"single_overall": "overall",
"switch_category": ""
},
"items": [
{
"businessId": 264931,
"color": "145 220 232",
"image_id": 96844,
"offset": [
0,
0
],
"path": "aida-users/87/sketch/2aa7aad5-74bb-41fa-9cdf-f06611b3e89a-2-87.png",
"print": {
"element": {
"element_angle_list": [],
"element_path_list": [],
"element_scale_list": [],
"location": []
},
"overall": {
"location": [],
"print_angle_list": [],
"print_path_list": [],
"print_scale_list": []
},
"single": {
"location": [],
"print_angle_list": [],
"print_path_list": [],
"print_scale_list": []
}
},
"priority": 10,
"resize_scale": [
1.0,
1.0
],
"type": "Dress"
},
{
"body_path": "aida-sys-image/models/female/79805ec3-3f01-466d-91e0-36028d079699.png",
"image_id": 95444,
"type": "Body"
}
]
}
],
"process_id": "87",
"tasks_id": ""
}

View File

@@ -1,5 +1,8 @@
# 1. 这里的顺序至关重要!必须在最顶端
import sys
from contextlib import asynccontextmanager
# from app.core.nacos_config import load_nacos_config, register_server, deregister_server
try:
import asyncore
@@ -16,7 +19,7 @@ from fastapi.responses import JSONResponse
from app.api.api_route import router
from app.core.config import settings
from app.core.record_api_count import count_api_calls
# from app.core.record_api_count import count_api_calls
from app.schemas.response_template import ResponseModel
from logging_env import LOGGER_CONFIG_DICT
from dotenv import load_dotenv
@@ -30,8 +33,21 @@ logger = logging.getLogger(__name__)
load_dotenv()
# @asynccontextmanager
# async def lifespan(app: FastAPI):
# try:
# load_nacos_config()
# register_server()
#
# yield
# finally:
# deregister_server()
# logger.info("lifespan down")
def get_application() -> FastAPI:
application = FastAPI(
# lifespan=lifespan,
docs_url="/docs",
redoc_url='/re-docs',
openapi_url=f"/openapi.json",
@@ -48,7 +64,7 @@ def get_application() -> FastAPI:
allow_methods=["*"],
allow_headers=["*"],
)
application.middleware("http")(count_api_calls)
# application.middleware("http")(count_api_calls)
application.include_router(router=router)
return application
@@ -64,5 +80,11 @@ async def http_exception_handler(exc: HTTPException):
)
@app.get("/health", operation_id="health")
async def health():
logger.info("health check")
return {"ok": True, "env": settings.APP_ENV}
if __name__ == '__main__':
uvicorn.run(app, host="0.0.0.0", port=settings.PORT)

View File

@@ -1,10 +1,16 @@
from pydantic import BaseModel
from typing import List, Optional
from pydantic import BaseModel, Field
class SAMRequestModel(BaseModel):
image_path: str
points: list[list[float]]
labels: list[int]
bucket: str = Field(..., description="minio bucket name ")
object_name: str = Field(..., description="minio object name ")
image_path: str = Field(..., description="图片路径,必填字段")
type: str = Field(..., description="推理类型,必填字段")
points: Optional[List[List[float]]] | None = None
labels: Optional[List[int]] | None = None
box: Optional[List[int]] | None = None
class DesignModel(BaseModel):

View File

@@ -0,0 +1,27 @@
from typing import Optional
from pydantic import BaseModel, Field
class FashionAgentRequest(BaseModel):
"""服装设计 Agent 请求"""
message: str = Field(default="", description="用户输入的消息")
user_id: str = Field(default="test-agent", description="用户ID,用于生成图片存储路径")
enable_thinking: bool = Field(default=False, description="模型思考是否开启")
call_print: bool = Field(default=False, description="是否直接调用 print 生成印花")
print_need_prompt_generation: bool = Field(default=False, description="print 是否需要 LLM 生成 prompt")
call_logo: bool = Field(default=False, description="是否直接调用 logo 生成装饰图案")
call_sketch: bool = Field(default=False, description="是否直接调用 sketch 生成草图")
sketch_need_prompt_generation: bool = Field(default=False, description="sketch 是否需要 LLM 生成 prompt")
call_design: bool = Field(default=False, description="是否直接调用 design 生成设计系列")
design_request_data: dict = Field(default={}, description="design 请求参数")
call_trending: bool = Field(default=False, description="是否直接调用 trending 趋势分析")
call_explor: bool = Field(default=False, description="是否直接调用 explorer 灵感探索")
provider: Optional[str] = Field(default="unsplash", description="图片源: pexels 或 unsplash")

View File

@@ -1,6 +1,6 @@
from typing import List
from typing import List, Optional
from pydantic import BaseModel
from pydantic import BaseModel, Field
class GenerateMultiViewModel(BaseModel):
@@ -8,6 +8,17 @@ class GenerateMultiViewModel(BaseModel):
image_url: str
class GenerateImageFlux2KleinModel(BaseModel):
bucket_name: str = Field(..., description="OSS桶名不传则为None")
object_name: str = Field(..., description="OSS对象名文件路径不传则为None")
# input_image_paths: Optional[List[str]] = Field(default=[], description="输入图片路径列表")
width: Optional[int] = Field(default=1024, description="图片宽度默认512像素")
height: Optional[int] = Field(default=1024, description="图片高度默认512像素")
prompt: Optional[str] = Field(default="", description="文本提示词,用于模型推理等场景")
steps: Optional[int] = Field(default=4, description="推理步数,控制模型生成过程的迭代次数")
guidance: Optional[float] = Field(default=4.0, description="引导系数,调节提示词对生成结果的影响程度")
class GenerateImageModel(BaseModel):
tasks_id: str
prompt: str
@@ -24,6 +35,13 @@ class GenerateSingleLogoImageModel(BaseModel):
seed: str
class GenerateSloganImageModel(BaseModel):
num_point: int
tasks_id: str
prompt: str
image_url: str
class GenerateProductImageModel(BaseModel):
tasks_id: str
prompt: str
@@ -32,6 +50,13 @@ class GenerateProductImageModel(BaseModel):
product_type: str
class Flux2ToProductImgModel(BaseModel):
tasks_id: str
prompt: str
image_path: str
infer_step: int | None = None
class GenerateRelightImageModel(BaseModel):
tasks_id: str
prompt: str

View File

@@ -0,0 +1,12 @@
from typing import List
from pydantic import BaseModel, Field
class SketchToGarmentModel(BaseModel):
input_image_path: str = Field(..., description="输入图片路径列表")
bucket_name: str = Field(..., description="输入图片路径列表")
user_id: str = Field(..., description="用户id")
callback_url: str # 必填,客户端提供的回调地址
task_id: str = Field()
model: str = Field(default="single", description="模型类型: single 或 multi")

View File

@@ -3,7 +3,6 @@
from pprint import pprint
import cv2
import mmcv
import numpy as np
import pandas as pd
import torch
@@ -12,6 +11,7 @@ from minio import Minio
from app.core.config import settings, DESIGN_MODEL_URL
from app.schemas.attribute_retrieve import AttributeRecognitionModel
from app.service.utils.image_normalize import my_imnormalize
from app.service.utils.new_oss_client import oss_get_image
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
@@ -109,10 +109,9 @@ class AttributeRecognition:
@staticmethod
def preprocess(img):
img = mmcv.imread(img)
img_scale = (224, 224)
img = cv2.resize(img, img_scale)
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
img = my_imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img

View File

@@ -10,7 +10,6 @@
from minio import Minio
from skimage import transform
import cv2
import mmcv
import numpy as np
import pandas as pd
import tritonclient.http as httpclient
@@ -18,6 +17,7 @@ import torch
from app.core.config import settings, DESIGN_MODEL_URL
from app.schemas.attribute_retrieve import CategoryRecognitionModel
from app.service.utils.image_normalize import my_imnormalize
from app.service.utils.new_oss_client import oss_get_image
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
@@ -39,11 +39,10 @@ class CategoryRecognition:
@staticmethod
def preprocess(img):
img = mmcv.imread(img)
# ori_shape = img.shape[:2]
img_scale = (224, 224)
img = cv2.resize(img, img_scale)
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
img = my_imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img

View File

@@ -1,7 +1,6 @@
import logging
import cv2
import mmcv
import numpy as np
import pandas as pd
import torch
@@ -9,11 +8,12 @@ import torch.nn.functional as F
import tritonclient.http as httpclient
from minio import Minio
from app.core.config import DESIGN_MODEL_URL
from app.core.config import DESIGN_MODEL_URL, SEG_PRODUCT_MODEL_URL
from app.core.config import settings
from app.schemas.brand_dna import BrandDnaModel
from app.service.attribute.config import const
from app.service.utils.generate_uuid import generate_uuid
from app.service.utils.image_normalize import my_imnormalize
from app.service.utils.new_oss_client import oss_upload_image, oss_get_image
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
@@ -29,7 +29,7 @@ class BrandDna:
self.attr_type = pd.read_csv(settings.CATEGORY_PATH)
# self.attr_type = pd.read_csv(r"E:\workspace\trinity_client_aida\app\service\attribute\config\descriptor\category\category_dis.csv")
self.att_client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
self.seg_client = httpclient.InferenceServerClient(url='10.1.1.243:30000')
self.seg_client = httpclient.InferenceServerClient(url=SEG_PRODUCT_MODEL_URL)
self.const = const
# self.const = local_debug_const
@@ -202,7 +202,7 @@ class BrandDna:
# 服装分割预处理
@staticmethod
def seg_product_preprocess(image):
img = mmcv.imread(image)
img = image
ori_shape = img.shape[:2]
img_scale_w, img_scale_h = ori_shape
if ori_shape[0] > 1024:
@@ -211,9 +211,9 @@ class BrandDna:
img_scale_h = 1024
# 如果图片size任意一边 大于 1024 则会resize 成1024
if ori_shape != (img_scale_w, img_scale_h):
# mmcv.imresize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了
# my_imnormalize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了
img = cv2.resize(img, (img_scale_h, img_scale_w))
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
img = my_imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img, ori_shape
@@ -227,11 +227,10 @@ class BrandDna:
# 类别检测模型预处理
@staticmethod
def category_preprocess(img):
img = mmcv.imread(img)
# ori_shape = img.shape[:2]
img_scale = (224, 224)
img = cv2.resize(img, img_scale)
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
img = my_imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img

View File

@@ -1,19 +1,10 @@
import logging
import cv2
import numpy as np
import tritonclient.grpc as grpcclient
import uuid
import httpx
from langchain_classic.output_parsers import ResponseSchema, StructuredOutputParser
from langchain_community.chat_models import ChatTongyi
from langchain_core.prompts import PromptTemplate
from minio import Minio
from tritonclient.utils import np_to_triton_dtype
from app.core.config import GI_MODEL_URL, GI_MODEL_NAME
from app.schemas.brand_dna import GenerateBrandModel
from app.service.utils.generate_uuid import generate_uuid
from app.service.utils.new_oss_client import oss_upload_image
from app.core.config import settings
@@ -26,14 +17,9 @@ class GenerateBrandInfo:
# user info init
self.user_id = request_data.user_id
self.category = "brand_logo"
# generate logo init
self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL)
self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8)
self.batch_size = 1
self.mode = 'txt2img'
# llm generate brand info init
self.model = ChatTongyi(model="qwen2.5-14b-instruct", api_key="sk-7658298c6b99443c98184a5e634fe6ab")
self.model = ChatTongyi(model="qwen2.5-14b-instruct", api_key=settings.QWEN_API_KEY)
self.response_schemas = [
ResponseSchema(name="brand_name", description="Brand name."),
@@ -63,38 +49,20 @@ class GenerateBrandInfo:
self.generate_logo_prompt = brand_data['brand_logo_prompt']
def generate_brand_logo(self):
prompts = [self.generate_logo_prompt] * self.batch_size
modes = [self.mode] * self.batch_size
images = [self.image.astype(np.float16)] * self.batch_size
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
mode_obj = np.array(modes, dtype="object").reshape((-1, 1))
image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3))
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
input_image = grpcclient.InferInput("input_image", image_obj.shape, np_to_triton_dtype(image_obj.dtype))
input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(mode_obj.dtype))
input_text.set_data_from_numpy(text_obj)
input_image.set_data_from_numpy(image_obj)
input_mode.set_data_from_numpy(mode_obj)
inputs = [input_text, input_image, input_mode]
result = self.grpc_client.infer(model_name=GI_MODEL_NAME, inputs=inputs)
image = result.as_numpy("generated_image")
image_result = cv2.cvtColor(np.squeeze(image.astype(np.uint8)), cv2.COLOR_RGB2BGR)
logo_url = self.upload_logo_image(image_result, generate_uuid())
self.result_data['brand_logo'] = logo_url
def upload_logo_image(self, image, object_name):
try:
_, img_byte_array = cv2.imencode('.jpg', image)
object_name = f'{self.user_id}/{self.category}/{object_name}.jpg'
oss_upload_image(oss_client=self.minio_client, bucket="aida-users", object_name=object_name, image_bytes=img_byte_array)
image_url = f"aida-users/{object_name}"
return image_url
except Exception as e:
logging.warning(f"upload_png_mask runtime exception : {e}")
request_item = {
"bucket_name": "aida-users",
"object_name": f'{self.user_id}/{self.category}/{uuid.uuid4().hex}.png',
"prompt": self.generate_logo_prompt,
"height": 1024,
"width": 1024
}
with httpx.Client(timeout=120) as client:
resp = client.post(
f"http://{settings.FLUX2_GEN_IMG_MODEL_URL}/predict",
json=request_item,
)
result = resp.json()
self.result_data['brand_logo'] = result.get("output_path", "")
if __name__ == '__main__':

View File

@@ -23,7 +23,7 @@ class ClothingSeg:
def __init__(self, request_data):
self.image_data = request_data.image_data
self.user_id = request_data.user_id
self.triton_client = grpcclient.InferenceServerClient(url="10.1.1.243:10071")
self.triton_client = grpcclient.InferenceServerClient(url=f"{settings.B_4_X_4090_SERVICE_HOST}:10071")
@RunTime
def get_result(self):
@@ -139,7 +139,7 @@ def get_bounding_box(mask):
if __name__ == "__main__":
test_data = ClothingSegModel(
user_id=89,
user_id="89",
image_data=[
# {
# "image_url": "test/clothing_seg/dress.jpg",

View File

@@ -13,7 +13,7 @@ from PIL import Image
from minio import Minio, S3Error
from moviepy.video.io.VideoFileClip import VideoFileClip
from app.core.config import settings
from app.core.config import settings, PS_RABBITMQ_QUEUES
from app.schemas.comfyui_i2v import ComfyuiPose2VModel
from app.service.generate_image.utils.mq import publish_status
@@ -622,9 +622,9 @@ class ComfyUIServerPose2V:
# 推送消息
if not settings.DEBUG:
publish_status(json.dumps(self.pose_transform_data), settings.COMFYUI_SERVER_ADDRESS)
publish_status(json.dumps(self.pose_transform_data), PS_RABBITMQ_QUEUES)
logger.info(
f" [x] Sent to {settings.COMFYUI_SERVER_ADDRESS} data@@@@ {json.dumps(self.pose_transform_data, indent=4)}")
f" [x] Sent to {PS_RABBITMQ_QUEUES} data@@@@ {json.dumps(self.pose_transform_data, indent=4)}")
return "\n🎉 所有任务完成!"

View File

@@ -10,13 +10,13 @@
import logging
import cv2
import mmcv
import numpy as np
import torch
import torch.nn.functional as F
import tritonclient.http as httpclient
from app.core.config import DESIGN_MODEL_URL, DESIGN_MODEL_NAME
from app.service.utils.image_normalize import my_imnormalize
"""
keypoint
@@ -25,13 +25,13 @@ from app.core.config import DESIGN_MODEL_URL, DESIGN_MODEL_NAME
def keypoint_preprocess(img_path):
img = mmcv.imread(img_path)
img = img_path
img_scale = (256, 256)
h, w = img.shape[:2]
img = cv2.resize(img, img_scale)
w_scale = img_scale[0] / w
h_scale = img_scale[1] / h
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
img = my_imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img, (w_scale, h_scale)
@@ -74,7 +74,7 @@ def keypoint_postprocess(output, scale_factor):
# KNet
def seg_preprocess(img_path):
img = mmcv.imread(img_path)
img = img_path
ori_shape = img.shape[:2]
img_scale_w, img_scale_h = ori_shape
if ori_shape[0] > 1024:
@@ -83,9 +83,9 @@ def seg_preprocess(img_path):
img_scale_h = 1024
# 如果图片size任意一边 大于 1024 则会resize 成1024
if ori_shape != (img_scale_w, img_scale_h):
# mmcv.imresize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了
# my_imnormalize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了
img = cv2.resize(img, (img_scale_h, img_scale_w))
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
img = my_imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img, ori_shape

File diff suppressed because it is too large Load Diff

View File

@@ -7,6 +7,7 @@ class BaseItem:
self.result['name'] = data['type'].lower()
self.result.pop("type")
self.result.update(basic)
self.result['design_type'] = basic.get('design_type', None)
class OthersItem(BaseItem):
@@ -14,10 +15,7 @@ class OthersItem(BaseItem):
super().__init__(data, basic)
self.Others_pipeline = [
LoadImage(minio_client),
# KeyPoint(),
# ContourDetection(),
Segmentation(minio_client),
# BackPerspective(minio_client),
Color(minio_client),
NoSegPrintPainting(minio_client),
PrintPainting(minio_client),
@@ -74,6 +72,65 @@ class BottomItem(BaseItem):
return self.result
"""merge"""
class OthersMergeItem(BaseItem):
def __init__(self, data, basic, minio_client):
super().__init__(data, basic)
self.Others_pipeline = [
LoadImage(minio_client),
# KeyPoint(),
# ContourDetection(),
Segmentation(minio_client),
# BackPerspective(minio_client),
Color(minio_client),
# NoSegPrintPainting(minio_client),
# PrintPainting(minio_client),
Scaling(),
Split(minio_client)
]
def process(self):
for item in self.Others_pipeline:
self.result = item(self.result)
return self.result
class TopMergeItem(BaseItem):
def __init__(self, data, basic, minio_client):
super().__init__(data, basic)
self.top_pipeline = [
LoadImage(minio_client),
KeyPoint(),
Segmentation(minio_client),
Scaling(),
Split(minio_client)
]
def process(self):
for item in self.top_pipeline:
self.result = item(self.result)
return self.result
class BottomMergeItem(BaseItem):
def __init__(self, data, basic, minio_client):
super().__init__(data, basic)
self.bottom_pipeline = [
LoadImage(minio_client),
KeyPoint(),
Segmentation(minio_client),
Scaling(),
Split(minio_client)
]
def process(self):
for item in self.bottom_pipeline:
self.result = item(self.result)
return self.result
class BodyItem(BaseItem):
def __init__(self, data, basic, minio_client):
super().__init__(data, basic)

View File

@@ -35,15 +35,9 @@ class LoadImage:
return cls.name
def __call__(self, result):
if result.get("merge_image_path"):
result['merge_image'], _ = self.read_image(result['merge_image_path'])
result['image'], result['pre_mask'] = self.read_image(result['path'])
# if 'extract_lines' in result.keys():
# if result['extract_lines']:
# result['gray'] = self.get_lines(cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY), result['path'])
# else:
# result['gray'] = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY)
# else:
# result['gray'] = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY)
result['gray'] = self.get_lines(cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY))
result['keypoint'] = self.get_keypoint(result['name'])
result['img_shape'] = result['image'].shape
@@ -61,21 +55,6 @@ class LoadImage:
mask = skeleton
result = np.ones_like(img) * 255
result[mask] = img[mask]
# 步骤2细化边缘可选让线条更干净
# kernel = np.ones((1, 1), np.uint8)
# clean = cv2.morphologyEx(binary, cv2.MORPH_OPEN, kernel)
# thinned = cv2.ximgproc.thinning(binary, thinningType=cv2.ximgproc.THINNING_ZHANGSUEN) # thinning算法细化线条
# mask = thinned > 0
# result = np.ones_like(img) * 255
# result[mask] = img[mask]
# 步骤3反转回 白底黑线
# lines = cv2.bitwise_not(thinned)
# cv2.imwrite(os.path.join('/home/user/PycharmProjects/trinity_client_aida/test/lines_original_result_5', f"Original_{path.replace('/', '-')}.png"), img)
# cv2.imwrite(os.path.join('/home/user/PycharmProjects/trinity_client_aida/test/lines_original_result_5', f"Line_{path.replace('/', '-')}.png"), result)
return result
def read_image(self, image_path):
@@ -96,19 +75,19 @@ class LoadImage:
@staticmethod
def get_keypoint(name):
if name == 'blouse' or name == 'outwear' or name == 'dress' or name == 'tops':
if name in ['blouse', 'outwear', 'dress', 'tops', 'blouse_merge', 'outwear_merge', 'dress_merge', 'tops_merge']:
keypoint = 'shoulder'
elif name == 'trousers' or name == 'skirt' or name == 'bottoms':
elif name in ['trousers', 'skirt', 'bottoms', 'trousers_merge', 'skirt_merge', 'bottoms_merge']:
keypoint = 'waistband'
elif name == 'bag':
elif name in ['bag', 'bag_merge']:
keypoint = 'hand_point'
elif name == 'shoes':
elif name in ['shoes', 'shoes_merge']:
keypoint = 'toe'
elif name == 'hairstyle':
elif name in ['hairstyle', 'hairstyle_merge']:
keypoint = 'head_point'
elif name == 'earring':
elif name in ['earring', 'earring_merge']:
keypoint = 'ear_point'
elif name == 'others':
elif name in ['others', 'others_merge']:
keypoint = "others"
else:
raise KeyError(f"{name} does not belong to item category list: blouse, outwear, dress, trousers, skirt, "

View File

@@ -12,9 +12,13 @@ class NoSegPrintPainting:
self.minio_client = minio_client
def __call__(self, result):
single_print = result['print']['single']
# single_print = [result['print']['single']]
overall_print = result['print']['overall']
element_print = result['print']['element']
# element_print = result['print']['element'
single_print = None
element_print = None
result['single_image'] = None
result['print_image'] = None
@@ -23,9 +27,9 @@ class NoSegPrintPainting:
# 获取平铺 + 旋转 的overall print
painting_dict = self.painting_collection(painting_dict, overall_print)
result['no_seg_sketch_overall'] = result['no_seg_sketch_print'] = self.printpaint(result, painting_dict, print_=True)
result['pattern_image'] = result['no_seg_sketch_overall']
# result['pattern_image'] = result['no_seg_sketch_overall']
if single_print['print_path_list']:
if single_print:
print_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
mask_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
for i in range(len(single_print['print_path_list'])):
@@ -65,7 +69,7 @@ class NoSegPrintPainting:
single_image = cv2.add(tmp1, tmp2)
result['no_seg_sketch_print'] = single_image
if element_print['element_path_list']:
if element_print:
print_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8)
mask_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8)
for i in range(len(element_print['element_path_list'])):
@@ -162,15 +166,17 @@ class NoSegPrintPainting:
dim_max = max(painting_dict['dim_image_h'], painting_dict['dim_image_w'])
dim_pattern = (int(dim_max * print_['scale'] / 5), int(dim_max * print_['scale'] / 5))
gap = print_dict.get('gap', [[0, 0]])[0]
painting_dict['tile_print'] = tile_image(pattern=print_['image'],
dim=dim_pattern,
gap_x=gap[0],
gap_y=gap[1],
canvas_h=painting_dict['dim_image_h'],
canvas_w=painting_dict['dim_image_w'],
location=painting_dict['location'],
angle=45)
painting_dict['mask_inv_print'] = np.zeros(painting_dict['tile_print'].shape[:2], dtype=np.uint8)
painting_dict['tile_print'], painting_dict['mask_inv_print'] = tile_image(pattern=print_['image'],
mask=print_['mask'],
dim=dim_pattern,
gap_x=gap[0],
gap_y=gap[1],
canvas_h=painting_dict['dim_image_h'],
canvas_w=painting_dict['dim_image_w'],
location=painting_dict['location'],
angle=int(print_.get('print_angle_list', [0])[0]))
# painting_dict['mask_inv_print'] = np.zeros(painting_dict['tile_print'].shape[:2], dtype=np.uint8)
# painting_dict['mask_inv_print'] = self.get_mask_inv(painting_dict['tile_print'])
return painting_dict
def tile_image(self, pattern, dim, scale, dim_image_h, dim_image_w, location, trigger=False):
@@ -251,10 +257,15 @@ class NoSegPrintPainting:
image = oss_get_image(oss_client=self.minio_client, bucket=bucket_name, object_name=object_name, data_type="PIL")
# 判断图片格式如果是RGBA 则贴在一张纯白图片上 防止透明转黑
if image.mode == "RGBA":
mask_pil = image.split()[3]
new_background = Image.new('RGB', image.size, (255, 255, 255))
new_background.paste(image, mask=image.split()[3])
image = new_background
else:
mask_pil = Image.new('L', image.size, 255) # L=灰度图255=纯白
print_dict['image'] = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
print_dict['mask'] = cv2.threshold(np.array(mask_pil), 127, 255, cv2.THRESH_BINARY)[1]
return print_dict
def crop_image(self, image, image_size_h, image_size_w, location, print_shape):
@@ -404,9 +415,12 @@ class NoSegPrintPainting:
return cropped_img
def tile_image(pattern, dim, gap_x, gap_y, canvas_h, canvas_w, location, angle=0):
def tile_image(pattern, mask, dim, gap_x, gap_y, canvas_h, canvas_w, location, angle=0):
"""
按照指定的 X/Y 间距平铺印花,并支持旋转
【修改版】以被平铺图案的【中心】作为平铺基准点
:param location: [[center_y, center_x]] → 第一个图案中心的坐标
:param angle: 旋转角度 (度数, 逆时针)
"""
# 1. 确保输入是 RGBA
@@ -418,35 +432,54 @@ def tile_image(pattern, dim, gap_x, gap_y, canvas_h, canvas_w, location, angle=0
rotated_p = rotate_image(resized_p, angle)
p_h, p_w = rotated_p.shape[:2]
# 3. 创建透明单元格
cell_h, cell_w = p_h + gap_y, p_w + gap_x
# 3. 创建透明单元格(图案放在单元格中心)
cell_h = p_h + gap_y
cell_w = p_w + gap_x
unit_cell = np.zeros((cell_h, cell_w, 4), dtype=np.uint8)
unit_cell[:p_h, :p_w, :] = rotated_p
# 计算图案在单元格中的左上角位置(让图案居中)
start_y = (cell_h - p_h) // 2
start_x = (cell_w - p_w) // 2
unit_cell[start_y:start_y + p_h, start_x:start_x + p_w, :] = rotated_p
# 4. 执行平铺
tiles_y = (canvas_h // cell_h) + 2
tiles_x = (canvas_w // cell_w) + 2
tiles_y = (canvas_h // cell_h) + 3 # 多加一点余量更安全
tiles_x = (canvas_w // cell_w) + 3
full_tiled = np.tile(unit_cell, (tiles_y, tiles_x, 1))
# 5. 裁剪平铺层
offset_x = int(location[0][1] % cell_w)
offset_y = int(location[0][0] % cell_h)
# 5. 计算偏移(关键修改:以中心为基准)
center_y, center_x = location[0][0], location[0][1] # 第一个图案的中心位置
# 计算从哪个位置开始裁剪,才能让中心落在指定坐标
offset_y = int((center_y - (p_h // 2)) % cell_h)
offset_x = int((center_x - (p_w // 2)) % cell_w)
tiled_layer = full_tiled[offset_y: offset_y + canvas_h,
offset_x: offset_x + canvas_w]
# 6. 创建纯白色背景并合成
# 创建一个纯白色的 BGR 画布
# 6. 创建纯白色背景并合成(保持你原来的风格)
white_background = np.full((canvas_h, canvas_w, 3), 255, dtype=np.uint8)
# 分离平铺层的颜色通道和 Alpha 通道
tiled_bgr = tiled_layer[:, :, :3]
alpha_mask = tiled_layer[:, :, 3] / 255.0 # 归一化到 0-1
alpha_mask = cv2.merge([alpha_mask, alpha_mask, alpha_mask]) # 扩展到 3 通道
alpha_mask = tiled_layer[:, :, 3] / 255.0
alpha_mask = cv2.merge([alpha_mask, alpha_mask, alpha_mask])
# 执行 Alpha 混合:结果 = 平铺层 * alpha + 背景 * (1 - alpha)
result = (tiled_bgr * alpha_mask + white_background * (1 - alpha_mask)).astype(np.uint8)
tiled_print = (tiled_bgr * alpha_mask + white_background * (1 - alpha_mask)).astype(np.uint8)
return result
# ====================== 处理 Mask ======================
# Mask 也同样居中处理
resized_mask = cv2.resize(mask, dim, interpolation=cv2.INTER_NEAREST)
rotated_mask = rotate_image(resized_mask, angle) # 注意mask也需要旋转
unit_mask = np.zeros((cell_h, cell_w), dtype=np.uint8)
unit_mask[start_y:start_y + p_h, start_x:start_x + p_w] = rotated_mask
full_mask_tiled = np.tile(unit_mask, (tiles_y, tiles_x))
tiled_mask = full_mask_tiled[offset_y: offset_y + canvas_h,
offset_x: offset_x + canvas_w]
return tiled_print, cv2.bitwise_not(tiled_mask)
def rotate_image(image, angle):

View File

@@ -12,10 +12,14 @@ class PrintPainting:
self.minio_client = minio_client
def __call__(self, result):
single_print = result['print']['single']
# single_print = result['print']['single']
overall_print = result['print']['overall']
element_print = result['print']['element']
partial_path = result['print']['partial'] if 'partial' in result['print'] else None
# element_print = result['print']['element']
# partial_path = result['print']['partial'] if 'partial' in result['print'] else None
single_print = None
element_print = None
partial_path = None
result['single_image'] = None
result['print_image'] = None
# TODO 给result['pattern_image'] resize 到resize_scale的大小
@@ -37,13 +41,13 @@ class PrintPainting:
if overall_print['print_path_list']:
overall_print['location'][0] = [x * y for x, y in zip(overall_print['location'][0], result['resize_scale'])]
painting_dict = {'dim_image_h': result['pattern_image'].shape[0], 'dim_image_w': result['pattern_image'].shape[1]}
result['print_image'] = result['pattern_image']
result['print_image'] = result['pattern_image'].copy()
# 获取平铺 + 旋转 的overall print
painting_dict = self.painting_collection(painting_dict, overall_print)
result['print_image'] = self.printpaint(result, painting_dict, print_=True)
result['single_image'] = result['final_image'] = result['pattern_image'] = result['print_image']
if single_print['print_path_list']:
if single_print:
# 2025-9-19 印花调整 印花坐标按照sketch的缩放比调整
sketch_resize_scale = result['resize_scale']
print_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
@@ -84,7 +88,7 @@ class PrintPainting:
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
result['single_image'] = cv2.add(tmp1, tmp2)
if element_print['element_path_list']:
if element_print:
# 2025-9-19 印花调整 印花坐标按照sketch的缩放比调整
sketch_resize_scale = result['resize_scale']
print_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8)
@@ -225,15 +229,15 @@ class PrintPainting:
dim_max = max(painting_dict['dim_image_h'], painting_dict['dim_image_w'])
dim_pattern = (int(dim_max * print_['scale'] / 5), int(dim_max * print_['scale'] / 5))
gap = print_dict.get('gap', [[0, 0]])[0]
painting_dict['tile_print'] = tile_image(pattern=print_['image'],
dim=dim_pattern,
gap_x=gap[0],
gap_y=gap[1],
canvas_h=painting_dict['dim_image_h'],
canvas_w=painting_dict['dim_image_w'],
location=painting_dict['location'],
angle=45)
painting_dict['mask_inv_print'] = np.zeros(painting_dict['tile_print'].shape[:2], dtype=np.uint8)
painting_dict['tile_print'], painting_dict['mask_inv_print'] = tile_image(pattern=print_['image'],
mask=print_['mask'],
dim=dim_pattern,
gap_x=gap[0],
gap_y=gap[1],
canvas_h=painting_dict['dim_image_h'],
canvas_w=painting_dict['dim_image_w'],
location=painting_dict['location'],
angle=int(print_.get('print_angle_list', [0])[0]))
return painting_dict
def tile_image(self, pattern, dim, scale, dim_image_h, dim_image_w, location, trigger=False):
@@ -314,10 +318,15 @@ class PrintPainting:
image = oss_get_image(oss_client=self.minio_client, bucket=bucket_name, object_name=object_name, data_type="PIL")
# 判断图片格式如果是RGBA 则贴在一张纯白图片上 防止透明转黑
if image.mode == "RGBA":
mask_pil = image.split()[3]
new_background = Image.new('RGB', image.size, (255, 255, 255))
new_background.paste(image, mask=image.split()[3])
image = new_background
else:
mask_pil = Image.new('L', image.size, 255) # L=灰度图255=纯白
print_dict['image'] = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
print_dict['mask'] = cv2.threshold(np.array(mask_pil), 127, 255, cv2.THRESH_BINARY)[1]
return print_dict
def crop_image(self, image, image_size_h, image_size_w, location, print_shape):
@@ -467,9 +476,12 @@ class PrintPainting:
return cropped_img
def tile_image(pattern, dim, gap_x, gap_y, canvas_h, canvas_w, location, angle=0):
def tile_image(pattern, mask, dim, gap_x, gap_y, canvas_h, canvas_w, location, angle=0):
"""
按照指定的 X/Y 间距平铺印花,并支持旋转
【修改版】以被平铺图案的【中心】作为平铺基准点
:param location: [[center_y, center_x]] → 第一个图案中心的坐标
:param angle: 旋转角度 (度数, 逆时针)
"""
# 1. 确保输入是 RGBA
@@ -481,35 +493,54 @@ def tile_image(pattern, dim, gap_x, gap_y, canvas_h, canvas_w, location, angle=0
rotated_p = rotate_image(resized_p, angle)
p_h, p_w = rotated_p.shape[:2]
# 3. 创建透明单元格
cell_h, cell_w = p_h + gap_y, p_w + gap_x
# 3. 创建透明单元格(图案放在单元格中心)
cell_h = p_h + gap_y
cell_w = p_w + gap_x
unit_cell = np.zeros((cell_h, cell_w, 4), dtype=np.uint8)
unit_cell[:p_h, :p_w, :] = rotated_p
# 计算图案在单元格中的左上角位置(让图案居中)
start_y = (cell_h - p_h) // 2
start_x = (cell_w - p_w) // 2
unit_cell[start_y:start_y + p_h, start_x:start_x + p_w, :] = rotated_p
# 4. 执行平铺
tiles_y = (canvas_h // cell_h) + 2
tiles_x = (canvas_w // cell_w) + 2
tiles_y = (canvas_h // cell_h) + 3 # 多加一点余量更安全
tiles_x = (canvas_w // cell_w) + 3
full_tiled = np.tile(unit_cell, (tiles_y, tiles_x, 1))
# 5. 裁剪平铺层
offset_x = int(location[0][1] % cell_w)
offset_y = int(location[0][0] % cell_h)
# 5. 计算偏移(关键修改:以中心为基准)
center_y, center_x = location[0][0], location[0][1] # 第一个图案的中心位置
# 计算从哪个位置开始裁剪,才能让中心落在指定坐标
offset_y = int((center_y - (p_h // 2)) % cell_h)
offset_x = int((center_x - (p_w // 2)) % cell_w)
tiled_layer = full_tiled[offset_y: offset_y + canvas_h,
offset_x: offset_x + canvas_w]
# 6. 创建纯白色背景并合成
# 创建一个纯白色的 BGR 画布
# 6. 创建纯白色背景并合成(保持你原来的风格)
white_background = np.full((canvas_h, canvas_w, 3), 255, dtype=np.uint8)
# 分离平铺层的颜色通道和 Alpha 通道
tiled_bgr = tiled_layer[:, :, :3]
alpha_mask = tiled_layer[:, :, 3] / 255.0 # 归一化到 0-1
alpha_mask = cv2.merge([alpha_mask, alpha_mask, alpha_mask]) # 扩展到 3 通道
alpha_mask = tiled_layer[:, :, 3] / 255.0
alpha_mask = cv2.merge([alpha_mask, alpha_mask, alpha_mask])
# 执行 Alpha 混合:结果 = 平铺层 * alpha + 背景 * (1 - alpha)
result = (tiled_bgr * alpha_mask + white_background * (1 - alpha_mask)).astype(np.uint8)
tiled_print = (tiled_bgr * alpha_mask + white_background * (1 - alpha_mask)).astype(np.uint8)
return result
# ====================== 处理 Mask ======================
# Mask 也同样居中处理
resized_mask = cv2.resize(mask, dim, interpolation=cv2.INTER_NEAREST)
rotated_mask = rotate_image(resized_mask, angle) # 注意mask也需要旋转
unit_mask = np.zeros((cell_h, cell_w), dtype=np.uint8)
unit_mask[start_y:start_y + p_h, start_x:start_x + p_w] = rotated_mask
full_mask_tiled = np.tile(unit_mask, (tiles_y, tiles_x))
tiled_mask = full_mask_tiled[offset_y: offset_y + canvas_h,
offset_x: offset_x + canvas_w]
return tiled_print, cv2.bitwise_not(tiled_mask)
def rotate_image(image, angle):

View File

@@ -34,22 +34,25 @@ class Segmentation:
result['mask'] = result['front_mask'] + result['back_mask']
else:
# preview 过模型 不缓存
if "preview_submit" in result.keys() and result['preview_submit'] == "preview":
# 推理获得seg 结果
if result.get("design_type", None) == "merge":
seg_result = get_seg_result(result['image'])
# submit 过模型 缓存
elif "preview_submit" in result.keys() and result['preview_submit'] == "submit":
# 推理获得seg 结果
seg_result = get_seg_result(result['image'])
self.save_seg_result(seg_result, result['image_id'])
# null 正常流程 加载本地缓存 无缓存则过模型
# 默认design 模式 - 过模型 缓存
# elif result.get("design_type", None) == "submit":
# 推理获得seg 结果
# seg_result = get_seg_result(result['image'])
# self.save_seg_result(seg_result, result['image_id'])
# 默认模式- 加载模型,找不到则过模型推理,推理后保存到本地
else:
# 本地查询seg 缓存是否存在
_, seg_result = self.load_seg_result(result["image_id"])
# 判断缓存和实际图片size是否相同
_ = False
if not _ or result["image"].shape[:2] != seg_result.shape:
# 推理获得seg 结果
seg_result = get_seg_result(result['image'])
if result['name'] == 'others':
seg_result = seg_result.clip(max=1)
self.save_seg_result(seg_result, result['image_id'])
result['seg_result'] = seg_result

View File

@@ -4,6 +4,7 @@ import logging
import cv2
import numpy as np
from PIL import Image
from celery.bin.result import result
from app.service.design_fast.utils.conversion_image import rgb_to_rgba
from app.service.design_fast.utils.transparent import sketch_to_transparent
@@ -19,105 +20,122 @@ class Split(object):
def __call__(self, result):
try:
if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms', 'others'):
ori_front_mask = result['front_mask'].copy()
ori_back_mask = result['back_mask'].copy()
if result.get('design_type', None) == 'merge':
ori_front_mask = result['front_mask'].copy()
ori_back_mask = result['back_mask'].copy()
if result['resize_scale'][0] == 1.0 and result['resize_scale'][1] == 1.0:
front_mask = result['front_mask']
back_mask = result['back_mask']
else:
height, width = result['front_mask'].shape[:2]
new_width = int(width * result['resize_scale'][0])
new_height = int(height * result['resize_scale'][1])
front_mask = cv2.resize(result['front_mask'], (new_width, new_height), interpolation=cv2.INTER_AREA)
back_mask = cv2.resize(result['back_mask'], (new_width, new_height), interpolation=cv2.INTER_AREA)
rgba_image = rgb_to_rgba(result['final_image'], front_mask + back_mask)
new_size = (int(rgba_image.shape[1] * result["scale"]), int(rgba_image.shape[0] * result["scale"]))
rgba_image = cv2.resize(rgba_image, new_size, interpolation=cv2.INTER_AREA)
result_front_image = np.zeros_like(rgba_image)
front_mask = cv2.resize(front_mask, new_size, interpolation=cv2.INTER_AREA)
result_front_image[front_mask != 0] = rgba_image[front_mask != 0]
result_front_image_pil = Image.fromarray(cv2.cvtColor(result_front_image, cv2.COLOR_BGR2RGBA))
if 'transparent' in result.keys():
# 用户自选区域transparent
transparent = result['transparent']
if transparent['mask_url'] is not None and transparent['mask_url'] != "":
# 预处理用户自选区mask
seg_mask = oss_get_image(oss_client=self.minio_client, bucket=transparent['mask_url'].split('/')[0], object_name=transparent['mask_url'][transparent['mask_url'].find('/') + 1:], data_type="cv2")
seg_mask = cv2.resize(seg_mask, new_size, interpolation=cv2.INTER_AREA)
# 转换颜色空间为 RGBOpenCV 默认是 BGR
image_rgb = cv2.cvtColor(seg_mask, cv2.COLOR_BGR2RGB)
r, g, b = cv2.split(image_rgb)
blue_mask = b > r
# 创建红色和绿色掩码
transparent_mask = np.array(blue_mask, dtype=np.uint8) * 255
result_front_image_pil = sketch_to_transparent(result_front_image_pil, transparent_mask, transparent["scale"])
if result['resize_scale'][0] == 1.0 and result['resize_scale'][1] == 1.0:
front_mask = result['front_mask']
back_mask = result['back_mask']
else:
result_front_image_pil = sketch_to_transparent(result_front_image_pil, front_mask, transparent["scale"])
result['front_image'], result["front_image_url"], _ = upload_png_mask(self.minio_client, result_front_image_pil, f'{generate_uuid()}', mask=None)
height, width = result['front_mask'].shape[:2]
new_width = int(width * result['resize_scale'][0])
new_height = int(height * result['resize_scale'][1])
# 前片部分 (红图部分)
# height, width = front_mask.shape
# mask_image = np.zeros((height, width, 3))
# mask_image[front_mask != 0] = [0, 0, 255]
front_mask = cv2.resize(result['front_mask'], (new_width, new_height), interpolation=cv2.INTER_AREA)
back_mask = cv2.resize(result['back_mask'], (new_width, new_height), interpolation=cv2.INTER_AREA)
result['merge_image'] = cv2.resize(result['merge_image'], (new_width, new_height), interpolation=cv2.INTER_AREA)
# 切换为原始图片尺寸-------------------------------
height, width = ori_front_mask.shape
mask_image = np.zeros((height, width, 3))
mask_image[ori_front_mask != 0] = [0, 0, 255]
# -----------------------------------------------
rgba_image = rgb_to_rgba(result['merge_image'], front_mask + back_mask)
new_size = (int(rgba_image.shape[1] * result["scale"]), int(rgba_image.shape[0] * result["scale"]))
rgba_image = cv2.resize(rgba_image, new_size, interpolation=cv2.INTER_AREA)
result_front_image = np.zeros_like(rgba_image)
front_mask = cv2.resize(front_mask, new_size, interpolation=cv2.INTER_AREA)
result_front_image[front_mask != 0] = rgba_image[front_mask != 0]
result_front_image_pil = Image.fromarray(cv2.cvtColor(result_front_image, cv2.COLOR_BGR2RGBA))
result['front_image'], result["front_image_url"], _ = upload_png_mask(self.minio_client, result_front_image_pil, f'{generate_uuid()}', mask=None)
# if result["name"] in ('blouse', 'dress', 'outwear', 'tops'):
# result_back_image = np.zeros_like(rgba_image)
# back_mask = cv2.resize(back_mask, new_size, interpolation=cv2.INTER_AREA)
# result_back_image[back_mask != 0] = rgba_image[back_mask != 0]
# result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA))
# result['back_image'], result["back_image_url"], _ = upload_png_mask(self.minio_client, result_back_image_pil, f'{generate_uuid()}', mask=None)
# mask_image[back_mask != 0] = [0, 255, 0]
#
# rbga_mask = rgb_to_rgba(mask_image, front_mask + back_mask)
# mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
# image_data = io.BytesIO()
# mask_pil.save(image_data, format='PNG')
# image_data.seek(0)
# image_bytes = image_data.read()
# req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
# result['mask_url'] = req.bucket_name + "/" + req.object_name
# else:
# rbga_mask = rgb_to_rgba(mask_image, front_mask)
# mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
# image_data = io.BytesIO()
# mask_pil.save(image_data, format='PNG')
# image_data.seek(0)
# image_bytes = image_data.read()
# req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
# result['mask_url'] = req.bucket_name + "/" + req.object_name
# result['back_image'] = None
# result["back_image_url"] = None
# # result["back_mask_url"] = None
# # result['back_mask_image'] = None
height, width = ori_front_mask.shape
mask_image = np.zeros((height, width, 3))
mask_image[ori_front_mask != 0] = [0, 0, 255]
mask_image[ori_back_mask != 0] = [0, 255, 0]
rbga_mask = rgb_to_rgba(mask_image, ori_front_mask + ori_back_mask)
mask_pil = Image.fromarray(cv2.cvtColor(rbga_mask.astype(np.uint8), cv2.COLOR_BGR2RGBA))
image_data = io.BytesIO()
mask_pil.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
req = oss_upload_image(oss_client=self.minio_client, bucket="aida-clothing", object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
result['mask_url'] = req.bucket_name + "/" + req.object_name
result_back_image = np.zeros_like(rgba_image)
back_mask = cv2.resize(back_mask, new_size, interpolation=cv2.INTER_AREA)
result_back_image[back_mask != 0] = rgba_image[back_mask != 0]
result_back_image_pil = Image.fromarray(cv2.cvtColor(result_back_image, cv2.COLOR_BGR2RGBA))
result['back_image'], result["back_image_url"], _ = upload_png_mask(self.minio_client, result_back_image_pil, f'{generate_uuid()}', mask=None)
# mask_image[back_mask != 0] = [0, 255, 0]
mask_image[ori_back_mask != 0] = [0, 255, 0]
result_back_image = np.zeros_like(rgba_image)
back_mask = cv2.resize(back_mask, new_size, interpolation=cv2.INTER_AREA)
result_back_image[back_mask != 0] = rgba_image[back_mask != 0]
result_back_image_pil = Image.fromarray(cv2.cvtColor(result_back_image, cv2.COLOR_BGR2RGBA))
result['back_image'], result["back_image_url"], _ = upload_png_mask(self.minio_client, result_back_image_pil, f'{generate_uuid()}', mask=None)
return result
else:
ori_front_mask = result['front_mask'].copy()
ori_back_mask = result['back_mask'].copy()
rbga_mask = rgb_to_rgba(mask_image, ori_front_mask + ori_back_mask)
mask_pil = Image.fromarray(cv2.cvtColor(rbga_mask.astype(np.uint8), cv2.COLOR_BGR2RGBA))
image_data = io.BytesIO()
mask_pil.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
req = oss_upload_image(oss_client=self.minio_client, bucket="aida-clothing", object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
result['mask_url'] = req.bucket_name + "/" + req.object_name
if result['resize_scale'][0] == 1.0 and result['resize_scale'][1] == 1.0:
front_mask = result['front_mask']
back_mask = result['back_mask']
else:
height, width = result['front_mask'].shape[:2]
new_width = int(width * result['resize_scale'][0])
new_height = int(height * result['resize_scale'][1])
front_mask = cv2.resize(result['front_mask'], (new_width, new_height), interpolation=cv2.INTER_AREA)
back_mask = cv2.resize(result['back_mask'], (new_width, new_height), interpolation=cv2.INTER_AREA)
rgba_image = rgb_to_rgba(result['final_image'], front_mask + back_mask)
new_size = (int(rgba_image.shape[1] * result["scale"]), int(rgba_image.shape[0] * result["scale"]))
rgba_image = cv2.resize(rgba_image, new_size, interpolation=cv2.INTER_AREA)
result_front_image = np.zeros_like(rgba_image)
front_mask = cv2.resize(front_mask, new_size, interpolation=cv2.INTER_AREA)
result_front_image[front_mask != 0] = rgba_image[front_mask != 0]
result_front_image_pil = Image.fromarray(cv2.cvtColor(result_front_image, cv2.COLOR_BGR2RGBA))
if 'transparent' in result.keys():
# 用户自选区域transparent
transparent = result['transparent']
if transparent['mask_url'] is not None and transparent['mask_url'] != "":
# 预处理用户自选区mask
seg_mask = oss_get_image(oss_client=self.minio_client, bucket=transparent['mask_url'].split('/')[0], object_name=transparent['mask_url'][transparent['mask_url'].find('/') + 1:], data_type="cv2")
seg_mask = cv2.resize(seg_mask, new_size, interpolation=cv2.INTER_AREA)
# 转换颜色空间为 RGBOpenCV 默认是 BGR
image_rgb = cv2.cvtColor(seg_mask, cv2.COLOR_BGR2RGB)
r, g, b = cv2.split(image_rgb)
blue_mask = b > r
# 创建红色和绿色掩码
transparent_mask = np.array(blue_mask, dtype=np.uint8) * 255
result_front_image_pil = sketch_to_transparent(result_front_image_pil, transparent_mask, transparent["scale"])
else:
result_front_image_pil = sketch_to_transparent(result_front_image_pil, front_mask, transparent["scale"])
result['front_image'], result["front_image_url"], _ = upload_png_mask(self.minio_client, result_front_image_pil, f'{generate_uuid()}', mask=None)
height, width = ori_front_mask.shape
mask_image = np.zeros((height, width, 3))
mask_image[ori_front_mask != 0] = [0, 0, 255]
result_back_image = np.zeros_like(rgba_image)
back_mask = cv2.resize(back_mask, new_size, interpolation=cv2.INTER_AREA)
result_back_image[back_mask != 0] = rgba_image[back_mask != 0]
result_back_image_pil = Image.fromarray(cv2.cvtColor(result_back_image, cv2.COLOR_BGR2RGBA))
result['back_image'], result["back_image_url"], _ = upload_png_mask(self.minio_client, result_back_image_pil, f'{generate_uuid()}', mask=None)
# mask_image[back_mask != 0] = [0, 255, 0]
mask_image[ori_back_mask != 0] = [0, 255, 0]
rbga_mask = rgb_to_rgba(mask_image, ori_front_mask + ori_back_mask)
mask_pil = Image.fromarray(cv2.cvtColor(rbga_mask.astype(np.uint8), cv2.COLOR_BGR2RGBA))
image_data = io.BytesIO()
mask_pil.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
req = oss_upload_image(oss_client=self.minio_client, bucket="aida-clothing", object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
result['mask_url'] = req.bucket_name + "/" + req.object_name
# 创建中间图层(未分割图层) 1.color + overall_print 2.color + overall_print + print
result_pattern_overall_image_pil = Image.fromarray(cv2.cvtColor(rgb_to_rgba(result['no_seg_sketch_overall'], ori_front_mask + ori_back_mask), cv2.COLOR_BGR2RGBA))
result['pattern_overall_image'], result['pattern_overall_image_url'], _ = upload_png_mask(self.minio_client, result_pattern_overall_image_pil, f'{generate_uuid()}')
result_pattern_print_image_pil = Image.fromarray(cv2.cvtColor(rgb_to_rgba(result['no_seg_sketch_print'], ori_front_mask + ori_back_mask), cv2.COLOR_BGR2RGBA))
result['pattern_print_image'], result['pattern_print_image_url'], _ = upload_png_mask(self.minio_client, result_pattern_print_image_pil, f'{generate_uuid()}')
return result
else:
ori_front_mask, ori_back_mask = None, None
# 创建中间图层(未分割图层) 1.color + overall_print 2.color + overall_print + print
@@ -127,5 +145,6 @@ class Split(object):
result_pattern_print_image_pil = Image.fromarray(cv2.cvtColor(rgb_to_rgba(result['no_seg_sketch_print'], ori_front_mask + ori_back_mask), cv2.COLOR_BGR2RGBA))
result['pattern_print_image'], result['pattern_print_image_url'], _ = upload_png_mask(self.minio_client, result_pattern_print_image_pil, f'{generate_uuid()}')
return result
except Exception as e:
logging.warning(f"split runtime exception : {e} image_id : {result['image_id']}")

View File

@@ -10,12 +10,12 @@
import logging
import cv2
import mmcv
import numpy as np
import torch
import tritonclient.http as httpclient
from app.core.config import DESIGN_MODEL_URL, DESIGN_MODEL_NAME
from app.service.utils.image_normalize import my_imnormalize
"""
keypoint
@@ -24,14 +24,14 @@ from app.core.config import DESIGN_MODEL_URL, DESIGN_MODEL_NAME
def keypoint_preprocess(img_path):
img = mmcv.imread(img_path)
img = img_path
img = cv2.copyMakeBorder(img, 25, 25, 25, 25, cv2.BORDER_CONSTANT, value=[255, 255, 255])
img_scale = (256, 256)
h, w = img.shape[:2]
img = cv2.resize(img, img_scale)
w_scale = img_scale[0] / w
h_scale = img_scale[1] / h
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
img = my_imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img, (w_scale, h_scale)
@@ -78,7 +78,7 @@ def keypoint_postprocess(output, scale_factor):
# KNet
def seg_preprocess(img_path):
img = mmcv.imread(img_path)
img = img_path
ori_shape = img.shape[:2]
img_scale_w, img_scale_h = ori_shape
if ori_shape[0] > 1024:
@@ -87,12 +87,12 @@ def seg_preprocess(img_path):
img_scale_h = 1024
# 如果图片size任意一边 大于 1024 则会resize 成1024
if ori_shape != (img_scale_w, img_scale_h):
# mmcv.imresize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了
# my_imnormalize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了
img = cv2.resize(img, (img_scale_h, img_scale_w))
# 扩充25的白边
img = cv2.copyMakeBorder(img, 25, 25, 25, 25, cv2.BORDER_CONSTANT, value=[255, 255, 255])
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
img = my_imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img, ori_shape

View File

@@ -23,19 +23,20 @@ def organize_clothing(layer):
front_layer = dict(priority=layer['priority'] if layer.get("layer_order", False) else PRIORITY_DICT.get(f'{layer["name"].lower()}_front', None),
name=f'{layer["name"].lower()}_front',
image=layer["front_image"],
merge_image=layer["front_image"],
# mask_image=layer['front_mask_image'],
image_url=layer['front_image_url'],
mask_url=layer['mask_url'],
mask_url=layer.get("mask_url", None),
sacle=layer['scale'],
clothes_keypoint=layer['clothes_keypoint'],
position=start_point,
resize_scale=layer["resize_scale"],
mask=cv2.resize(layer['mask'], layer["front_image"].size),
gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "",
pattern_overall_image_url=layer['pattern_overall_image_url'],
pattern_print_image_url=layer['pattern_print_image_url'],
pattern_overall_image_url=layer.get('pattern_overall_image_url', None),
pattern_print_image_url=layer.get('pattern_print_image_url', None),
pattern_image=layer['pattern_image'],
pattern_image=layer.get('pattern_image', None),
# back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else ""
transpose=layer.get("transpose", [1, 1]), # 默认为1, 1代表不镜像
rotate=layer.get('rotate', 0),
@@ -46,17 +47,17 @@ def organize_clothing(layer):
image=layer["back_image"],
# mask_image=layer['back_mask_image'],
image_url=layer['back_image_url'],
mask_url=layer['mask_url'],
mask_url=layer.get('mask_url', None),
sacle=layer['scale'],
clothes_keypoint=layer['clothes_keypoint'],
position=start_point,
resize_scale=layer["resize_scale"],
mask=cv2.resize(layer['mask'], layer["front_image"].size),
gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "",
pattern_overall_image_url=layer['pattern_overall_image_url'],
pattern_print_image_url=layer['pattern_print_image_url'],
pattern_overall_image_url=layer.get('pattern_overall_image_url', None),
pattern_print_image_url=layer.get('pattern_print_image_url', None),
# back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else ""
transpose=layer.get("transpose", [1, 1]), # 默认为1, 1代表不镜像
transpose=layer.get("transpose", [1, 1]), # 默认为1, 1代表不镜像
rotate=layer.get('rotate', 0),
)
return front_layer, back_layer
@@ -80,17 +81,19 @@ def organize_others(layer):
image=layer["front_image"],
# mask_image=layer['front_mask_image'],
image_url=layer['front_image_url'],
mask_url=layer['mask_url'],
mask_url=layer.get('mask_url', None),
sacle=layer['scale'],
clothes_keypoint=(0, 0),
position=start_point,
resize_scale=layer["resize_scale"],
mask=cv2.resize(layer['mask'], layer["front_image"].size),
gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "",
pattern_overall_image_url=layer['pattern_overall_image_url'],
pattern_print_image_url=layer['pattern_print_image_url'],
pattern_image=layer['pattern_image'],
pattern_overall_image_url=layer.get('pattern_overall_image_url', None),
pattern_print_image_url=layer.get('pattern_print_image_url', None),
pattern_image=layer.get('pattern_image', None),
# back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else ""
transpose=layer.get("transpose", [1, 1]), # 默认为1, 1代表不镜像
rotate=layer.get('rotate', 0),
)
# 后片数据
back_layer = dict(priority=-layer.get("priority", 0) if layer.get("layer_order", False) else PRIORITY_DICT.get(f'{layer["name"].lower()}_back', None),
@@ -98,16 +101,18 @@ def organize_others(layer):
image=layer["back_image"],
# mask_image=layer['back_mask_image'],
image_url=layer['back_image_url'],
mask_url=layer['mask_url'],
mask_url=layer.get('mask_url', None),
sacle=layer['scale'],
clothes_keypoint=(0, 0),
position=start_point,
resize_scale=layer["resize_scale"],
mask=cv2.resize(layer['mask'], layer["front_image"].size),
gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "",
pattern_overall_image_url=layer['pattern_overall_image_url'],
pattern_print_image_url=layer['pattern_print_image_url'],
pattern_overall_image_url=layer.get('pattern_overall_image_url', None),
pattern_print_image_url=layer.get('pattern_print_image_url', None),
# back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else ""
transpose=layer.get("transpose", [1, 1]), # 默认为1, 1代表不镜像
rotate=layer.get('rotate', 0),
)
return front_layer, back_layer

View File

@@ -187,6 +187,111 @@ def synthesis(data, size, basic_info):
logging.warning(f"synthesis runtime exception : {e}")
def merge(data, size, basic_info):
# out_of_bounds_control: 是否允许服装越界 True 允许 False 不允许 默认情况允许
out_of_bounds_control = basic_info.get('out_of_bounds_control', True)
# 创建底图
base_image = Image.new('RGBA', size, (0, 0, 0, 0))
try:
all_mask_shape = (size[1], size[0])
body_mask = None
for d in data:
if d['name'] == 'body' or d['name'] == 'mannequin':
# 创建一个新的宽高透明图像, 把模特贴上去获取mask
transparent_image = Image.new("RGBA", size, (0, 0, 0, 0))
transparent_image.paste(d['image'], (d['adaptive_position'][1], d['adaptive_position'][0]), d['image']) # 此处可变数组会被paste篡改值所以使用下标获取position
body_mask = np.array(transparent_image.split()[3])
# 根据新的坐标获取新的肩点
left_shoulder = [x + y for x, y in zip(basic_info['body_point_test']['shoulder_left'], [d['adaptive_position'][1], d['adaptive_position'][0]])]
right_shoulder = [x + y for x, y in zip(basic_info['body_point_test']['shoulder_right'], [d['adaptive_position'][1], d['adaptive_position'][0]])]
body_mask[:min(left_shoulder[1], right_shoulder[1]), left_shoulder[0]:right_shoulder[0]] = 255
_, binary_body_mask = cv2.threshold(body_mask, 127, 255, cv2.THRESH_BINARY)
top_outer_mask = np.array(binary_body_mask)
bottom_outer_mask = np.array(binary_body_mask)
others_outer_mask = np.array(binary_body_mask)
top = True
bottom = True
others = True
i = len(data)
while i:
i -= 1
if top and data[i]['name'] in ["blouse_front", "outwear_front", "dress_front", "tops_front"]:
if out_of_bounds_control:
top = True
else:
top = False
mask_shape = data[i]['mask'].shape
y_offset, x_offset = data[i]['adaptive_position']
# 初始化叠加区域的起始和结束位置
all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset)
all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset)
# 将叠加区域赋值为相应的像素值
_, sketch_mask = cv2.threshold(data[i]['mask'], 127, 255, cv2.THRESH_BINARY)
background = np.zeros_like(top_outer_mask)
background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end]
top_outer_mask = background + top_outer_mask
elif bottom and data[i]['name'] in ["trousers_front", "skirt_front", "bottoms_front", "dress_front"]:
# bottom = False
mask_shape = data[i]['mask'].shape
y_offset, x_offset = data[i]['adaptive_position']
# 初始化叠加区域的起始和结束位置
all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset)
all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset)
# 将叠加区域赋值为相应的像素值
_, sketch_mask = cv2.threshold(data[i]['mask'], 127, 255, cv2.THRESH_BINARY)
background = np.zeros_like(top_outer_mask)
background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end]
bottom_outer_mask = background + bottom_outer_mask
elif others and data[i]['name'] in ['others_front']:
mask_shape = data[i]['mask'].shape
y_offset, x_offset = data[i]['adaptive_position']
# 初始化叠加区域的起始和结束位置
all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset)
all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset)
# 将叠加区域赋值为相应的像素值
_, sketch_mask = cv2.threshold(data[i]['mask'], 127, 255, cv2.THRESH_BINARY)
background = np.zeros_like(top_outer_mask)
background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end]
others_outer_mask = background + others_outer_mask
pass
elif bottom is False and top is False:
break
all_mask = cv2.bitwise_or(top_outer_mask, bottom_outer_mask)
all_mask = cv2.bitwise_or(all_mask, others_outer_mask)
for layer in data:
if layer['image'] is not None:
if layer['name'] != "body":
test_image = Image.new('RGBA', size, (0, 0, 0, 0))
paste_img, position = transpose_rotate(layer, layer['image'])
test_image.paste(paste_img, position, paste_img)
mask_data = np.where(all_mask > 0, 255, 0).astype(np.uint8)
mask_alpha = Image.fromarray(mask_data)
mask_alpha.paste(paste_img.getchannel('A'), position, paste_img.getchannel('A'))
cropped_image = Image.composite(test_image, Image.new("RGBA", test_image.size, (255, 255, 255, 0)), mask_alpha)
base_image.paste(test_image, (0, 0), cropped_image) # test_image 已经按照坐标贴到最大宽值的图片上 坐着这里坐标为00
else:
base_image.paste(layer['merge_image'], (layer['adaptive_position'][1], layer['adaptive_position'][0]), layer['merge_image'])
result_image = base_image
image_data = io.BytesIO()
result_image.save(image_data, format='PNG')
image_data.seek(0)
# oss upload
image_bytes = image_data.read()
bucket_name = "aida-results"
object_name = f'result_{generate_uuid()}.png'
oss_upload_image(oss_client=minio_client, bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
return f"{bucket_name}/{object_name}"
except Exception as e:
logging.warning(f"synthesis runtime exception : {e}")
def synthesis_single(front_image, back_image):
result_image = None
if front_image:
@@ -242,7 +347,8 @@ def transpose_rotate(layer, image):
rotate = layer.get('rotate', 0)
paste_x, paste_y = layer['adaptive_position'][1], layer['adaptive_position'][0]
original_w = image.width
original_h = image.height
# transpose左右是1 上下是-1
if transpose[0] != 1:
# 左右
@@ -256,8 +362,8 @@ def transpose_rotate(layer, image):
image = image.rotate(-rotate, expand=True)
# 4. 计算粘贴位置以保持视觉中心一致
# 原本 (15, 36) 是 288*288 的左上角,我们计算其中心点
target_center_x = 15 + 288 // 2
target_center_y = 36 + 288 // 2
target_center_x = paste_x + original_w // 2
target_center_y = paste_y + original_h // 2
# 获取旋转后图像的新尺寸
new_w, new_h = image.size
@@ -265,4 +371,4 @@ def transpose_rotate(layer, image):
# 计算新的左上角坐标,使得旋转后的图像中心依然在原定的中心位置
paste_x = target_center_x - new_w // 2
paste_y = target_center_y - new_h // 2
return image, (paste_x, paste_y)
return image, (paste_x, paste_y)

View File

@@ -7,7 +7,7 @@ import numpy as np
import torch
import tritonclient.grpc as grpcclient
from minio import Minio
from pymilvus import MilvusClient
# from pymilvus import MilvusClient
from urllib3.exceptions import ResponseError
from app.core.config import settings, SR_MODEL_NAME, SR_TRITON_URL, MILVUS_TABLE_KEYPOINT, KEYPOINT_RESULT_TABLE_FIELD_SET
@@ -58,7 +58,21 @@ class DesignPreprocessing:
if len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
elif image.shape[2] == 4: # 如果是四通道 mask
image = image[:, :, :3]
# 分离RGB和Alpha通道
bgr = image[:, :, :3]
alpha = image[:, :, 3]
# 创建白色背景(也可改为其他颜色,如(255,255,255)就是白色)
background_color = (255, 255, 255)
background = np.full_like(bgr, background_color)
# 将Alpha通道转换为掩码0=透明255=不透明)
alpha_mask = alpha / 255.0 # 归一化到0-1
alpha_mask = np.expand_dims(alpha_mask, axis=-1) # 扩展维度,方便广播计算
# 混合背景和原图:透明区域显示背景色,不透明区域显示原图
image = (bgr * alpha_mask + background * (1 - alpha_mask)).astype(np.uint8)
# 此时image已经是3通道RGB无需再执行image = image[:, :, :3]
obj["image_obj"] = image
return image_list
@@ -174,8 +188,9 @@ class DesignPreprocessing:
scale = 0.4
if waist_width / scale >= image_width:
add_width = int((waist_width / scale - image_width) / 2)
ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(256, 256, 256))
image_bytes = cv2.imencode(".jpg", ret)[1].tobytes()
ret = cv2.copyMakeBorder(image['obj'], 0, 0, add_width, add_width, cv2.BORDER_CONSTANT, value=(255, 255, 255))
img_rgba = cv2.cvtColor(ret, cv2.COLOR_RGB2RGBA)
image_bytes = cv2.imencode(".png", img_rgba)[1].tobytes()
# image['show_image_url'] = f"{image['image_url'].split('/', 1)[0]}/{self.minio_client.put_object(image['image_url'].split('/', 1)[0], image['image_url'].split('/', 1)[1].replace('.', '-show.'), io.BytesIO(image_bytes), len(image_bytes), content_type='image/jpeg').object_name}"
bucket_name = image['image_url'].split('/', 1)[0]
object_name = image['image_url'].split('/', 1)[1].replace('.', '-show.')
@@ -261,14 +276,15 @@ class DesignPreprocessing:
def keypoint_cache(self, sketch):
try:
client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS)
# client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS)
keypoint_id = sketch['image_id']
res = client.query(
collection_name=MILVUS_TABLE_KEYPOINT,
# ids=[keypoint_id],
filter=f"keypoint_id == {keypoint_id}",
output_fields=['keypoint_vector', 'keypoint_site']
)
# res = client.query(
# collection_name=MILVUS_TABLE_KEYPOINT,
# # ids=[keypoint_id],
# filter=f"keypoint_id == {keypoint_id}",
# output_fields=['keypoint_vector', 'keypoint_site']
# )
res = []
if len(res) == 0:
# 没有结果 直接推理拿结果 并保存
keypoint_infer_result = self.infer_keypoint_result(sketch)

View File

View File

@@ -0,0 +1,159 @@
import logging
from typing import Annotated, Required, TypedDict
from langchain_core.messages import AIMessage, AnyMessage
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from app.service.fashion_agent.graph_node.design_graph.tools import design_tool # noqa: E402
logger = logging.getLogger()
"""定义状态"""
class DesignState(TypedDict):
messages: Required[Annotated[list[AnyMessage], add_messages]]
design_request_data: dict = {}
request_objects: list[dict] = []
results: list[dict] = []
"""节点"""
def run_design_node(state: DesignState) -> dict:
"""调用 design_tool 执行设计任务,逐个推送结果"""
from langgraph.config import get_stream_writer
writer = get_stream_writer()
request_data = state.get("design_request_data")
request_objects = request_data.get("objects")
results = []
for item in design_tool.invoke({"objects": request_objects}):
logger.info(f"design result: {item}")
results.append(item)
writer({"event": "tool-output-delta", "tool_name": "design_tool", "type": "design_result", "data": item})
writer({"event": "tool-finished", "tool_name": "design_tool", "type": "design_result", "data": results})
result_text = f"设计完成,共处理 {len(results)} 个对象"
return {"results": results, "messages": [AIMessage(content=result_text)]}
"""构建 Graph"""
def build_design_graph():
"""构建 design graph"""
workflow = StateGraph(DesignState)
workflow.add_node("run_design", run_design_node)
workflow.add_edge(START, "run_design")
workflow.add_edge("run_design", END)
return workflow.compile()
design_graph = build_design_graph()
if __name__ == "__main__":
request_data = {
"objects": [
{
"basic": {
"body_point_test": {
"waistband_right": [203, 249],
"hand_point_right": [229, 343],
"waistband_left": [119, 248],
"hand_point_left": [97, 343],
"shoulder_left": [108, 107],
"relation_type": "System",
"shoulder_right": [212, 107],
"relation_id": 1020356,
},
"layer_order": False,
"scale_bag": 0.7,
"scale_earrings": 0.16,
"self_template": False,
"single_overall": "overall",
"switch_category": "",
},
"items": [
{
"color": "209 196 171",
"image_id": 84093,
"offset": [1, 1],
"path": "aida-users/89/sketchboard/female/Outwear/0943d209-7ce0-408c-bc61-83f15da94138.png",
"print": {
"element": {"element_angle_list": [], "element_path_list": [], "element_scale_list": [], "location": []},
"overall": {
"location": [[0.0, 0.0]],
"print_angle_list": [0.0, 0.0],
"print_path_list": [],
"print_scale_list": [[0.0, 0.0]],
},
"single": {"location": [], "print_angle_list": [], "print_path_list": [], "print_scale_list": []},
},
"resize_scale": [1.0, 1.0],
"type": "Outwear",
},
{
"color": "63 71 73",
"image_id": 100496,
"offset": [1, 1],
"path": "aida-sys-image/images/female/blouse/0628001684.jpg",
"print": {
"element": {"element_angle_list": [], "element_path_list": [], "element_scale_list": [], "location": []},
"overall": {
"location": [[0.0, 0.0]],
"print_angle_list": [0.0, 0.0],
"print_path_list": [],
"print_scale_list": [[0.0, 0.0]],
},
"single": {"location": [], "print_angle_list": [], "print_path_list": [], "print_scale_list": []},
},
"resize_scale": [1.0, 1.0],
"type": "Blouse",
},
{
"color": "111 78 63",
"gradient": "aida-gradient/f69b98e8-4248-4f7a-98a2-21bac41bf3e0.png",
"image_id": 92193,
"offset": [1, 1],
"path": "aida-sys-image/images/female/trousers/0825001160.jpg",
"print": {
"element": {"element_angle_list": [], "element_path_list": [], "element_scale_list": [], "location": []},
"overall": {
"location": [[0.0, 0.0]],
"print_angle_list": [0.0, 0.0],
"print_path_list": [],
"print_scale_list": [[0.0, 0.0]],
},
"single": {"location": [], "print_angle_list": [], "print_path_list": [], "print_scale_list": []},
},
"resize_scale": [1.0, 1.0],
"type": "Trousers",
},
{
"body_path": "aida-sys-image/models/female/2e4815b9-1191-419d-94ed-5771239ca4a5.png",
"image_id": 67277,
"offset": [1, 1],
"resize_scale": [1.0, 1.0],
"type": "Body",
},
],
"objectSign": "65830966",
}
],
"process_id": "4802946666428422",
"requestId": "1d1e7641-0d62-4da2-adc0-b4404910723c",
"callback_url": "https://api.aida.com.hk/api/third/party/receiveDesignResults",
}
result = design_graph.invoke({"design_request_data": request_data})
print(result)

View File

@@ -0,0 +1,206 @@
import logging
import queue
import threading
from langchain.tools import tool
from pydantic import BaseModel, Field
from app.service.design_fast.design_generate import process_item, process_layer
from app.service.design_fast.utils.synthesis_item import synthesis, synthesis_single, update_base_size_priority
logger = logging.getLogger()
class DesignModel(BaseModel):
objects: list[dict] = Field(description="")
@tool(args_schema=DesignModel, description="design tool")
def design_tool(objects: list[dict]):
"""design tool"""
result_queue = queue.Queue()
def process_object(obj):
basic = obj["basic"]
design_type = basic.get("design_type", "default")
items_response = {
"layers": [],
"objectSign": obj["objectSign"] if "objectSign" in obj.keys() else "",
}
if basic["single_overall"] == "overall":
item_results = []
for item in obj["items"]:
item_results.append(process_item(item, basic, design_type))
layers = []
for item in item_results:
process_layer(item, layers)
layers = sorted(layers, key=lambda s: s.get("priority", float("inf")))
layers, new_size = update_base_size_priority(layers)
for lay in layers:
items_response["layers"].append(
{
"image_category": "body" if lay["name"] == "mannequin" else lay["name"],
"position": lay["position"],
"priority": lay.get("priority", None),
"resize_scale": lay["resize_scale"] if "resize_scale" in lay.keys() else None,
"image_size": lay["image"] if lay["image"] is None else lay["image"].size,
"gradient_string": lay["gradient_string"] if "gradient_string" in lay.keys() else "",
"mask_url": lay["mask_url"],
"image_url": lay["image_url"] if "image_url" in lay.keys() else None,
"pattern_overall_image_url": (
lay["pattern_overall_image_url"] if "pattern_overall_image_url" in lay.keys() else None
),
"pattern_print_image_url": lay["pattern_print_image_url"] if "pattern_print_image_url" in lay.keys() else None,
}
)
items_response["synthesis_url"] = synthesis(layers, new_size, basic)
else:
item_result = process_item(obj["items"][0], basic, design_type)
items_response["layers"].append(
{
"image_category": f"{item_result['name']}_front",
"image_size": item_result["back_image"].size if item_result["back_image"] else None,
"position": None,
"priority": 0,
"image_url": item_result["front_image_url"],
"mask_url": item_result["mask_url"],
"gradient_string": item_result["gradient_string"] if "gradient_string" in item_result.keys() else "",
"pattern_overall_image_url": (
item_result["pattern_overall_image_url"] if "pattern_overall_image_url" in item_result.keys() else None
),
"pattern_print_image_url": (
item_result["pattern_print_image_url"] if "pattern_print_image_url" in item_result.keys() else None
),
}
)
items_response["layers"].append(
{
"image_category": f"{item_result['name']}_back",
"image_size": item_result["front_image"].size if item_result["front_image"] else None,
"position": None,
"priority": 0,
"image_url": item_result["back_image_url"],
"mask_url": item_result["mask_url"],
"gradient_string": item_result["gradient_string"] if "gradient_string" in item_result.keys() else "",
"pattern_overall_image_url": (
item_result["pattern_overall_image_url"] if "pattern_overall_image_url" in item_result.keys() else None
),
"pattern_print_image_url": (
item_result["pattern_print_image_url"] if "pattern_print_image_url" in item_result.keys() else None
),
}
)
items_response["synthesis_url"] = synthesis_single(item_result["front_image"], item_result["back_image"])
logger.info(items_response)
result_queue.put(items_response)
# 启动所有线程
threads = []
for obj in objects:
t = threading.Thread(target=process_object, args=(obj,))
threads.append(t)
t.start()
# 主线程逐个取出结果 yield
finished = 0
total = len(objects)
while finished < total:
result = result_queue.get()
yield result
finished += 1
if __name__ == "__main__":
request_objects = [
{
"basic": {
"body_point_test": {
"waistband_right": [203, 249],
"hand_point_right": [229, 343],
"waistband_left": [119, 248],
"hand_point_left": [97, 343],
"shoulder_left": [108, 107],
"relation_type": "System",
"shoulder_right": [212, 107],
"relation_id": 1020356,
},
"layer_order": False,
"scale_bag": 0.7,
"scale_earrings": 0.16,
"self_template": False,
"single_overall": "overall",
"switch_category": "",
},
"items": [
{
"color": "209 196 171",
"image_id": 84093,
"offset": [1, 1],
"path": "aida-users/89/sketchboard/female/Outwear/0943d209-7ce0-408c-bc61-83f15da94138.png",
"print": {
"element": {"element_angle_list": [], "element_path_list": [], "element_scale_list": [], "location": []},
"overall": {
"location": [[0.0, 0.0]],
"print_angle_list": [0.0, 0.0],
"print_path_list": [],
"print_scale_list": [[0.0, 0.0]],
},
"single": {"location": [], "print_angle_list": [], "print_path_list": [], "print_scale_list": []},
},
"resize_scale": [1.0, 1.0],
"type": "Outwear",
},
{
"color": "63 71 73",
"image_id": 100496,
"offset": [1, 1],
"path": "aida-sys-image/images/female/blouse/0628001684.jpg",
"print": {
"element": {"element_angle_list": [], "element_path_list": [], "element_scale_list": [], "location": []},
"overall": {
"location": [[0.0, 0.0]],
"print_angle_list": [0.0, 0.0],
"print_path_list": [],
"print_scale_list": [[0.0, 0.0]],
},
"single": {"location": [], "print_angle_list": [], "print_path_list": [], "print_scale_list": []},
},
"resize_scale": [1.0, 1.0],
"type": "Blouse",
},
{
"color": "111 78 63",
"gradient": "aida-gradient/f69b98e8-4248-4f7a-98a2-21bac41bf3e0.png",
"image_id": 92193,
"offset": [1, 1],
"path": "aida-sys-image/images/female/trousers/0825001160.jpg",
"print": {
"element": {"element_angle_list": [], "element_path_list": [], "element_scale_list": [], "location": []},
"overall": {
"location": [[0.0, 0.0]],
"print_angle_list": [0.0, 0.0],
"print_path_list": [],
"print_scale_list": [[0.0, 0.0]],
},
"single": {"location": [], "print_angle_list": [], "print_path_list": [], "print_scale_list": []},
},
"resize_scale": [1.0, 1.0],
"type": "Trousers",
},
{
"body_path": "aida-sys-image/models/female/2e4815b9-1191-419d-94ed-5771239ca4a5.png",
"image_id": 67277,
"offset": [1, 1],
"resize_scale": [1.0, 1.0],
"type": "Body",
},
],
"objectSign": "65830966",
}
]
result = design_tool.invoke({"objects": request_objects})
for item in result:
print(item)

View File

@@ -0,0 +1,138 @@
import asyncio
import logging
from typing import Annotated, Required, TypedDict
from langchain.tools import ToolRuntime
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, SystemMessage
from langchain_qwq import ChatQwen
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from pydantic import BaseModel, Field
from langchain_core.runnables import RunnableConfig
from app.service.fashion_agent.graph_node.explorer_graph.tools import explor_tool
from app.service.fashion_agent.init_llm import build_llm
logger = logging.getLogger()
"""定义状态"""
class ExplorerState(TypedDict):
messages: Required[Annotated[list[AnyMessage], add_messages]]
input_text: str
search_query: str
image_results: list[dict] # 每项包含 image_url 和 minio_path
provider: str = "unsplash" # 图片源: "pexels" 或 "unsplash"
"""节点"""
def extract_input_node(state: ExplorerState) -> dict:
"""从 messages 中提取用户输入"""
input_text = state["messages"][0].content if state.get("messages") else ""
return {"input_text": input_text}
class SearchQuery(BaseModel):
"""搜索关键词"""
query: str = Field(description="用于搜索灵感图片的英文关键词,简洁有力")
# TODO 要考虑搜索图片失败或者图片不存在的情况, 搜索不到 需要调整搜索词或者拆分搜索词最终失败的话调用mood board生成工具生成 保证绝对有图片
async def generate_query_node(state: ExplorerState) -> dict:
"""使用 LLM 分析用户输入,生成搜索关键词"""
input_text = state["input_text"]
logger.info(f"[Explorer] 用户输入: {input_text}")
llm = build_llm()
structured_llm = llm.with_structured_output(SearchQuery)
messages = [
SystemMessage(content="""你是一个专业的服装设计师助手。
根据用户输入生成一个英文搜索关键词用于在图片库中搜索服装设计灵感图片moodboard
要求:
1. 使用英文,简洁有力
2. 适合搜索高质量的设计灵感图片
例如:
用户输入:"夏季连衣裙,清新风格"
输出summer dress fresh style"""),
HumanMessage(content=input_text),
]
result = structured_llm.invoke(messages)
logger.info(f"[Explorer] LLM 生成的搜索关键词: {result.query}")
return {"search_query": result.query}
async def search_and_upload_node(state: ExplorerState, config: RunnableConfig) -> dict:
"""使用搜索关键词获取图片并上传到 minio"""
query = state.get("search_query", "")
user_id = state.get("user_id", "agent")
provider = state.get("provider", "unsplash")
try:
results = await explor_tool.ainvoke({"query": query, "per_page": 4, "user_id": user_id, "method": provider}, config=config)
except Exception as e:
logger.error(f"[Explorer] 搜索失败 '{query}': {e}")
results = []
return {"image_results": results}
def summarize_node(state: ExplorerState) -> dict:
"""汇总结果"""
input_text = state.get("input_text", "")
query = state.get("search_query", "")
results = state.get("image_results", [])
result_text = f"【灵感探索 Moodboard】\n\n"
result_text += f"基于您的需求:「{input_text}\n"
result_text += f"搜索关键词:{query}\n\n"
result_text += f"已为您找到 {len(results)} 张灵感图片:\n"
for i, item in enumerate(results, 1):
result_text += f" {i}. 原图: {item.get('image_url', '')}\n"
result_text += f" Minio: {item.get('minio_path', '')}\n"
return {"messages": [AIMessage(content=result_text)]}
"""构建图"""
def build_explorer_graph():
"""构建灵感探索图"""
workflow = StateGraph(ExplorerState)
workflow.add_node("extract_input", extract_input_node)
workflow.add_node("generate_query", generate_query_node)
workflow.add_node("search_and_upload", search_and_upload_node)
workflow.add_node("summarize", summarize_node)
workflow.add_edge(START, "extract_input")
workflow.add_edge("extract_input", "generate_query")
workflow.add_edge("generate_query", "search_and_upload")
workflow.add_edge("search_and_upload", "summarize")
workflow.add_edge("summarize", END)
return workflow.compile()
if __name__ == "__main__":
async def test():
graph = build_explorer_graph()
result = await graph.ainvoke(
{
"messages": [HumanMessage(content="夏季连衣裙,清新自然风格")],
"provider": "unsplash",
}
)
print(result["messages"][-1].content)
asyncio.run(test())

View File

@@ -0,0 +1,56 @@
from langchain.tools import ToolRuntime, tool
from pydantic import BaseModel, Field
from langchain_core.runnables import RunnableConfig
from app.service.fashion_agent.graph_node.node_tools.pexels_search import search_photos
from app.service.fashion_agent.graph_node.node_tools.unsplash_search import get_random_photos
class SearchInput(BaseModel):
"""Input schema for Pexels Search Tool."""
query: str = Field(description="Search query for Pexels, e.g., 'minimalist fashion moodboard', 'summer dress inspiration'")
per_page: int = Field(description="Number of images to return (1-80)", default=4)
user_id: str = Field(description="User ID for image storage", default="agent")
method: str = Field(description="", default="unsplash")
@tool(args_schema=SearchInput)
async def explor_tool(
query: str, per_page: int = 4, user_id: str = "agent", method: str = "unsplash", config: RunnableConfig = None
) -> list[dict]:
"""Search for fashion inspiration images on Unsplash and upload to minio. Returns a list of dicts with image_url and minio_path."""
if config:
# 方式 1从 configurable 获取
user_id = config.get("configurable", {}).get("user_id", "agent")
results = []
if method == "unsplash":
results = await get_random_photos(query, count=per_page, user_id=user_id)
elif method == "pexels":
results = await search_photos(query, per_page=per_page, user_id=user_id)
else:
results = []
return results
if __name__ == "__main__":
import asyncio
async def test():
urls = await get_random_photos("summer dress fresh natural style", count=4)
print(f"Uploaded {len(urls)} images to minio:")
for url in urls:
print(f" {url}")
asyncio.run(test())
if __name__ == "__main__":
import asyncio
async def test():
urls = await search_photos("minimalist fashion moodboard", per_page=4)
print(f"Uploaded {len(urls)} images to minio:")
for url in urls:
print(f" {url}")
asyncio.run(test())

View File

@@ -0,0 +1,152 @@
import asyncio
from typing import Annotated, Required, TypedDict
from langchain_core.messages import AIMessage, AnyMessage
from langchain_qwq import ChatQwen
from langchain_core.messages import HumanMessage, SystemMessage
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from pydantic import BaseModel, Field
from app.service.fashion_agent.graph_node.logo_graph.tools import generate_logo_tool
from app.service.fashion_agent.init_llm import qwen_plus_llm
"""初始化 LLM TODO 将 API Key 替换为环境变量或者配置文件中的值,避免在代码中硬编码敏感信息"""
"""定义状态"""
class LogoState(TypedDict):
messages: Required[Annotated[list[AnyMessage], add_messages]]
input_text: str
user_id: str = "agent"
role: str = ""
gender: str = ""
style: str = ""
need_prompt_generation: bool = True # 是否需要使用 prompt 生成节点
logo_num: int = 1
logo_prompts: list[str] = []
logo_img_urls: list[str] = []
"""生成 Logo 的提示词节点"""
# 定义输出结构
class LogoPrompt(BaseModel):
"""生成的 Logo 图像提示词"""
prompts: list[str] = Field(description="用于生成 Logo 的详细提示词")
def extract_input_node(state: LogoState) -> dict:
"""从 messages 中提取用户输入"""
input_text = state["messages"][0].content if state.get("messages") else ""
return {"input_text": input_text}
def generate_logo_prompt_node(state: LogoState) -> dict:
"""根据用户输入生成 Logo 的图像生成提示词"""
structured_llm = qwen_plus_llm.with_structured_output(LogoPrompt)
messages = [
SystemMessage(content="""从用户输入中提取核心主题词,只输出一个简单的英文单词。
例如:
- "我想要一个猫咪图案" -> "cat"
- "设计一个花朵" -> "flower"
- "可爱的狗" -> "dog"
只输出单词,不要其他内容。"""),
HumanMessage(content=state["input_text"]),
]
result = structured_llm.invoke(messages)
prompts = result.prompts
return {
"logo_prompts": prompts,
}
"""生成 Logo 图案节点"""
async def generate_logo_img_node(state: LogoState) -> dict:
"""根据生成的提示词,生成 Logo 图案"""
# 如果 logo_prompts 为空,使用 input_text 作为 prompt
prompts = state["logo_prompts"] if state["logo_prompts"] else [state["input_text"]]
logo_img_urls = []
for i in range(state.get("logo_num", 1)):
image_url = await generate_logo_tool.ainvoke({"prompt": prompts[i], "user_id": state.get("user_id", "agent")})
logo_img_urls.append(image_url)
result_text = f"Logo 生成完成,共生成 {len(logo_img_urls)} 张图片:\n"
return {"logo_img_urls": logo_img_urls, "messages": [AIMessage(content=result_text)]}
"""条件分支 判断是否需要生成 prompt"""
def should_generate_prompt(state: LogoState) -> str:
"""条件分支:判断是否需要生成 prompt"""
if state.get("need_prompt_generation", True):
return "gen_prompt"
else:
return "gen_logo"
def build_logo_graph():
"""构建独立的画像收集 Graph"""
workflow = StateGraph(LogoState)
workflow.add_node("extract_input", extract_input_node)
workflow.add_node("gen_prompt", generate_logo_prompt_node)
workflow.add_node("gen_logo", generate_logo_img_node)
# 添加边
workflow.add_edge(START, "extract_input")
workflow.add_conditional_edges(
"extract_input",
should_generate_prompt,
{
"gen_prompt": "gen_prompt",
"gen_logo": "gen_logo",
},
)
workflow.add_edge("gen_prompt", "gen_logo")
workflow.add_edge("gen_logo", END)
graph = workflow.compile()
return graph
async def main(test_input, user_id="agent", need_prompt_generation=True):
graph = build_logo_graph()
result = await graph.ainvoke(
{
"input_text": test_input,
"user_id": user_id,
"logo_prompts": [] if need_prompt_generation else [test_input],
"need_prompt_generation": need_prompt_generation,
"role": "",
"gender": "",
"style": "",
}
)
return result
if __name__ == "__main__":
# 测试示例 1: 需要 prompt 生成(默认)- 简单关键词输入
test_input = "我想要一个金毛图案"
result = asyncio.run(main(test_input, need_prompt_generation=True))
print("=== 需要 prompt 生成 ===")
print(f"Result: {result}")
# 测试示例 2: 直接使用用户提供的 prompt
user_prompt = "golden retriever"
result = asyncio.run(main(user_prompt, need_prompt_generation=False))
print("\n=== 直接使用 prompt ===")
print(f"Result: {result}")

View File

@@ -0,0 +1,79 @@
import asyncio
import concurrent.futures
import random
import numpy as np
import tritonclient.grpc as grpcclient
from langchain.tools import tool
from PIL import Image
from pydantic import BaseModel, Field
from tritonclient.utils import np_to_triton_dtype
from uuid_utils import uuid7
from app.core.config import settings
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
# 模型配置
GSL_MODEL_URL = f"{settings.B_4_X_4090_SERVICE_HOST}:10041"
GSL_MODEL_NAME = "stable_diffusion_xl_transparent"
# 线程池用于执行同步推理
executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
def _generate_logo_sync(prompt: str) -> Image.Image:
"""同步生成 Logo 的内部函数"""
seed = random.randint(0, 2**32 - 1)
grpc_client = grpcclient.InferenceServerClient(url=GSL_MODEL_URL)
# 准备输入
prompts = [prompt]
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
input_text.set_data_from_numpy(text_obj)
negative_prompts = "bad, ugly"
text_obj_neg = np.array(negative_prompts, dtype="object").reshape((-1, 1))
input_text_neg = grpcclient.InferInput("negative_prompt", text_obj_neg.shape, np_to_triton_dtype(text_obj_neg.dtype))
input_text_neg.set_data_from_numpy(text_obj_neg)
seed_input = np.array(seed, dtype="object").reshape((-1, 1))
input_seed = grpcclient.InferInput("seed", seed_input.shape, np_to_triton_dtype(seed_input.dtype))
input_seed.set_data_from_numpy(seed_input)
inputs = [input_text, input_text_neg, input_seed]
# 同步推理
result = grpc_client.infer(model_name=GSL_MODEL_NAME, inputs=inputs)
image = result.as_numpy("generated_image")
return Image.fromarray(np.squeeze(image.astype(np.uint8)))
async def generate_logo(prompt: str) -> Image.Image:
"""异步生成透明背景的 Logo 图片"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(executor, _generate_logo_sync, prompt)
class GenerateLogoToolInput(BaseModel):
"""Input schema for the Generate Logo Tool."""
prompt: str = Field(description="Simple keyword for logo generation, e.g., 'cat', 'flower', 'dog'")
user_id: str = Field(description="User ID for image storage", default="agent")
@tool(args_schema=GenerateLogoToolInput)
async def generate_logo_tool(prompt: str, user_id: str = "agent") -> str:
"""Generate a transparent background logo image based on a simple keyword."""
image = await generate_logo(prompt=prompt)
# 上传到 minio使用线程池避免阻塞事件循环
file_name = f"{uuid7()}.png"
loop = asyncio.get_event_loop()
image_url = await loop.run_in_executor(executor, upload_SDXL_image, image, user_id, "logo", file_name)
return [image_url]
if __name__ == "__main__":
result = asyncio.run(generate_logo_tool.ainvoke({"prompt": "golden retriever"}))
print(f"Logo saved to: {result}")

View File

@@ -0,0 +1,27 @@
import httpx
async def generate_image(
bucket_name="fida-public-bucket",
object_name=f"furniture/sketches/123456.png",
prompt="Generate a modern minimalist dining chair made of light "
"oak wood and white leather, with slim metal legs, photographed "
"in a bright Scandinavian living room with natural sunlight, high detail, "
"8k resolution.",
):
request_data = {
"input_image_paths": [],
"prompt": prompt,
"bucket_name": bucket_name,
"object_name": object_name,
"width": 1024,
"height": 1024,
}
async with httpx.AsyncClient(timeout=120) as client:
resp = await client.post(
f"http://20.1.1.33:14202/predict",
json=request_data,
)
result = resp.json()
image_url = result.get("output_path", None)
return image_url

View File

@@ -0,0 +1,72 @@
import asyncio
import concurrent.futures
import io
import logging
import os
import httpx
from dotenv import load_dotenv
from PIL import Image
from uuid_utils import uuid7
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
load_dotenv()
logger = logging.getLogger()
PEXELS_API_KEY = os.environ.get("PEXELS_API_KEY", "")
PEXELS_BASE_URL = os.environ.get("PEXELS_BASE_URL", "")
# 线程池用于执行同步上传
executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
async def search_photos(query: str, per_page: int = 4, user_id: str = "agent") -> list[dict]:
"""从 Pexels 搜索图片并上传到 minio
Args:
query: 搜索关键词
per_page: 返回图片数量 (1-80)
user_id: 用户 ID
Returns:
图片信息列表,每项包含 image_url 和 minio_path
"""
# 搜索图片
async with httpx.AsyncClient(timeout=30) as client:
response = await client.get(
f"{PEXELS_BASE_URL}/search",
headers={"Authorization": PEXELS_API_KEY},
params={
"query": query,
"per_page": per_page,
"orientation": "square",
"size": "medium",
},
)
if response.status_code != 200:
raise Exception(f"Pexels API error: {response.status_code} - {response.text}")
data = response.json()
photos = data.get("photos", [])
# 下载并上传到 minio
results = []
for photo in photos:
try:
# 下载图片(使用 large 尺寸)
image_url = photo["src"]["original"]
async with httpx.AsyncClient(timeout=60) as dl_client:
dl_response = await dl_client.get(image_url)
image = Image.open(io.BytesIO(dl_response.content))
# 上传到 minio使用线程池避免阻塞事件循环
file_name = f"{uuid7()}.jpg"
loop = asyncio.get_event_loop()
minio_url = await loop.run_in_executor(executor, upload_SDXL_image, image, user_id, "explorer", file_name)
results.append({"image_url": image_url, "minio_path": minio_url})
logger.info(f"[Explorer] 上传成功: {minio_url}")
except Exception as e:
logger.error(f"[Explorer] 上传失败: {e}")
return results

View File

@@ -0,0 +1,90 @@
import asyncio
import concurrent.futures
import io
import logging
import os
import httpx
from PIL import Image
from uuid_utils import uuid7
from dotenv import load_dotenv
from app.service.generate_image.utils.upload_sd_image import upload_SDXL_image
load_dotenv()
# Unsplash API 配置
UNSPLASH_ACCESS_KEY = os.environ.get("UNSPLASH_ACCESS_KEY", "")
UNSPLASH_BASE_URL = os.environ.get("UNSPLASH_BASE_URL", "")
logger = logging.getLogger()
# 线程池用于执行同步上传
executor = concurrent.futures.ThreadPoolExecutor(max_workers=4)
async def get_random_photos(query: str, count: int = 4, user_id: str = "agent") -> list[dict]:
"""从 Unsplash 获取随机图片并上传到 minio
Args:
query: 搜索关键词
count: 返回图片数量 (1-30)
user_id: 用户 ID
Returns:
图片信息列表,每项包含 image_url 和 minio_path
"""
# 获取随机图片
async with httpx.AsyncClient(timeout=30) as client:
response = await client.get(
f"{UNSPLASH_BASE_URL}/search/photos",
headers={"Authorization": f"Client-ID {UNSPLASH_ACCESS_KEY}"},
params={
"query": query,
"per_page": count,
"page": 1,
},
)
if response.status_code != 200:
raise Exception(f"Unsplash API error: {response.status_code} - {response.text}")
data = response.json()
# /search/photos 返回 {"results": [...], "total": ...}
photos = data.get("results", [])
# 下载并上传到 minio
results = []
for photo in photos:
try:
# 下载图片
image_url = photo["urls"]["raw"]
async with httpx.AsyncClient(timeout=60) as dl_client:
dl_response = await dl_client.get(image_url)
image = Image.open(io.BytesIO(dl_response.content))
# 上传到 minio使用线程池避免阻塞事件循环
file_name = f"{uuid7()}.jpg"
loop = asyncio.get_event_loop()
minio_url = await loop.run_in_executor(executor, upload_SDXL_image, image, user_id, "explorer", file_name)
results.append({"image_url": image_url, "minio_path": minio_url})
logger.info(f"[Explorer] 上传成功: {minio_url}")
except Exception as e:
logger.error(f"[Explorer] 上传失败: {e}")
return results
if __name__ == "__main__":
import asyncio
async def test():
"""测试 Unsplash 搜索"""
query = "summer dress fresh natural style"
print(f"搜索关键词: {query}")
print("=" * 50)
results = await get_random_photos(query, count=4, user_id="test")
print(f"\n找到 {len(results)} 张图片:")
for i, item in enumerate(results, 1):
print(f" {i}. 原图: {item.get('image_url', '')}")
print(f" Minio: {item.get('minio_path', '')}")
asyncio.run(test())

View File

@@ -0,0 +1,158 @@
import asyncio
import logging
from typing import Annotated, Required, TypedDict
from langchain_qwq import ChatQwen
from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from pydantic import BaseModel, Field
from app.service.fashion_agent.init_llm import qwen_plus_llm
from app.service.fashion_agent.graph_node.print_graph.tools import generate_print_tool, test
logger = logging.getLogger()
"""定义状态"""
class PrintState(TypedDict):
messages: Required[Annotated[list[AnyMessage], add_messages]]
input_text: str
role: str = ""
gender: str = ""
style: str = ""
print_need_prompt_generation: bool = False # 是否需要使用 prompt 生成节点
print_num: int = 1
print_prompts: list[str] = []
print_img_urls: list[str] = []
"""生成印花图案的提示词节点"""
# 定义输出结构
class PrintPrompt(BaseModel):
"""生成的印花图像提示词"""
prompts: list[str] = Field(description="用于生成印花图案的详细提示词")
def extract_input_node(state: PrintState) -> dict:
"""从 messages 中提取用户输入"""
input_text = state["messages"][0].content if state.get("messages") else ""
return {"input_text": input_text}
def generate_print_prompt_node(state: PrintState) -> dict:
"""根据用户输入生成印花图案的图像生成提示词"""
structured_llm = qwen_plus_llm.with_structured_output(PrintPrompt)
messages = [
SystemMessage(content=f"""你是一个专业的印花图案设计师。
请根据用户输入生成用于AI图像生成的印花图案提示词。
要求:
1. 提示词应该详细描述印花图案的样式、元素、颜色、布局
2. 提示词应该适合用于 Stable Diffusion 图像生成模型
3. 提示词应该使用英文,因为图像生成模型对英文理解更好
4. 提示词数量为 {state.get("print_num", 1)}
"""),
HumanMessage(content=state["input_text"]),
]
result = structured_llm.invoke(messages)
prompts = result.prompts
logger.info(f"[Print Graph] Generated print prompts: {prompts}")
return {
"print_prompts": prompts,
}
"""生成印花图案节点"""
async def generate_print_img_node(state: PrintState) -> dict:
"""根据生成的提示词,生成印花图案"""
# 如果 print_prompts 为空,使用 input_text 作为 prompt
if state.get("print_need_prompt_generation", False):
prompts = state["print_prompts"] if state["print_prompts"] else [state["input_text"]]
else:
input_text = state.get("input_text", "")
prompts = [input_text]
print_img_urls = []
for prompt in prompts:
image_url = await generate_print_tool.ainvoke({"prompt": prompt})
print_img_urls.append(image_url)
logger.info(f"[Print Graph] Generated print image URL: {image_url}")
return {"print_img_urls": print_img_urls}
"""条件分支 判断是否需要生成 prompt"""
def should_generate_prompt(state: PrintState) -> str:
"""条件分支:判断是否需要生成 prompt"""
logger.info(
f"[Print Graph] should_generate_prompt: print_need_prompt_generation={state.get('print_need_prompt_generation')}, print_prompts={state.get('print_prompts')}"
)
if state.get("print_need_prompt_generation", True):
return "gen_prompt"
else:
return "gen_print"
def build_print_graph():
workflow = StateGraph(PrintState)
workflow.add_node("extract_input", extract_input_node)
workflow.add_node("gen_prompt", generate_print_prompt_node)
workflow.add_node("gen_print", generate_print_img_node)
# 添加边
workflow.add_edge(START, "extract_input")
workflow.add_conditional_edges(
"extract_input",
should_generate_prompt,
{
"gen_prompt": "gen_prompt",
"gen_print": "gen_print",
},
)
workflow.add_edge("gen_prompt", "gen_print")
workflow.add_edge("gen_print", END)
graph = workflow.compile()
return graph
async def main(test_input, print_need_prompt_generation=True):
graph = build_print_graph()
result = await graph.ainvoke(
{
"input_text": test_input,
"print_prompts": [] if print_need_prompt_generation else [test_input],
"print_need_prompt_generation": print_need_prompt_generation,
"role": "",
"gender": "",
"style": "",
}
)
return result
if __name__ == "__main__":
# 测试示例 1: 需要 prompt 生成(默认)
test_input = "我想要一个优雅的花卉印花,适合用于连衣裙,颜色以粉色和白色为主"
result = asyncio.run(main(test_input, print_need_prompt_generation=True))
print("=== 需要 prompt 生成 ===")
print(f"Result: {result}")
# 测试示例 2: 直接使用用户提供的 prompt
user_prompt = "Elegant floral print pattern, pink and white colors, suitable for dress fabric, seamless tileable design"
result = asyncio.run(main(user_prompt, print_need_prompt_generation=False))
print("\n=== 直接使用 prompt ===")
print(f"Result: {result}")

View File

@@ -0,0 +1,39 @@
import asyncio
from langchain.tools import tool
from langsmith import uuid7
from pydantic import BaseModel, Field
from app.service.fashion_agent.graph_node.node_tools.generate_image import generate_image
class GenerateImageToolInput(BaseModel):
"""Input schema for the Generate Image Tool."""
prompt: str = Field(description="Description of the desired image, e.g., 'A cozy living room with warm lighting and natural textures.'")
@tool(args_schema=GenerateImageToolInput)
async def generate_print_tool(prompt: str) -> str:
"""Generate an image based on the provided prompt."""
bucket_name = "aida-users"
object_name = f"agent_generate_print/{uuid7()}.png"
image_url = await generate_image(prompt=prompt, bucket_name=bucket_name, object_name=object_name)
return [image_url]
@tool
async def test(text: str):
"""测试工具函数,返回固定字符串"""
return text
async def run_test():
result = await generate_print_tool.ainvoke({"prompt": "A cozy living room with warm lighting and natural textures."})
return result
if __name__ == "__main__":
result = asyncio.run(run_test())
print(result)

View File

@@ -0,0 +1,176 @@
import asyncio
import logging
from typing import Annotated, Required, TypedDict
from langchain_qwq import ChatQwen
from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage, AIMessage
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from pydantic import BaseModel, Field
from app.service.fashion_agent.init_llm import qwen_plus_llm
from app.service.fashion_agent.graph_node.sketch_graph.tools import generate_sketch_tool
logger = logging.getLogger()
"""定义状态"""
class SketchState(TypedDict):
messages: Required[Annotated[list[AnyMessage], add_messages]]
input_text: str
role: str = ""
gender: str = ""
style: str = ""
sketch_need_prompt_generation: bool = False # 是否需要使用 prompt 生成节点
sketch_num: int = 1
sketch_prompts: list[str] = []
sketch_img_urls: list[str] = []
"""生成服装草图的提示词节点"""
# 定义输出结构
class SketchPrompt(BaseModel):
"""生成的印花图像提示词"""
prompts: list[str] = Field(description="用于生成服装草图的详细提示词")
def extract_input_node(state: SketchState) -> dict:
"""从 messages 中提取用户输入"""
input_text = state["messages"][0].content if state.get("messages") else ""
return {"input_text": input_text}
def generate_sketch_prompt_node(state: SketchState) -> dict:
"""根据用户输入生成服装草图的图像生成提示词"""
structured_llm = qwen_plus_llm.with_structured_output(SketchPrompt)
messages = [
SystemMessage(content=f"""你是一个专业的服装设计师。
请根据用户输入生成用于AI图像生成的服装草图提示词。
要求:
1. 提示词必须包含clean black and white line drawing only, pure white background, centered composition
2. 提示词应该详细描述服装的廓形、结构、细节
3. 提示词应该适合用于 Stable Diffusion 图像生成模型
4. 提示词应该使用英文,因为图像生成模型对英文理解更好
5. 草图风格必须是黑白线稿,不要添加颜色
6. 提示词数量为 {state.get("sketch_num", 1)}
"""),
HumanMessage(content=state["input_text"]),
]
result = structured_llm.invoke(messages)
prompts = result.prompts
return {
"sketch_prompts": prompts,
}
"""生成服装草图节点"""
async def generate_sketch_img_node(state: SketchState) -> dict:
"""根据生成的提示词,生成服装草图"""
# 如果 sketch_need_prompt_generation=False 且 sketch_prompts 为空,使用模板生成 prompt
if not state.get("sketch_need_prompt_generation", False) and not state.get("sketch_prompts"):
input_text = state.get("input_text", "")
prompts = [build_sketch_template_prompt(input_text)]
else:
prompts = state["sketch_prompts"] if state["sketch_prompts"] else [state["input_text"]]
sketch_img_urls = []
for prompt in prompts:
image_url = await generate_sketch_tool.ainvoke({"prompt": prompt})
sketch_img_urls.append(image_url)
return {"sketch_img_urls": sketch_img_urls}
"""条件分支 判断是否需要生成 prompt"""
def should_generate_prompt(state: SketchState) -> str:
"""条件分支:判断是否需要生成 prompt"""
if state.get("sketch_need_prompt_generation", False):
return "gen_prompt"
else:
return "gen_sketch"
def build_sketch_graph():
workflow = StateGraph(SketchState)
workflow.add_node("gen_sketch", generate_sketch_img_node)
workflow.add_edge(START, "gen_sketch")
workflow.add_edge("gen_sketch", END)
graph = workflow.compile()
return graph
# workflow = StateGraph(SketchState)
# workflow.add_node("extract_input", extract_input_node)
# workflow.add_node("gen_prompt", generate_sketch_prompt_node)
# workflow.add_node("gen_sketch", generate_sketch_img_node)
# # 添加边
# workflow.add_edge(START, "extract_input")
# workflow.add_conditional_edges(
# "extract_input",
# should_generate_prompt,
# {
# "gen_prompt": "gen_prompt",
# "gen_sketch": "gen_sketch",
# },
# )
# workflow.add_edge("gen_prompt", "gen_sketch")
# workflow.add_edge("gen_sketch", END)
# graph = workflow.compile()
# return graph
def build_sketch_template_prompt(input_text: str) -> str:
"""构建 sketch prompt 模板"""
return f"{input_text}, clean black and white line drawing only, pure white background, centered composition, fashion sketch style"
async def main(test_input, sketch_need_prompt_generation=False):
graph = build_sketch_graph()
# 如果不需要 LLM 生成 prompt使用模板
if not sketch_need_prompt_generation:
sketch_prompts = [build_sketch_template_prompt(test_input)]
else:
sketch_prompts = []
result = await graph.ainvoke(
{
"input_text": test_input,
"sketch_prompts": sketch_prompts,
"sketch_need_prompt_generation": sketch_need_prompt_generation,
"role": "",
"gender": "",
"style": "",
}
)
return result
if __name__ == "__main__":
# 测试示例 1: 直接使用模板 prompt默认
test_input = "dress"
result = asyncio.run(main(test_input, sketch_need_prompt_generation=False))
print("=== 使用模板 prompt ===")
print(f"Result: {result}")
# # 测试示例 2: 使用 LLM 生成 prompt
# test_input = "设计一条优雅的A字廓形连衣裙V领设计收腰裙摆到膝盖适合日常穿着"
# result = asyncio.run(main(test_input, sketch_need_prompt_generation=True))
# print("\n=== 使用 LLM 生成 prompt ===")
# print(f"Result: {result}")

View File

@@ -0,0 +1,33 @@
import asyncio
from langchain.tools import tool
from langsmith import uuid7
from pydantic import BaseModel, Field
from app.service.fashion_agent.graph_node.node_tools.generate_image import generate_image
class GenerateImageToolInput(BaseModel):
"""Input schema for the Generate Image Tool."""
prompt: str = Field(description="Description of the desired image, e.g., 'A cozy living room with warm lighting and natural textures.'")
@tool(args_schema=GenerateImageToolInput)
async def generate_sketch_tool(prompt: str) -> str:
"""Generate an image based on the provided prompt."""
bucket_name = "fida-public-bucket"
object_name = f"test/{uuid7()}.png"
image_url = await generate_image(prompt=prompt, bucket_name=bucket_name, object_name=object_name)
return [image_url]
async def run_test():
result = await generate_sketch_tool.ainvoke({"prompt": "A cozy living room with warm lighting and natural textures."})
return result
if __name__ == "__main__":
result = asyncio.run(run_test())
print(result)

View File

@@ -0,0 +1,69 @@
import asyncio
from typing import Annotated, Required, TypedDict
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
"""定义状态"""
class TrendingState(TypedDict):
messages: Required[Annotated[list[AnyMessage], add_messages]]
input_text: str
"""节点"""
def extract_input_node(state: TrendingState) -> dict:
"""从 messages 中提取用户输入"""
input_text = state["messages"][0].content if state.get("messages") else ""
return {"input_text": input_text}
async def trending_node(state: TrendingState) -> dict:
"""趋势分析节点(占位)"""
input_text = state.get("input_text", "")
# TODO: 接入真实的趋势分析逻辑
result_text = (
f"【趋势分析】\n基于您的输入「{input_text}」,以下是当前服装设计趋势:\n\n"
"1. 极简主义持续流行,黑白灰为主色调\n"
"2. 可持续时尚成为主流,环保面料受青睐\n"
"3. 复古风格回潮90年代元素重新流行\n"
"4. 功能性与美学结合,运动休闲风持续升温"
)
return {"messages": [AIMessage(content=result_text)]}
"""构建图"""
def build_trending_graph():
"""构建趋势分析图"""
workflow = StateGraph(TrendingState)
workflow.add_node("extract_input", extract_input_node)
workflow.add_node("trending", trending_node)
workflow.add_edge(START, "extract_input")
workflow.add_edge("extract_input", "trending")
workflow.add_edge("trending", END)
return workflow.compile()
if __name__ == "__main__":
async def test():
graph = build_trending_graph()
result = await graph.ainvoke(
{
"messages": [HumanMessage(content="女装连衣裙")],
}
)
print(result["messages"][-1].content)
asyncio.run(test())

View File

@@ -0,0 +1,33 @@
import os
from dotenv import load_dotenv
from langchain_qwq import ChatQwen
load_dotenv()
QWEN_API_KEY_INTL = os.environ.get("QWEN_API_KEY_INTL", "")
def build_llm(enable_thinking: bool = False):
llm = ChatQwen(
model="qwen3.6-plus",
timeout=None,
max_retries=2,
enable_thinking=enable_thinking,
streaming=True,
api_key=QWEN_API_KEY_INTL,
)
return llm
qwen_plus_llm = ChatQwen(
model="qwen-plus",
timeout=None,
max_retries=2,
streaming=False,
temperature=0.25,
top_p=0.8,
api_key=QWEN_API_KEY_INTL,
)
# response = qwen_plus_llm.invoke("你好")
# print(response)

View File

@@ -0,0 +1,144 @@
import sys
from pathlib import Path
from typing import Annotated, Required, TypedDict
from deepagents import CompiledSubAgent, create_deep_agent
from langchain.agents import create_agent
from langchain.tools import tool
from langchain_core.messages import AnyMessage, HumanMessage
from langchain_qwq import ChatQwen
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from app.service.fashion_agent.graph_node.design_graph.graph import build_design_graph
from app.service.fashion_agent.graph_node.design_graph.tools import design_tool
from app.service.fashion_agent.graph_node.explorer_graph.tools import explor_tool
from app.service.fashion_agent.graph_node.logo_graph.graph import build_logo_graph
from app.service.fashion_agent.graph_node.logo_graph.tools import generate_logo_tool
from app.service.fashion_agent.graph_node.print_graph.graph import build_print_graph
from app.service.fashion_agent.graph_node.print_graph.tools import generate_print_tool
from app.service.fashion_agent.graph_node.sketch_graph.graph import build_sketch_graph
from app.service.fashion_agent.graph_node.sketch_graph.tools import generate_sketch_tool
from app.service.fashion_agent.graph_node.trending_graph.trending_graph import build_trending_graph
from app.service.fashion_agent.graph_node.explorer_graph.graph import build_explorer_graph
from app.service.fashion_agent.init_llm import build_llm
print_graph = build_print_graph()
logo_graph = build_logo_graph()
sketch_graph = build_sketch_graph()
design_graph = build_design_graph()
trending_graph = build_trending_graph()
explorer_graph = build_explorer_graph()
class MainState(TypedDict):
# 消息
messages: Required[Annotated[list[AnyMessage], add_messages]]
# 模块控制
call_design: bool = False
call_print: bool = False
call_logo: bool = False
call_sketch: bool = False
call_design: bool = False
call_trending: bool = False
call_explor: bool = False
# design参数
design_request_data: dict = {}
# 模块需求标志
print_need_prompt_generation: bool = False
sketch_need_prompt_generation: bool = False
# 公共参数
role: str = ""
gender: str = ""
style: str = ""
# print模块结果
print_img_urls: list[str] = []
tools = [explor_tool, generate_logo_tool, generate_print_tool, generate_sketch_tool]
def route_node(state: MainState) -> str:
"""根据标志决定走哪条路径"""
if state.get("call_print"):
return "direct_print"
if state.get("call_logo"):
return "direct_logo"
if state.get("call_sketch"):
return "direct_sketch"
if state.get("call_design"):
return "direct_design"
if state.get("call_trending"):
return "direct_trending"
if state.get("call_explor"):
return "direct_explor"
return "llm_agent"
def build_main_graph(enable_thinking: bool = False):
llm = build_llm(enable_thinking=enable_thinking)
chat_agent = create_agent(
model=llm, tools=tools, state_schema=MainState, system_prompt="你是一个专业的服装设计助手。根据用户需求,调用合适的工具完成任务."
)
"""构建主图"""
workflow = StateGraph(MainState)
# 添加节点
workflow.add_node("llm_agent", chat_agent)
workflow.add_node("direct_print", print_graph)
workflow.add_node("direct_logo", logo_graph)
workflow.add_node("direct_sketch", sketch_graph)
workflow.add_node("direct_design", design_graph)
workflow.add_node("direct_trending", trending_graph)
workflow.add_node("direct_explor", explorer_graph)
# 条件分支
workflow.add_conditional_edges(
START,
route_node,
{
"llm_agent": "llm_agent",
"direct_print": "direct_print",
"direct_logo": "direct_logo",
"direct_sketch": "direct_sketch",
"direct_design": "direct_design",
"direct_trending": "direct_trending",
"direct_explor": "direct_explor",
},
)
# 所有路径都到 END
workflow.add_edge("llm_agent", END)
workflow.add_edge("direct_print", END)
workflow.add_edge("direct_logo", END)
workflow.add_edge("direct_sketch", END)
workflow.add_edge("direct_design", END)
workflow.add_edge("direct_trending", END)
workflow.add_edge("direct_explor", END)
return workflow.compile()
agent = build_main_graph()
if __name__ == "__main__":
import asyncio
async def test_direct():
# 直接调用 sketch跳过 LLM
result = await agent.ainvoke(
{
"messages": [HumanMessage(content="我想设计衬衫,带有猫咪的图案")],
"call_sketch": True,
"sketch_need_prompt_generation": False,
}
)
print("=== 直接调用 sketch ===")
print(result["messages"][-1].content)
asyncio.run(test_direct())

View File

@@ -0,0 +1,134 @@
import json
import logging
import sys
from pathlib import Path
from langgraph.stream import ProtocolEvent, StreamChannel, StreamTransformer
from app.service.fashion_agent.main_agent import build_main_graph
from langgraph.prebuilt import ToolCallTransformer
from typing import AsyncGenerator, TypedDict
from langchain_core.messages import HumanMessage, ToolMessage
from app.schemas.fashion_agent import FashionAgentRequest
logger = logging.getLogger()
class CustomTransformer(StreamTransformer):
required_stream_modes = ("custom",)
def __init__(self, scope: tuple[str, ...] = ()) -> None:
super().__init__(scope)
self.log = StreamChannel()
def init(self) -> dict:
return {"custom": self.log}
def process(self, event: ProtocolEvent) -> bool:
if event["method"] == "custom":
self.log.push(event["params"]["data"])
return True
class FashionAgentService:
async def run_stream(self, request: FashionAgentRequest) -> AsyncGenerator[str, None]:
"""流式运行 agent - 使用 v3 projections"""
config = {"configurable": {"user_id": request.user_id}}
agent = build_main_graph(enable_thinking=request.enable_thinking)
state = {
"messages": [HumanMessage(content=request.message)],
"call_print": request.call_print,
"call_logo": request.call_logo,
"call_sketch": request.call_sketch,
"call_design": request.call_design,
"call_trending": request.call_trending,
"call_explor": request.call_explor,
"print_need_prompt_generation": request.print_need_prompt_generation,
"sketch_need_prompt_generation": request.sketch_need_prompt_generation,
"design_request_data": request.design_request_data,
}
stream = await agent.astream_events(state, config=config, version="v3", transformers=[ToolCallTransformer, CustomTransformer])
tool_names = {}
filter_tool_name = ["design_tool"]
async for event in stream:
if event["method"] == "tools":
data = event["params"]["data"]
tool_call_id = data.get("tool_call_id")
# 记录 tool_name
if data.get("event") == "tool-started":
tool_names[tool_call_id] = data.get("tool_name")
# 通过 ID 查找 tool_name
elif data.get("event") == "tool-finished":
tool_name = tool_names.get(tool_call_id, "unknown")
if tool_name in filter_tool_name:
continue
data["tool_name"] = tool_name
if isinstance(data["output"], ToolMessage):
data["output"] = json.loads(data["output"].content)
response_event = {"event_type": "tool", "data": data}
yield f"data: {json.dumps(response_event, ensure_ascii=False)}\n\n"
elif event["method"] == "custom":
data = event["params"]["data"]
response_event = {"event_type": "tool", "data": data}
yield f"data: {json.dumps(response_event, ensure_ascii=False)}\n\n"
elif event["method"] == "messages":
data = event["params"]["data"][0]
if not isinstance(data, dict):
continue
if data.get("event") != "content-block-delta":
continue
block = data.get("delta") or {}
if block.get("type") == "text-delta":
response_event = {"event_type": "messages", "data": {"event": data["event"]} | block}
yield f"data: {json.dumps(response_event, ensure_ascii=False)}\n\n"
elif block.get("type") == "reasoning-delta":
response_event = {"event_type": "messages", "data": {"event": data["event"]} | block}
yield f"data: {json.dumps(response_event, ensure_ascii=False)}\n\n"
else:
pass
# print(f"----------------{event}")
response_event = {"event_type": "done"}
yield f"data: {response_event}"
if __name__ == "__main__":
import asyncio
async def test_stream():
"""测试流式调用"""
with open("app/service/fashion_agent/graph_node/design_graph/design_request_data.json", encoding="utf-8") as f:
request_data = json.load(f)
service = FashionAgentService()
print("=" * 50)
print("测试流式输出")
print("=" * 50)
request = FashionAgentRequest(
message="生成一张草莓图案",
call_print=True,
# print_need_prompt_generation=False,
# call_sketch=True,
# sketch_need_prompt_generation=False,
# call_logo=True,
# call_explor=True,
# call_design=True,
# design_request_data=request_data,
)
async for chunk in service.run_stream(request):
print(chunk, end="")
# 运行测试
asyncio.run(test_stream())

View File

@@ -11,7 +11,6 @@ import logging
import uuid
import cv2
import mmcv
import numpy as np
import pandas as pd
import torch
@@ -21,6 +20,7 @@ from minio import Minio
from tritonclient.utils import np_to_triton_dtype
from app.core.config import settings, FAST_GI_MODEL_URL, GI_MODEL_URL, DESIGN_MODEL_URL, FAST_GI_MODEL_NAME, GI_MODEL_NAME
from app.service.utils.image_normalize import my_imnormalize
from app.service.utils.new_oss_client import oss_upload_image
logger = logging.getLogger()
@@ -86,10 +86,9 @@ class AgentToolGenerateImage:
@staticmethod
def preprocess(img):
img = mmcv.imread(img)
img_scale = (224, 224)
img = cv2.resize(img, img_scale)
img = mmcv.imnormalize(
img = my_imnormalize(
img,
mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]),
to_rgb=True)

View File

@@ -189,10 +189,10 @@ if __name__ == '__main__':
tasks_id="123-89",
prompt="a single item of sketch of dress, 4k, white background",
image_url="aida-collection-element/89/Sketchboard/95f20cdc-e059-435c-b8b1-d04cc9e80c3d.png",
mode='img2img',
mode='txt2img',
category="sketch",
gender="Female",
version="fast"
version="hight"
)
server = GenerateImage(rd)
print(server.get_result())

View File

@@ -2,23 +2,23 @@ import logging
import time
import cv2
import mmcv
import numpy as np
import torch
import tritonclient.http as httpclient
from app.core.config import settings, DESIGN_MODEL_URL, DESIGN_MODEL_NAME
from app.service.generate_image.utils.upload_sd_image import upload_stain_png_sd, upload_face_png_sd
from app.service.utils.image_normalize import my_imnormalize
logger = logging.getLogger()
def seg_preprocess(img_path):
img = mmcv.imread(img_path)
img = img_path
ori_shape = img.shape[:2]
img_scale = ori_shape
img = cv2.resize(img, img_scale)
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
img = my_imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img, ori_shape
@@ -242,10 +242,9 @@ def stain_detection(image, user_id, category, tasks_id, spot_size=100):
def generate_category_recognition(image, gender):
def preprocess(img):
img = mmcv.imread(img)
img_scale = (224, 224)
img = cv2.resize(img, img_scale)
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
img = my_imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img

View File

@@ -1,7 +1,6 @@
import logging
import cv2
import mmcv
import numpy as np
import torch
import torch.nn.functional as F
@@ -10,6 +9,7 @@ from minio import Minio
from app.core.config import settings
from app.core.config import DESIGN_MODEL_URL
from app.schemas.image2sketch import Image2SketchModel
from app.service.utils.image_normalize import my_imnormalize
from app.service.utils.new_oss_client import oss_get_image, oss_upload_image
logger = logging.getLogger()
@@ -67,7 +67,7 @@ class LineArtService:
@staticmethod
def line_art_preprocess(image):
img = mmcv.imread(image)
img = image
ori_shape = img.shape[:2]
img_scale_w, img_scale_h = ori_shape
if ori_shape[0] > 1024:
@@ -76,9 +76,9 @@ class LineArtService:
img_scale_h = 1024
# 如果图片size任意一边 大于 1024 则会resize 成1024
if ori_shape != (img_scale_w, img_scale_h):
# mmcv.imresize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了
# my_imnormalize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了
img = cv2.resize(img, (img_scale_h, img_scale_w))
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
img = my_imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img, ori_shape

View File

@@ -90,7 +90,7 @@ def get_response(messages):
def get_translation_from_llama3(text):
start_time = time.time()
url = "http://10.1.1.240:11434/api/generate"
url = f"http://{settings.A6000_SERVICE_HOST}:12434/api/generate"
# url = "http://10.1.1.240:1143/api/generate"
# prompt = f"System: {prefix_for_llama}\nUser:[{text}]"
@@ -103,8 +103,8 @@ def get_translation_from_llama3(text):
# 创建请求的负载 translator是自定义的翻译模型
payload = {
"model": "translator",
"prompt": f"[{text}]",
"model": "AiDA-translator:latest",
"prompt": text,
"stream": False
}
# 将负载转换为 JSON 格式
@@ -148,7 +148,7 @@ def get_translation_from_llama3(text):
def get_prompt_from_image(image_path, text):
start_time = time.time()
# url = "http://localhost:11434/api/generate"
url = "http://10.1.1.243:11434/api/generate"
url = f"http://{settings.B_4_X_4090_SERVICE_HOST}:11434/api/generate"
image_base64 = minio_util.minio_url_to_base64(image_path.img)
# image_base64 = minio_url_to_base64(image_path)
@@ -180,7 +180,7 @@ def get_prompt_from_image(image_path, text):
def main():
"""Main function"""
text = get_translation_from_llama3("[火焰]")
text = get_translation_from_llama3("火焰")
print(text)

View File

@@ -0,0 +1,170 @@
import json
import logging
import time
import requests
from dashscope import Generation
from requests import RequestException
from retry import retry
from app.core.config import settings
from app.service.chat_robot.script.prompt import GET_LANGUAGE_PREFIX
from app.service.prompt_generation.util import minio_util
logger = logging.getLogger(__name__)
def get_assistant_response(messages):
response = Generation.call(
model='qwen-max',
api_key=settings.QWEN_API_KEY,
messages=messages,
# seed=random.randint(1, 10000), # 设置随机数种子seed如果没有设置则随机数种子默认为1234
result_format='message', # 将输出设置为message形式
enable_search='false'
)
return response
def get_language(message: str) -> str:
messages = [
{
"content": GET_LANGUAGE_PREFIX, # ai message
"role": "system"
},
{
"content": "Tree", # 用户message
"role": "user"
},
{
"content": "English", # 用户message
"role": "assistant"
},
{
"content": "玩具", # 用户message
"role": "user"
},
{
"content": "Chinese", # 用户message
"role": "assistant"
},
{
"content": message, # 用户message
"role": "user"
}
]
first_response = get_assistant_response(messages)
assistant_output = first_response.output.choices[0].message.content
logging.info(f"大模型输出信息:{first_response}\n判断用户输入的语言为:{assistant_output}")
# print(f"大模型输出信息:{first_response}\n判断用户输入的语言为{assistant_output}")
return assistant_output
@retry(exceptions=RequestException, tries=3, delay=1)
def get_response(messages):
response = Generation.call(
model='qwen-turbo',
api_key=settings.QWEN_API_KEY,
messages=messages,
# seed=random.randint(1, 10000), # 设置随机数种子seed如果没有设置则随机数种子默认为1234
result_format='message', # 将输出设置为message形式
enable_search='True'
)
return response
def get_translation_from_llama3(text):
start_time = time.time()
url = f"http://{settings.A6000_SERVICE_HOST}:12434/api/generate"
# 先获取用户输入文本的语言
language = get_language(text)
if 'English' in language:
return text
# 创建请求的负载 translator是自定义的翻译模型
payload = {
"model": "AiDA-translator:latest",
"prompt": f"[{text}]",
"stream": False
}
# 将负载转换为 JSON 格式
headers = {'Content-Type': 'application/json'}
response = requests.post(url, data=json.dumps(payload), headers=headers)
# 处理响应
if response.status_code == 200:
# print("Response from server:")
# print(response.json())
resp = json.loads(response.content).get("response")
logger.info(f"translation server runtime is {time.time() - start_time} , response is {resp}")
print("input : {}, translate result : {}".format(text, resp))
return resp
else:
logger.info(f"translation server runtime is {time.time() - start_time} , response is {response.content}")
print(f"Request failed with status code {response.status_code}")
print(response.text)
return ""
# 在llama3中创建一个翻译模型
# def create_model_with_llama(text):
# url = "http://localhost:11434/api/create"
# # url = "http://20.1.1.43:1143/api/generate"
#
# # prompt = f"System: {prefix_for_llama}\nUser:[{text}]"
#
# # 创建翻译器的配置文件
# payload = {
# "model": "translator",
# "modelfile": "FROM llama3\nSYSTEM Translate everything within the brackets [] into English."
# "Never translate or modify any English input."
# "The input must be fully translated into coherent English sentences."
# }
#
# # 将负载转换为 JSON 格式
# headers = {'Content-Type': 'application/json'}
# response = requests.post(url, data=json.dumps(payload), headers=headers)
def get_prompt_from_image(image_path, text):
start_time = time.time()
# url = "http://localhost:11434/api/generate"
url = f"http://{settings.B_4_X_4090_SERVICE_HOST}:11434/api/generate"
image_base64 = minio_util.minio_url_to_base64(image_path.img)
# image_base64 = minio_url_to_base64(image_path)
# 创建请求的负载 translator是自定义的翻译模型
payload = {
"model": "llama3.2-vision",
"images": [image_base64],
"prompt": f"{text}",
"stream": False
}
# 将负载转换为 JSON 格式
headers = {'Content-Type': 'application/json'}
response = requests.post(url, data=json.dumps(payload), headers=headers)
# 处理响应
if response.status_code == 200:
# print("Response from server:")
# print(response.json())
resp = json.loads(response.content).get("response")
logger.info(f"sketch re-generate server runtime is {time.time() - start_time} \n, response is {resp}")
# print("input : {}, sketch re-generate result : {}".format(text, resp))
return resp
else:
logger.info(f"sketch re-generate server runtime is {time.time() - start_time} , response is {response.content}")
print(f"Request failed with status code {response.status_code}")
print(response.text)
return ""
def main():
"""Main function"""
text = get_translation_from_llama3("[火焰]")
print(text)
if __name__ == "__main__":
main()

View File

@@ -14,7 +14,7 @@ REDIS_KEY_USER_PREF_PREFIX = "user_pref"
RECOMMENDATION_CONFIG = {
# 时间衰减半衰期(用于计算时间衰减权重)
# 值越小,最近的行为权重越大
"K_half": 20,
"K_half": 10,
# 探索与利用的比例 (0.0-1.0)
# - 值越大,使用探索分支(随机推荐)的几率越大,结果更随机
@@ -25,7 +25,7 @@ RECOMMENDATION_CONFIG = {
# 向量检索返回的候选数量
# 值越大,候选池越大,但计算成本也越高
# 建议范围: 100-1000
"topk": 1000,
"topk": 200,
# Style 加分系数(同 style 的候选进行加分)
# 值越大,匹配 style 的候选被选中的概率越大
@@ -53,7 +53,7 @@ RECOMMENDATION_CONFIG = {
}
# 数据库表名
TABLE_USER_PREFERENCE_LOG = "user_preference_log_test"
TABLE_USER_PREFERENCE_LOG = "user_preference"
TABLE_SYS_FILE = "t_sys_file"
# MySQL 连接配置(用于推荐系统)

View File

@@ -1,6 +1,6 @@
"""
增量监听模块
实时监听 user_preference_log_test 表的新增记录,更新用户偏好向量
实时监听 user_preference 表的新增记录,更新用户偏好向量
"""
import logging
import math
@@ -48,7 +48,7 @@ class IncrementalListener:
if self.last_process_time is None:
# 第一次运行查询最近30分钟的数据
cursor.execute(f"""
SELECT id, account_id, path, category, style, data_time, is_system_sketch, sys_file_id
SELECT id, account_id, path, category, style, data_time
FROM {TABLE_USER_PREFERENCE_LOG}
WHERE data_time > DATE_SUB(NOW(), INTERVAL 30 MINUTE)
ORDER BY data_time
@@ -56,7 +56,7 @@ class IncrementalListener:
else:
# 基于上次处理时间查询
cursor.execute(f"""
SELECT id, account_id, path, category, style, data_time, is_system_sketch, sys_file_id
SELECT id, account_id, path, category, style, data_time
FROM {TABLE_USER_PREFERENCE_LOG}
WHERE data_time > %s
ORDER BY data_time
@@ -258,7 +258,7 @@ class IncrementalListener:
}
else:
# 用户图
# 从 user_preference_log_test 获取 category如果有
# 从 user_preference 获取 category如果有
cursor.execute(f"""
SELECT category
FROM {TABLE_USER_PREFERENCE_LOG}
@@ -308,6 +308,10 @@ class IncrementalListener:
def start_background_listener(scheduler: BackgroundScheduler):
"""将增量监听任务注册到后台调度器"""
# 降低 apscheduler 的日志级别,避免大量刷屏
logging.getLogger('apscheduler.executors.default').setLevel(logging.WARNING)
logging.getLogger('apscheduler.scheduler').setLevel(logging.WARNING)
listener = IncrementalListener()
scheduler.add_job(
listener.process_once,

View File

@@ -203,39 +203,74 @@ def search_similar_vectors(
query_vector: np.ndarray,
category: str,
topk: int = 500,
style: Optional[str] = None
style: Optional[str] = None,
style_boost_ratio: float = 0.2
) -> List[Dict]:
"""
向量相似度检索
Args:
query_vector: 查询向量2048维
category: 类别过滤
topk: 返回数量
style: 风格过滤(可选)
style: 风格过滤(可选)- 当提供时会给对应style的结果加分
style_boost_ratio: 风格加分比例默认0.1即10%
Returns:
检索结果列表,每个元素包含 path, score, style, category 等字段
"""
client = get_milvus_client()
try:
# 构建过滤表达式
# 使用 filter 参数而不是 expr根据 pymilvus MilvusClient API
filter_expr = f"category == '{category}' && deprecated == 0"
if style:
filter_expr += f" && style == '{style}'"
# 如果没有指定style使用原始逻辑
if not style:
filter_expr = f"category == '{category}' && deprecated == 0"
results = client.search(
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
data=[query_vector.tolist()],
anns_field="feature_vector",
search_params={"metric_type": "IP", "params": {"nprobe": 10}},
limit=topk,
filter=filter_expr,
output_fields=["path", "style", "category", "sys_file_id"]
)
else:
# 有style参数时使用两阶段搜索策略
# 搜索
results = client.search(
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
data=[query_vector.tolist()],
anns_field="feature_vector",
search_params={"metric_type": "IP", "params": {"nprobe": 10}},
limit=topk,
filter=filter_expr,
output_fields=["path", "style", "category", "sys_file_id"]
)
# 第一阶段搜索匹配style的向量使用boosted query vector
filter_expr_style = f"category == '{category}' && deprecated == 0 && style == '{style}'"
boosted_query = query_vector * (1 + style_boost_ratio)
results_style = client.search(
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
data=[boosted_query.tolist()],
anns_field="feature_vector",
search_params={"metric_type": "IP", "params": {"nprobe": 10}},
limit=topk,
filter=filter_expr_style,
output_fields=["path", "style", "category", "sys_file_id"]
)
# 第二阶段搜索其他style的向量
filter_expr_others = f"category == '{category}' && deprecated == 0 && style != '{style}'"
results_others = client.search(
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
data=[query_vector.tolist()],
anns_field="feature_vector",
search_params={"metric_type": "IP", "params": {"nprobe": 10}},
limit=topk,
filter=filter_expr_others,
output_fields=["path", "style", "category", "sys_file_id"]
)
# 合并结果
results = []
if results_style and len(results_style) > 0:
results.extend(results_style[0])
if results_others and len(results_others) > 0:
results.extend(results_others[0])
# 转换为单个结果列表格式
results = [results] if results else []
# 格式化结果
formatted_results = []
@@ -249,7 +284,10 @@ def search_similar_vectors(
"sys_file_id": hit.get("entity", {}).get("sys_file_id")
})
return formatted_results
# 按分数排序并返回topk
formatted_results.sort(key=lambda x: x["score"], reverse=True)
return formatted_results[:topk]
except Exception as e:
logger.error(f"向量检索失败: {e}", exc_info=True)
return []
@@ -280,7 +318,7 @@ def query_random_candidates(category: str, style: Optional[str] = None, limit: i
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
filter=filter_expr,
output_fields=["path", "style", "category"],
limit=10000 # 先查询大量数据,然后随机选择
limit=10000
)
# 随机选择

View File

@@ -6,6 +6,7 @@ import logging
import math
import pymysql
import numpy as np
from datetime import datetime
from typing import List, Dict, Tuple, Optional
from collections import defaultdict
@@ -25,7 +26,7 @@ logger = logging.getLogger(__name__)
def optimize_database_table():
"""
优化 user_preference_log_test 表结构
优化 user_preference 表结构
添加冗余字段和索引
"""
conn = None
@@ -317,8 +318,8 @@ def precompute_system_sketch_vectors(batch_size: int = 1000, retry_times: int =
def compute_user_preference_vector(
account_id: int,
category: str,
conn: Optional[pymysql.connections.Connection] = None
# max_date: Optional[datetime] = None
conn: Optional[pymysql.connections.Connection] = None,
max_date: Optional[datetime] = None
) -> Optional[np.ndarray]:
"""
计算用户偏好向量
@@ -419,8 +420,8 @@ def compute_user_preference_vector(
p_i = 1 + math.log(1 + like_count)
# 综合权重
# w_i = d_k * p_i
w_i = p_i
w_i = d_k * p_i
# w_i = p_i
vectors.append(feature_vector)
weights.append(w_i)
@@ -518,16 +519,16 @@ def run_precompute():
logger.info("=" * 50)
# 1. 优化数据库表结构
logger.info("\n[1/5] 优化数据库表结构...")
optimize_database_table()
# logger.info("\n[1/5] 优化数据库表结构...")
# optimize_database_table()
# # 2. 创建 Milvus 集合
# logger.info("\n[2/5] 创建 Milvus 集合...")
# create_collection()
# 3. 历史数据迁移
logger.info("\n[3/5] 历史数据迁移...")
migrate_historical_data()
# logger.info("\n[3/5] 历史数据迁移...")
# migrate_historical_data()
# # 4. 系统图向量预计算
# logger.info("\n[4/5] 系统图向量预计算...")
@@ -543,13 +544,13 @@ def run_precompute():
if __name__ == "__main__":
# 1. 优化数据库表结构
logger.info("\n[1/5] 优化数据库表结构...")
optimize_database_table()
# 3. 历史数据迁移
logger.info("\n[3/5] 历史数据迁移...")
migrate_historical_data()
# # 1. 优化数据库表结构
# logger.info("\n[1/5] 优化数据库表结构...")
# optimize_database_table()
#
# # 3. 历史数据迁移
# logger.info("\n[3/5] 历史数据迁移...")
# migrate_historical_data()
# 5. 初始用户偏好向量生成
logger.info("\n[5/5] 初始用户偏好向量生成...")

View File

@@ -0,0 +1,35 @@
import logging
import httpx
logger = logging.getLogger("app")
async def notify_callback(callback_url: str, task_id: str, status: str, result: dict, ):
"""
调用客户端提供的回调接口
"""
try:
payload = {
"task_id": task_id,
"status": status,
"result": result
}
logger.info(payload)
async with httpx.AsyncClient(timeout=30.0) as client:
resp = await client.post(
str(callback_url),
json=payload,
headers={"Content-Type": "application/json"}
)
if 200 <= resp.status_code < 300:
logger.info(f"回调成功 | task_id: {task_id} | status: {status} | url: {callback_url}")
return True
else:
logger.warning(f"回调返回非2xx状态码 | task_id: {task_id} | status: {resp.status_code} | url: {callback_url}")
return False
except Exception as e:
logger.error(f"回调失败 | task_id: {task_id} | url: {callback_url} | error: {e}", exc_info=True)
return False

View File

@@ -0,0 +1,46 @@
from celery import Celery
from kombu import Queue, Exchange
from app.core.config import settings
celery_app = Celery(
"sketch_to_garment",
broker=f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}/2",
backend=f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB}",
include=["app.service.sketch2garment.tasks"]
)
print(f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}/3")
print(f"celery_app: {celery_app}")
celery_app.conf.update(
task_serializer="json",
accept_content=["json"],
result_serializer="json",
timezone="Asia/Hong_Kong",
enable_utc=True,
task_track_started=True,
task_time_limit=300, # 单个任务最长 5 分钟
task_soft_time_limit=280,
# 定义队列
task_queues=(
Queue("sketch_to_garment_queue",
exchange=Exchange("sketch_to_garment_exchange", type="direct"),
durable=True),
),
task_routes={
'app.service.sketch2garment.tasks.sketch_to_garment':
{
'queue': 'sketch_to_garment_queue',
'exchange': 'sketch_to_garment_exchange', # ← 修改这里
},
},
task_default_queue="sketch_to_garment_queue",
worker_concurrency=1,
worker_prefetch_multiplier=1,
worker_max_tasks_per_child=1,
task_acks_late=True,
task_reject_on_worker_lost=True,
)

View File

@@ -0,0 +1,44 @@
import logging
from app.service.sketch2garment.tasks import sketch_to_garment
logger = logging.getLogger(__name__)
def submit_sketch_to_garment_task(model: str = "single", task_id: str = "", callback_url: str = "", bucket_name: str = "test", user_id: str = "123", input_image_path: str = ""):
"""提交 img_to_3D 任务(带队列长度限制)"""
queue_name = "img_to_3d_queue"
max_queue_length = 10
try:
# current_length = get_queue_length(queue_name)
# if current_length >= max_queue_length:
# return {
# "state": "queue_full",
# "message": "当前 3D 生成请求较多,请稍后重试。",
# "queue_length": current_length,
# "max_length": max_queue_length
# }
# 提交任务
task = sketch_to_garment.apply_async(
args=(task_id, callback_url, bucket_name, input_image_path, user_id, model),
task_id=task_id,
queue="sketch_to_garment_queue")
# logger.info(f"img_to_3d_task 已提交 | task_id: {task_id} | 当前队列长度: {current_length}")
return {
"state": "success",
"task_id": task_id,
"message": "任务已成功提交,正在后台处理...",
}
except Exception as e:
logger.error(f"提交 img_to_3d_task 失败: {e}", exc_info=True)
return {
"state": "fail",
"message": "提交失败,请稍后重试。",
"error": str(e)
}

View File

@@ -0,0 +1,57 @@
import asyncio
import logging
from app.core.config import settings
from app.service.sketch2garment.callback import notify_callback
import httpx
from app.service.sketch2garment.celery_app import celery_app
logger = logging.getLogger(__name__)
@celery_app.task(bind=True, queue="sketch_to_garment_queue", max_retries=3, name='app.service.sketch2garment.tasks.sketch_to_garment')
def sketch_to_garment(self, task_id: str, callback_url: str, bucket_name: str, input_image_path: str, user_id: str, category: str = None):
payload = {
"bucket_name": bucket_name,
"category": category or settings.DEFAULT_CATEGORY,
"input_image_path": input_image_path,
"user_id": user_id
}
logger.info(f"payload: {payload}")
try:
with httpx.Client(timeout=300.0) as client: # 注意这里用 AsyncClient 配合 Celery
# 如果你的 LitServe 是同步 endpoint也可以用 httpx.Client()
response = client.post(settings.SKETCH_TO_GARMENT_URL, json=payload)
if response.status_code == 200:
result = response.json()
result_json = {
"pattern": result[1],
"texture": result[2],
"glb": result[3],
"texture_fabric": result[4]
}
asyncio.run(
notify_callback(callback_url=callback_url, task_id=task_id, result=result_json, status="success")
)
else:
asyncio.run(
notify_callback(
callback_url=callback_url,
task_id=task_id,
result={
"status": "fail",
"task_id": task_id,
"message": "fail",
"error": "fail"
},
status="fail")
)
except Exception as e:
return {
"status": "failed",
"task_id": task_id,
"input": payload,
"error": str(e)
}

View File

@@ -0,0 +1,27 @@
import cv2
import numpy as np
def my_imnormalize(img, mean, std, to_rgb=True):
"""Inplace normalize an image with mean and std.
Args:
img (ndarray): Image to be normalized.
mean (ndarray): The mean to be used for normalize.
std (ndarray): The std to be used for normalize.
to_rgb (bool): Whether to convert to rgb.
Returns:
ndarray: The normalized image.
"""
# cv2 inplace normalization does not accept uint8
img = img.copy().astype(np.float32)
assert img.dtype != np.uint8
mean = np.float64(mean.reshape(1, -1))
stdinv = 1 / np.float64(std.reshape(1, -1))
if to_rgb:
cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) # inplace
cv2.subtract(img, mean, img) # inplace
cv2.multiply(img, stdinv, img) # inplace
return img

View File

@@ -11,7 +11,12 @@ from minio import Minio
from app.core.config import settings
from app.service.utils.decorator import RunTime
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
minio_client = Minio(
settings.MINIO_URL,
access_key=settings.MINIO_ACCESS,
secret_key=settings.MINIO_SECRET,
secure=settings.MINIO_SECURE,
)
# 自定义 Retry 类
@@ -30,7 +35,7 @@ http_client = urllib3.PoolManager(
num_pools=10, # 设置连接池大小
maxsize=10,
timeout=timeout,
cert_reqs='CERT_REQUIRED', # 需要证书验证
cert_reqs="CERT_REQUIRED", # 需要证书验证
retries=CustomRetry(
total=5,
backoff_factor=0.2,
@@ -51,7 +56,7 @@ def oss_get_image(oss_client, bucket, object_name, data_type):
image_array = np.frombuffer(image_bytes, np.uint8) # 转成8位无符号整型
image_object = cv2.imdecode(image_array, cv2.IMREAD_UNCHANGED)
if image_object.dtype == np.uint16:
image_object = (image_object / 256).astype('uint8')
image_object = (image_object / 256).astype("uint8")
else:
data_bytes = BytesIO(image_data.read())
image_object = Image.open(data_bytes)
@@ -63,13 +68,19 @@ def oss_get_image(oss_client, bucket, object_name, data_type):
def oss_upload_image(oss_client, bucket, object_name, image_bytes):
req = None
try:
req = oss_client.put_object(bucket_name=bucket, object_name=object_name, data=io.BytesIO(image_bytes), length=len(image_bytes), content_type='image/png')
req = oss_client.put_object(
bucket_name=bucket,
object_name=object_name,
data=io.BytesIO(image_bytes),
length=len(image_bytes),
content_type="image/png",
)
except Exception as e:
logger.warning(f" | 上传图片出现异常 ######: {e}")
return req
if __name__ == '__main__':
if __name__ == "__main__":
# url = "aida-results/result_0002186a-e631-11ee-86a6-b48351119060.png"
# url = "aida-collection-element/11523/Moodboard/f60af0d2-94c2-48f9-90ff-74b8e8a481b5.jpg"
# url = "aida-sys-image/images/female/outwear/0628000054.jpg"
@@ -81,16 +92,26 @@ if __name__ == '__main__':
# url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg"
# url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png"
# url = "aida-users/89/single_logo/123-89.png"
url = "lanecarford/lc_stylist_agent_outfit_items/141/ee25ec85-d504-4b42-9a18-db6682fe9e3b-6.jpg"
url = "aida-users/agent/logo/019e91c4-7b1f-74e2-b716-d652713101cd.png"
# url = "aida-collection-element/12148/Sketchboard/95ea577b-305b-4a62-b30a-39c0dd3ddb3f.png"
read_type = "2"
if read_type == "cv2":
img = oss_get_image(oss_client=minio_client, bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type)
img = oss_get_image(
oss_client=minio_client,
bucket=url.split("/")[0],
object_name=url[url.find("/") + 1 :],
data_type=read_type,
)
cv2.imshow("", img)
cv2.waitKey(0)
else:
img = oss_get_image(oss_client=minio_client, bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type)
img = oss_get_image(
oss_client=minio_client,
bucket=url.split("/")[0],
object_name=url[url.find("/") + 1 :],
data_type=read_type,
)
draw = ImageDraw.Draw(img)
# 获取图片尺寸
width, height = img.size
@@ -103,7 +124,7 @@ if __name__ == '__main__':
draw.line(
[(center_x, 0), (center_x, height)], # 从顶部到底部的垂直线
fill=(255, 0, 0), # 红色 (R, G, B)
width=2 # 线宽
width=2, # 线宽
)
img.show()

View File

@@ -91,6 +91,21 @@ class Redis(object):
r = cls._get_r()
r.expire(name, expire_in_seconds)
@classmethod
def scan_keys(cls, pattern="*"):
"""
扫描匹配模式的key
"""
r = cls._get_r()
keys = []
cursor = 0
while True:
cursor, partial_keys = r.scan(cursor, match=pattern, count=1000)
keys.extend(partial_keys)
if cursor == 0:
break
return [key.decode('utf-8') if isinstance(key, bytes) else key for key in keys]
if __name__ == '__main__':
redis_client = Redis()

View File

@@ -1,25 +1,20 @@
services:
aida_server:
container_name: "AiDA_${SERVE_ENV}_Server"
build:
context: .
dockerfile: Dockerfile
working_dir: /app
volumes:
- ./app:/app/app
- ./.env_prod:/app/.env
- ./.env:/app/.env
- /etc/localtime:/etc/localtime:ro
- ./seg_cache:/seg_cache
ports:
- "10200:80"
depends_on:
- redis
redis:
image: redis
container_name: aida_redis
restart: always
ports:
- "6400:6379"
volumes:
- ./redis/data:/data
- ./redis/conf/redis.conf:/etc/redis/redis.conf
command: redis-server /etc/redis/redis.conf --appendonly yes
- "${SERVE_PORT}:80"
networks:
- aida_app_net
networks:
aida_app_net:
external: true
name: aida_app_net

View File

@@ -13,18 +13,26 @@ dependencies = [
"celery-types>=0.23.0",
"chromadb>=1.3.7",
"dashscope>=1.25.5",
"deepagents>=0.6.7",
"dominate>=2.9.1",
"dotenv>=0.9.9",
"fastapi[standard]>=0.125.0",
"image>=1.5.33",
"langchain>=1.2.0",
"langchain-community>=0.4.1",
"langchain-openai>=1.2.2",
"langchain-qwq>=0.3.5",
"langgraph>=1.0.5",
"langgraph-api>=0.4.28",
"langgraph-cli[inmem,redis]<=0.4.26",
"langsmith>=0.8.11",
"load>=1.0.14",
"load-dotenv>=0.1.0",
"loguru>=0.7.3",
"minio>=7.2.20",
"mmcv>=2.2.0",
"moviepy==1.0.3",
"nacos-sdk-python==2.0.1",
"np>=1.0.2",
"numpy<2",
"ollama>=0.6.1",
"opencv-python>=4.11.0.86",
@@ -49,6 +57,6 @@ dependencies = [
"tool>=0.8.0",
"torch>=2.9.1",
"torchvision>=0.24.1",
"tritonclient[all]>=2.63.0",
"tritonclient[all]>=2.69.0",
"uvicorn>=0.38.0",
]

Binary file not shown.

Binary file not shown.

2675
uv.lock generated

File diff suppressed because it is too large Load Diff