85 Commits

Author SHA1 Message Date
litianxiang
fb46a9521d Merge remote-tracking branch 'origin/develop' into dev-ltx
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-13 13:57:28 +08:00
litianxiang
b90688f835 更改增量更新日志级别 2026-01-13 13:57:15 +08:00
zcr
7e30779aec feat: seg any thing 新增box模式
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-13 12:43:30 +08:00
zcr
f7294f5966 feat: seg any thing 新增box模式
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-13 12:32:18 +08:00
zcr
0ac5a4e0a8 Merge remote-tracking branch 'origin/develop' into develop
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-12 16:18:15 +08:00
zcr
40b57b749c feat: 新增design模式 merge,前端CV python 合成 2026-01-12 16:18:04 +08:00
litianxiang
b8a538a8a1 fix:增量更新向量问题修改
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-12 13:59:06 +08:00
litianxiang
29b4f43a27 debug:推荐接口
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-12 13:34:56 +08:00
litianxiang
69dc20207d debug:推荐接口
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-12 13:03:58 +08:00
litianxiang
18979af604 debug:推荐接口返回redis值
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-12 13:01:26 +08:00
litianxiang
74406f9be4 推荐接口更新向量接口注册
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-12 11:59:01 +08:00
litianxiang
df99e3ac76 新增查看redis内容接口
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-12 11:51:37 +08:00
litianxiang
19346c2eb7 Merge remote-tracking branch 'origin/develop' into dev-ltx
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-12 09:51:52 +08:00
litianxiang
2af9cbfe78 fix:推荐接口 2026-01-12 09:49:07 +08:00
zcr
fe12b5697d fix: design 镜像默认值修改,旋转方向和前端保持一致
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-09 17:40:49 +08:00
zcr
c04d4877b0 fix: design 回参新增镜像旋转参数
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-09 17:12:53 +08:00
zcr
91016e6cae fix: design 回参新增镜像旋转参数
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-09 17:08:16 +08:00
zcr
0f4bb260ad fix: design 回参新增镜像旋转参数
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-09 17:06:39 +08:00
zcr
c792106f02 fix: design 回参新增镜像旋转参数
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-09 15:42:42 +08:00
zcr
deac5a4cab fix: design item sketch旋转参数为none
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-09 12:31:34 +08:00
zcr
15682036b3 feat : 新增seg anything 接口 ,接口文档补充
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-08 17:39:27 +08:00
zcr
9ba3a0ca49 feat : 新增seg anything 接口
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-08 17:33:54 +08:00
zcr
f6963070fb feat : 支持上下左右同时镜像
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-08 13:47:44 +08:00
zcr
12f5ca3ca3 feat : design 示例说明
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-08 10:44:02 +08:00
zcr
19110f51bf feat : design 示例说明
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-08 10:29:31 +08:00
zcr
e04636ce21 feat : design overall print 新增平铺间距和旋转角度
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-07 17:03:02 +08:00
zcr
2a50e7040e feat : design overall print 新增平铺间距和旋转角度
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-07 16:22:19 +08:00
zcr
a6f3bda9f7 feat : design 单品新增 镜像旋转功能
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-06 12:21:10 +08:00
zcr
c18f45e549 feat : design 单品新增 镜像旋转功能
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-06 12:00:58 +08:00
zcr
4951fab71a 代码整理
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2025-12-30 17:49:22 +08:00
zcr
aa57478852 新推荐接口first commit
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2025-12-30 17:35:32 +08:00
zcr
2a6c48d937 新推荐接口first commit
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2025-12-30 17:23:36 +08:00
litianxiang
fed3fcdf85 新推荐接口first commit 2025-12-30 17:18:12 +08:00
zcr
417528f8cd feat : 代码梳理 移除所有敏感密钥 通过环境变量方式配置
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2025-12-30 16:52:20 +08:00
zcr
18024a2d70 feat : 代码梳理 移除所有敏感密钥 通过环境变量方式配置
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2025-12-30 16:49:08 +08:00
litianxiang
1be716e414 Merge remote-tracking branch 'origin/dev-ltx' into develop
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
# Conflicts:
#	app/api/api_recommendation.py
#	app/service/design_fast/utils/organize.py
2025-12-30 10:19:19 +08:00
litianxiang
826bdcf9c1 mysql更改库名 2025-12-29 16:19:52 +08:00
litianxiang
f351184630 新推荐接口first commit 2025-12-29 10:52:33 +08:00
zcr
fac1eab1bc feat : design 新增 callback 模式
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2025-12-16 11:59:28 +08:00
832ca6fd05 更新 .gitea/workflows/develop_build_scheduled.yaml
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2025-11-28 17:29:43 +08:00
673423131a 更新 .gitea/workflows/develop_build_scheduled.yaml
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2025-11-28 17:25:32 +08:00
6e15430a83 更新 .gitea/workflows/develop_build_scheduled.yaml
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2025-11-28 17:25:21 +08:00
51068d2215 上传文件至「.gitea/workflows」
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2025-11-28 17:23:33 +08:00
d493d9eff6 更新 .gitea/workflows/develop_build_commit.yaml
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2025-11-28 17:19:54 +08:00
7d970a7bba 上传文件至「.gitea/workflows」
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Successful in 14s
2025-11-28 17:19:14 +08:00
3fc6720bf7 更新 .gitea/workflows/develop_build_commit.yaml
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Successful in 19s
2025-11-28 17:11:49 +08:00
efa2e3a4a9 更新 .gitea/workflows/develop_build_commit.yaml
Some checks failed
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Failing after 7s
2025-11-28 17:10:11 +08:00
c6af01bc51 上传文件至「.gitea/workflows」
Some checks failed
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Failing after 8s
2025-11-28 17:02:47 +08:00
zhh
448af4ab6b feat(新功能):
fix(修复bug):  推荐接口 请求参数打印
docs(文档变更):
refactor(重构):
test(增加测试):
2025-11-24 15:32:25 +08:00
zhh
8a9f160cfa feat(新功能):
fix(修复bug):  推荐接口 请求参数打印
docs(文档变更):
refactor(重构):
test(增加测试):
2025-11-24 15:18:02 +08:00
zchengrong
6e06c8b516 feat(新功能): 新增图层越界控制参数
fix(修复bug):
docs(文档变更):
refactor(重构):
test(增加测试):
2025-11-17 16:14:22 +08:00
zchengrong
322fb9c46b feat(新功能):
fix(修复bug):
docs(文档变更):
refactor(重构):
test(增加测试): 新增天祥回调url
2025-11-17 14:54:04 +08:00
zchengrong
30bfd22e3e feat(新功能):
fix(修复bug):   骨架生成视频 首帧曝光问题
docs(文档变更):
refactor(重构):
test(增加测试):
2025-11-07 17:44:19 +08:00
zchengrong
e8d8b715ae feat(新功能):
fix(修复bug):   accessories 替换为 others
docs(文档变更):
refactor(重构):
test(增加测试):
2025-11-07 10:56:34 +08:00
zchengrong
7d2149dcaf feat(新功能):
fix(修复bug):
docs(文档变更):
refactor(重构):
test(增加测试): 天祥design回调
2025-11-06 11:59:57 +08:00
zchengrong
fee9334b1f feat(新功能):
fix(修复bug):  修复骨架生成视频前两帧过曝问题
docs(文档变更):
refactor(重构):
test(增加测试):
2025-11-06 11:59:04 +08:00
zchengrong
85c486c3dc feat(新功能): 新增 骨架/图片+prompt/首尾帧+prompt -> 视频接口
fix(修复bug):
docs(文档变更):
refactor(重构):
test(增加测试):
2025-11-05 17:08:40 +08:00
zchengrong
0e7ef80eed feat(新功能):
fix(修复bug):  印花部分,修改no_seg_sketch_print 丢失overall 印花问题
docs(文档变更):
refactor(重构):
test(增加测试):
2025-11-04 16:53:27 +08:00
zhh
8ccbbe41b1 feat(新功能): 新增wan2.2 pose-transform模型接口,comfyui-api形式
fix(修复bug):
docs(文档变更):
refactor(重构):
test(增加测试):
2025-11-03 17:37:33 +08:00
zhh
98468ea7aa feat(新功能): 新增wan2.2 pose-transform模型接口,comfyui-api形式
fix(修复bug):
docs(文档变更):
refactor(重构):
test(增加测试):
2025-11-03 16:39:42 +08:00
zhh
a9d9bdcb71 feat(新功能): 新增wan2.2 pose-transform模型接口,comfyui-api形式
fix(修复bug):
docs(文档变更):
refactor(重构):
test(增加测试):
2025-11-03 16:37:41 +08:00
zhh
7459583377 feat(新功能): 新增wan2.2 pose-transform模型接口,comfyui-api形式
fix(修复bug):
docs(文档变更):
refactor(重构):
test(增加测试):
2025-11-03 11:52:39 +08:00
zhh
385ff2d4aa test 2025-10-16 14:35:47 +08:00
zhh
02ad5db269 feat(新功能): fix(修复bug):两种未分割sketch修复 refactor(重构): test(增加测试): 2025-09-26 22:58:11 +08:00
zhh
1d90963ded feat(新功能):回溯-single 前后片优先级默认为20,-20 fix(修复bug): refactor(重构): test(增加测试): 2025-09-26 18:05:46 +08:00
zhh
d1fefceebf feat(新功能):single 前后片优先级默认为20,-20 fix(修复bug): refactor(重构): test(增加测试): 2025-09-26 17:52:24 +08:00
zhh
242ebfc1df Revert "feat(新功能):更换翻译模型,关闭语种判断逻辑 fix(修复bug): refactor(重构): test(增加测试):"
This reverts commit b8cf3d25b4.
2025-09-26 14:51:24 +08:00
zhh
b8cf3d25b4 feat(新功能):更换翻译模型,关闭语种判断逻辑 fix(修复bug): refactor(重构): test(增加测试): 2025-09-26 14:46:03 +08:00
zhh
95647be610 feat(新功能): fix(修复bug):design的宽度自适应-宽度算法修复 refactor(重构): test(增加测试): 2025-09-26 13:20:26 +08:00
zhh
e966ed5aa5 feat(新功能): 1、design-print为解决sketch原图太灰导致印花颜色便暗 上色部分使用原始方案 2、cv2.resize 插值算法更换,提升resize后图片质量 fix(修复bug): refactor(重构): test(增加测试): 2025-09-26 10:44:29 +08:00
zhh
0d4d464e3f feat(新功能): 1、design-print为解决sketch原图太灰导致印花颜色便暗 2、cv2.resize 插值算法更换,提升resize后图片质量 fix(修复bug): refactor(重构): test(增加测试): 2025-09-26 10:29:39 +08:00
zhh
4bc79e62ca feat(新功能):design 新增两个中间结果(未分割图层) 1.color + overall_print 2.color + overall_print + print fix(修复bug): refactor(重构): test(增加测试): 2025-09-25 16:01:28 +08:00
zhh
bf1fb8e514 feat(新功能):design 新增两个中间结果(未分割图层) 1.color + overall_print 2.color + overall_print + print fix(修复bug): refactor(重构): test(增加测试): 2025-09-25 15:39:17 +08:00
zhh
d720bf2209 feat(新功能):design 新增两个中间结果(未分割图层) 1.color + overall_print 2.color + overall_print + print fix(修复bug): refactor(重构): test(增加测试): 2025-09-25 15:32:23 +08:00
zhh
8f486867d5 feat(新功能): fix(修复bug): : refactor(重构): test(增加测试): 徐佩design测试 2025-09-23 10:18:06 +08:00
zhh
1f45fe48a3 feat(新功能): fix(修复bug): : refactor(重构): test(增加测试): 徐佩design测试 2025-09-23 10:12:19 +08:00
zhh
79865d9a96 feat(新功能): fix(修复bug): : refactor(重构): test(增加测试): 2025-09-22 11:17:59 +08:00
zhh
a9a5964127 feat(新功能): fix(修复bug): : refactor(重构): test(增加测试): 2025-09-22 11:16:37 +08:00
zhh
47e991cd76 feat(新功能): fix(修复bug): : refactor(重构): test(增加测试): 2025-09-22 10:53:46 +08:00
zhh
8bc1ea576e feat(新功能):
fix(修复bug):  overall 坐标算法新增比例参数
docs(文档变更):
refactor(重构):
test(增加测试):
2025-09-22 10:48:58 +08:00
zhh
31e848e8bb feat(新功能): 灰色sketch图使print颜色变暗 解决方案测试
fix(修复bug):
docs(文档变更):
refactor(重构):
test(增加测试):
2025-09-22 10:41:09 +08:00
zhh
6da3712a76 feat(新功能):
fix(修复bug):  印花坐标计算方式新增sketch拉伸比例
docs(文档变更):
refactor(重构):
test(增加测试):
2025-09-19 10:40:21 +08:00
zhh
e6da512a31 feat(新功能):
fix(修复bug):  pattern_image (上色无印花sketch图),修改为不拉伸
docs(文档变更):
refactor(重构):
test(增加测试):
2025-09-18 15:33:01 +08:00
zhh
16d4844cca feat(新功能):
fix(修复bug):  pattern_image (上色无印花sketch图),修改为不拉伸
docs(文档变更):
refactor(重构):
test(增加测试):
2025-09-18 15:17:51 +08:00
zhh
978e0d998d feat(新功能):
fix(修复bug):  pattern_image (上色无印花sketch图),修改为不拉伸
docs(文档变更):
refactor(重构):
test(增加测试):
2025-09-18 15:05:31 +08:00
179 changed files with 10867 additions and 8523 deletions

View File

@@ -0,0 +1,44 @@
name: git commit AiDA python develop 分支构建部署
on:
workflow_dispatch:
push:
branches:
- develop
jobs:
scheduled_deploy:
runs-on: ubuntu-latest
if: "contains(github.event.head_commit.message, '[run build]')"
env:
REMOTE_DEPLOY_PATH: /workspace/Trinity/Fastapi_AiDA_Trinity_Dev
steps:
- name: 1.检出代码
uses: actions/checkout@v4
with:
ref: 'develop'
- name: 2.复制文件到服务器
uses: appleboy/scp-action@v0.1.7
with:
host: ${{ secrets.SERVER_HOST }}
username: ${{ secrets.SERVER_USER }}
password: ${{ secrets.SERVER_PASSWORD }}
source: "."
target: ${{ env.REMOTE_DEPLOY_PATH }}
- name: Restart Docker containers
uses: appleboy/ssh-action@v0.1.10
with:
host: ${{ secrets.SERVER_HOST }}
username: ${{ secrets.SERVER_USER }}
password: ${{ secrets.SERVER_PASSWORD }}
script: |
# 进入项目目录
cd ${{ env.REMOTE_DEPLOY_PATH }}
docker-compose down 2>&1
docker-compose up -d --build --remove-orphans 2>&1
docker image prune -f 2>&1

View File

@@ -0,0 +1,40 @@
name: 手动 AiDA python develop 分支构建部署
on:
workflow_dispatch:
jobs:
scheduled_deploy:
runs-on: ubuntu-latest
env:
REMOTE_DEPLOY_PATH: /workspace/Trinity/Fastapi_AiDA_Trinity_Dev
steps:
- name: 1.检出代码
uses: actions/checkout@v4
with:
ref: 'develop'
- name: 2.复制文件到服务器
uses: appleboy/scp-action@v0.1.7
with:
host: ${{ secrets.SERVER_HOST }}
username: ${{ secrets.SERVER_USER }}
password: ${{ secrets.SERVER_PASSWORD }}
source: "."
target: ${{ env.REMOTE_DEPLOY_PATH }}
- name: 3.重启docker-compose
uses: appleboy/ssh-action@v0.1.10
with:
host: ${{ secrets.SERVER_HOST }}
username: ${{ secrets.SERVER_USER }}
password: ${{ secrets.SERVER_PASSWORD }}
script: |
# 进入项目目录
cd ${{ env.REMOTE_DEPLOY_PATH }}
docker-compose down 2>&1
docker-compose up -d --build --remove-orphans 2>&1
docker image prune -f 2>&1

View File

@@ -0,0 +1,42 @@
name: 定时 AiDA python develop 分支构建部署
on:
# 使用 schedule 触发器,遵循标准的 Cron 格式 (分钟 小时-8 日期 月份 星期)
schedule:
- cron: '30 9 * * *'
jobs:
scheduled_deploy:
runs-on: ubuntu-latest
env:
REMOTE_DEPLOY_PATH: /workspace/Trinity/Fastapi_AiDA_Trinity_Dev
steps:
- name: 1.检出代码
uses: actions/checkout@v4
with:
ref: 'develop'
- name: 2.复制文件到服务器
uses: appleboy/scp-action@v0.1.7
with:
host: ${{ secrets.SERVER_HOST }}
username: ${{ secrets.SERVER_USER }}
password: ${{ secrets.SERVER_PASSWORD }}
source: "."
target: ${{ env.REMOTE_DEPLOY_PATH }}
- name: Restart Docker containers
uses: appleboy/ssh-action@v0.1.10
with:
host: ${{ secrets.SERVER_HOST }}
username: ${{ secrets.SERVER_USER }}
password: ${{ secrets.SERVER_PASSWORD }}
script: |
# 进入项目目录
cd ${{ env.REMOTE_DEPLOY_PATH }}
docker-compose down 2>&1
docker-compose up -d --build --remove-orphans 2>&1
docker image prune -f 2>&1

2
.gitignore vendored
View File

@@ -149,3 +149,5 @@ app/logs/*
*.csv
*.avi
*.json
*.env*
config.backup.py

22
Dockerfile Normal file
View File

@@ -0,0 +1,22 @@
FROM python:3.12-slim
# Install uv.
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
# Copy the application into the container.
COPY . /app
# Install the application dependencies.
WORKDIR /app
RUN mkdir /seg_cache
# 更新索引并安装替代包
RUN apt-get update && apt-get install -y \
vim \
libgl1 \
libglib2.0-0 \
&& rm -rf /var/lib/apt/lists/*
RUN uv sync --frozen --no-cache
# Run the application.
CMD ["/app/.venv/bin/fastapi", "run", "app/main.py", "--port", "80", "--host", "0.0.0.0"]

View File

@@ -23,11 +23,11 @@
$ pip install mmcv==1.4.2 -f https://download.openmmlab.com/mmcv/dist/cu117/torch1.13/index.html
2. 启动服务器
1. 启动服务器
$ uvicorn app.main:app --host 0.0.0.0 --port 8000
3. 打开 http://127.0.0.1:8000/docs
2. 打开 http://127.0.0.1:8000/docs
Docker 部署
---------------

View File

@@ -2,8 +2,7 @@ import json
import logging
from fastapi import APIRouter, HTTPException
from app.core.config import DEBUG
from app.core.config import settings
from app.schemas.attribute_retrieve import *
from app.schemas.response_template import ResponseModel
from app.service.attribute.config import const, local_debug_const
@@ -35,13 +34,13 @@ def attribute_recognition(request_item: list[AttributeRecognitionModel]):
"""
try:
for item in request_item:
logger.debug(f"attribute_recognition request item is : @@@@@@:{json.dumps(item.dict())}")
if DEBUG:
logger.info(f"attribute_recognition request item is : @@@@@@:{json.dumps(item.dict(), indent=4)}")
if settings.DEBUG:
service = AttributeRecognition(const=local_debug_const, request_data=request_item)
else:
service = AttributeRecognition(const=const, request_data=request_item)
data = service.get_result()
logger.debug(f"attribute_recognition response @@@@@@:{json.dumps(data)}")
logger.info(f"attribute_recognition response @@@@@@:{json.dumps(data, indent=4)}")
except Exception as e:
logger.warning(f"attribute_recognition Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
@@ -67,10 +66,10 @@ def category_recognition(request_item: list[CategoryRecognitionModel]):
"""
try:
for item in request_item:
logger.info(f"category_recognition request item is : @@@@@@:{json.dumps(item.dict())}")
logger.info(f"category_recognition request item is : @@@@@@:{json.dumps(item.dict(), indent=4)}")
service = CategoryRecognition(request_data=request_item)
data = service.get_result()
logger.info(f"category_recognition response @@@@@@:{json.dumps(data)}")
logger.info(f"category_recognition response @@@@@@:{json.dumps(data, indent=4)}")
except Exception as e:
logger.warning(f"category_recognition Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))

View File

@@ -26,7 +26,7 @@ def seg_product(request_item: BrandDnaModel):
}
"""
try:
logger.info(f"brand dna request item is : @@@@@@:{json.dumps(request_item.dict())}")
logger.info(f"brand dna request item is : @@@@@@:{json.dumps(request_item.dict(),indent=4)}")
service = BrandDna(request_item)
result_url = service.get_result()
except Exception as e:
@@ -36,7 +36,7 @@ def seg_product(request_item: BrandDnaModel):
@router.post("/GenerateBrand")
def GenerateBrand(request_data: GenerateBrandModel):
def generate_brand(request_data: GenerateBrandModel):
"""
通过prompt 生成 brand name brand slogan , brand logo。
创建一个具有以下参数的请求体:

View File

@@ -9,7 +9,6 @@ from apscheduler.schedulers.background import BackgroundScheduler
from apscheduler.triggers.cron import CronTrigger
from fastapi import HTTPException, APIRouter
from app.service.recommend.service import load_resources, matrix_data
import pymysql
from app.core.config import DB_CONFIG, TABLE_CATEGORIES, RECOMMEND_PATH_PREFIX
from minio import Minio

View File

@@ -5,10 +5,11 @@ import time
from PIL import ImageEnhance
from fastapi import APIRouter, HTTPException
from minio import Minio
from app.core.config import settings
from app.schemas.brighten import BrightenModel
from app.schemas.response_template import ResponseModel
from app.service.utils.oss_client import oss_get_image, oss_upload_image
from app.service.utils.new_oss_client import oss_get_image, oss_upload_image
router = APIRouter()
logger = logging.getLogger()
@@ -20,6 +21,9 @@ def increase_brightness(img, factor):
return bright_img
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
@router.post("/brighten")
async def brighten(request_item: BrightenModel):
"""
@@ -35,14 +39,14 @@ async def brighten(request_item: BrightenModel):
"""
try:
start_time = time.time()
logger.info(f"brighten request item is : @@@@@@:{json.dumps(request_item.dict())}")
image = oss_get_image(bucket=request_item.image_url.split('/')[0], object_name=request_item.image_url[request_item.image_url.find('/') + 1:], data_type="PIL")
logger.info(f"brighten request item is : @@@@@@:{json.dumps(request_item.dict(),indent=4)}")
image = oss_get_image(oss_client=minio_client, bucket=request_item.image_url.split('/')[0], object_name=request_item.image_url[request_item.image_url.find('/') + 1:], data_type="PIL")
new_image = increase_brightness(image, request_item.brighten_value)
image_data = io.BytesIO()
new_image.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
req = oss_upload_image(bucket=request_item.image_url.split('/')[0], object_name=request_item.image_url[request_item.image_url.find('/') + 1:], image_bytes=image_bytes)
req = oss_upload_image(oss_client=minio_client, bucket=request_item.image_url.split('/')[0], object_name=request_item.image_url[request_item.image_url.find('/') + 1:], image_bytes=image_bytes)
brighten_url = f"{req.bucket_name}/{req.object_name}"
logger.info(f"run time is : {time.time() - start_time}")
except Exception as e:

View File

@@ -30,9 +30,9 @@ def chat_robot(request_data: ChatRobotModel):
}
"""
try:
logger.info(f"chat_robot request item is : @@@@@@:{json.dumps(request_data.dict())}")
logger.info(f"chat_robot request item is : @@@@@@:{json.dumps(request_data.dict(),indent=4)}")
data = chat(post_data=request_data)
logger.info(f"chat_robot response @@@@@@:{json.dumps(data)}")
logger.info(f"chat_robot response @@@@@@:{json.dumps(data, indent=4)}")
except Exception as e:
logger.warning(f"chat_robot Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))

View File

@@ -42,7 +42,7 @@ def clothing_seg(request_item: ClothingSegModel):
}
"""
try:
logger.info(f"clothing_seg request item is : @@@@@@:{json.dumps(request_item.dict())}")
logger.info(f"clothing_seg request item is : @@@@@@:{json.dumps(request_item.dict(),indent=4)}")
server = ClothingSeg(request_item)
result_url = server.get_result()
except Exception as e:

View File

@@ -1,64 +1,76 @@
import json
import logging
import os
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, BackgroundTasks
import requests
from fastapi import APIRouter, HTTPException, BackgroundTasks
from app.schemas.design import DesignModel, DesignProgressModel, ModelProgressModel, DBGConfigModel, DesignStreamModel
from app.schemas.design import DesignModel, ModelProgressModel, DesignStreamModel, SAMRequestModel
from app.schemas.response_template import ResponseModel
from app.service.design.model_process_service import model_transpose
from app.service.design_batch.service import start_design_batch_generate
from app.service.design_fast.design_generate import design_generate, design_generate_v2
from app.service.design_fast.utils.redis_utils import Redis
from app.service.design_fast.model_process_service import model_transpose
router = APIRouter()
logger = logging.getLogger()
@router.post("/design")
def design(request_data: DesignModel, background_tasks: BackgroundTasks):
def design(request_data: DesignModel):
"""
objects.items.transparent:
- **objects.items.transparent**:
```json
"transparent":{
"mask_url":"test/transparent_test/transparent_mask.png",
"scale":0.1
},
mask_url 为空"" -> 单件衣服透明
mask_url"mask_url" -> 区域透明
```
- **mask_url** 为"" -> 单件衣服透明
- **mask_url** 非空"mask_url" -> 区域透明
- **transpose** 镜像模式 ,:"top_bottom""left_right"
- **rotate** 45,
创建一个具有以下参数的请求体:
- ** design 参数变更:
design detail 请求参数中 basic -> preview_submit 替换为design_type 可选参数 default ,merge (移除preview和submit)
design_type 参数说明:
defuault模式下 请求参数不变
merge模式下 items -> 每个item需要新增 merge_image_path , merge_image_path为前端处理 print color等操作后的单件结果图
**
- 创建一个具有以下参数的请求体:
示例参数:
```json
{
"objects": [
{
"basic": {
"body_point_test": {
"waistband_right": [
200,
241
203,
249
],
"hand_point_right": [
223,
297
229,
343
],
"waistband_left": [
112,
241
119,
248
],
"hand_point_left": [
92,
305
97,
343
],
"shoulder_left": [
99,
116
108,
107
],
"shoulder_right": [
215,
116
212,
107
]
},
"layer_order": true,
"design_type": "preview",
"scale_bag": 0.7,
"scale_earrings": 0.16,
"self_template": true,
@@ -67,14 +79,19 @@ def design(request_data: DesignModel, background_tasks: BackgroundTasks):
},
"items": [
{
"businessId": 270372,
"color": "30 28 28",
"image_id": 69780,
"businessId": 2115382,
"color": "",
"image_id": 61686,
"offset": [
0,
0
],
"path": "aida-sys-image/images/female/trousers/0825000630.jpg",
"path": "aida-sys-image/images/female/dress/0628000564.jpg",
"transpose": [
1,
1
],
"rotate": 45,
"print": {
"element": {
"element_angle_list": [],
@@ -83,10 +100,30 @@ def design(request_data: DesignModel, background_tasks: BackgroundTasks):
"location": []
},
"overall": {
"location": [],
"print_angle_list": [],
"print_path_list": [],
"print_scale_list": []
"location": [
[
53.0,
118.5
]
],
"print_angle_list": [
0.0
],
"print_path_list": [
"aida-users/89/print/02d57aa8-f342-4e1d-b02c-b278f94dcfe6-3-89.png"
],
"print_scale_list": [
[
0.5,
0.5
]
],
"gap": [
[
10,
10
]
]
},
"single": {
"location": [],
@@ -100,104 +137,30 @@ def design(request_data: DesignModel, background_tasks: BackgroundTasks):
1.0,
1.0
],
"type": "Trousers"
"seg_mask_url": "aida-clothing/mask/mask_9698b428-eb93-11f0-9327-0242c0a80003.png",
"type": "Dress"
},
{
"businessId": 270373,
"color": "30 28 28",
"image_id": 98243,
"offset": [
0,
0
],
"path": "aida-sys-image/images/female/blouse/0902003811.jpg",
"print": {
"element": {
"element_angle_list": [],
"element_path_list": [],
"element_scale_list": [],
"location": []
},
"overall": {
"location": [],
"print_angle_list": [],
"print_path_list": [],
"print_scale_list": []
},
"single": {
"location": [],
"print_angle_list": [],
"print_path_list": [],
"print_scale_list": []
}
},
"priority": 11,
"resize_scale": [
1.0,
1.0
],
"type": "Blouse"
},
{
"businessId": 270374,
"color": "172 68 68",
"image_id": 98244,
"offset": [
0,
0
],
"path": "aida-sys-image/images/female/outwear/0825000410.jpg",
"print": {
"element": {
"element_angle_list": [],
"element_path_list": [],
"element_scale_list": [],
"location": []
},
"overall": {
"location": [],
"print_angle_list": [],
"print_path_list": [],
"print_scale_list": []
},
"single": {
"location": [],
"print_angle_list": [],
"print_path_list": [],
"print_scale_list": []
}
},
"priority": 12,
"resize_scale": [
1.0,
1.0
],
"transparent":{
"mask_url":"test/transparent_test/transparent_mask.png",
"scale":0.1
},
"type": "Outwear"
},
{
"body_path": "aida-sys-image/models/female/5bdfe7ca-64eb-44e4-b03d-8e517520c795.png",
"image_id": 96090,
"body_path": "aida-sys-image/models/female/2e4815b9-1191-419d-94ed-5771239ca4a5.png",
"image_id": 67277,
"type": "Body"
}
]
}
],
"process_id": "83"
"process_id": "89"
}
```
"""
# logger.info(f"design request item is : @@@@@@:{json.dumps(request_data.dict())}")
# logger.info(f"design request item is : @@@@@@:{json.dumps(request_data.dict(),indent=4)}")
# data = generate(request_data=request_data)
# logger.info(f"design response @@@@@@:{json.dumps(data)}")
# logger.info(f"design response @@@@@@:{json.dumps(data, indent=4)}")
#
try:
logger.info(f"design request item is : @@@@@@:{json.dumps(request_data.dict())}")
logger.info(f"design request item is : @@@@@@:{json.dumps(request_data.dict(), indent=4)}")
data = design_generate(request_data=request_data)
logger.info(f"design response @@@@@@:{json.dumps(data)}")
logger.info(f"design response @@@@@@:{json.dumps(data, indent=4)}")
except Exception as e:
logger.warning(f"design Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
@@ -215,47 +178,48 @@ async def design_v2(request_data: DesignStreamModel, background_tasks: Backgroun
"basic": {
"body_point_test": {
"waistband_right": [
200,
241
203,
249
],
"hand_point_right": [
223,
297
229,
343
],
"waistband_left": [
112,
241
119,
248
],
"hand_point_left": [
92,
305
97,
343
],
"shoulder_left": [
99,
116
108,
107
],
"relation_type": "System",
"shoulder_right": [
215,
116
]
212,
107
],
"relation_id": 1020356
},
"layer_order": true,
"layer_order": false,
"scale_bag": 0.7,
"scale_earrings": 0.16,
"self_template": true,
"self_template": false,
"single_overall": "overall",
"switch_category": ""
},
"items": [
{
"businessId": 270372,
"color": "30 28 28",
"image_id": 69780,
"color": "209 196 171",
"image_id": 84093,
"offset": [
0,
0
1,
1
],
"path": "aida-sys-image/images/female/trousers/0825000630.jpg",
"path": "aida-users/89/sketchboard/female/Outwear/0943d209-7ce0-408c-bc61-83f15da94138.png",
"print": {
"element": {
"element_angle_list": [],
@@ -264,10 +228,23 @@ async def design_v2(request_data: DesignStreamModel, background_tasks: Backgroun
"location": []
},
"overall": {
"location": [],
"print_angle_list": [],
"location": [
[
0.0,
0.0
]
],
"print_angle_list": [
0.0,
0.0
],
"print_path_list": [],
"print_scale_list": []
"print_scale_list": [
[
0.0,
0.0
]
]
},
"single": {
"location": [],
@@ -276,22 +253,20 @@ async def design_v2(request_data: DesignStreamModel, background_tasks: Backgroun
"print_scale_list": []
}
},
"priority": 10,
"resize_scale": [
1.0,
1.0
],
"type": "Trousers"
"type": "Outwear"
},
{
"businessId": 270373,
"color": "30 28 28",
"image_id": 98243,
"color": "63 71 73",
"image_id": 100496,
"offset": [
0,
0
1,
1
],
"path": "aida-sys-image/images/female/blouse/0902003811.jpg",
"path": "aida-sys-image/images/female/blouse/0628001684.jpg",
"print": {
"element": {
"element_angle_list": [],
@@ -300,10 +275,23 @@ async def design_v2(request_data: DesignStreamModel, background_tasks: Backgroun
"location": []
},
"overall": {
"location": [],
"print_angle_list": [],
"location": [
[
0.0,
0.0
]
],
"print_angle_list": [
0.0,
0.0
],
"print_path_list": [],
"print_scale_list": []
"print_scale_list": [
[
0.0,
0.0
]
]
},
"single": {
"location": [],
@@ -312,7 +300,6 @@ async def design_v2(request_data: DesignStreamModel, background_tasks: Backgroun
"print_scale_list": []
}
},
"priority": 11,
"resize_scale": [
1.0,
1.0
@@ -320,14 +307,14 @@ async def design_v2(request_data: DesignStreamModel, background_tasks: Backgroun
"type": "Blouse"
},
{
"businessId": 270374,
"color": "172 68 68",
"image_id": 98244,
"color": "111 78 63",
"gradient": "aida-gradient/f69b98e8-4248-4f7a-98a2-21bac41bf3e0.png",
"image_id": 92193,
"offset": [
0,
0
1,
1
],
"path": "aida-sys-image/images/female/outwear/0825000410.jpg",
"path": "aida-sys-image/images/female/trousers/0825001160.jpg",
"print": {
"element": {
"element_angle_list": [],
@@ -336,10 +323,23 @@ async def design_v2(request_data: DesignStreamModel, background_tasks: Backgroun
"location": []
},
"overall": {
"location": [],
"print_angle_list": [],
"location": [
[
0.0,
0.0
]
],
"print_angle_list": [
0.0,
0.0
],
"print_path_list": [],
"print_scale_list": []
"print_scale_list": [
[
0.0,
0.0
]
]
},
"single": {
"location": [],
@@ -348,31 +348,37 @@ async def design_v2(request_data: DesignStreamModel, background_tasks: Backgroun
"print_scale_list": []
}
},
"priority": 12,
"resize_scale": [
1.0,
1.0
],
"transparent":{
"mask_url":"test/transparent_test/transparent_mask.png",
"scale":0.1
},
"type": "Outwear"
"type": "Trousers"
},
{
"body_path": "aida-sys-image/models/female/5bdfe7ca-64eb-44e4-b03d-8e517520c795.png",
"image_id": 96090,
"body_path": "aida-sys-image/models/female/2e4815b9-1191-419d-94ed-5771239ca4a5.png",
"image_id": 67277,
"offset": [
1,
1
],
"resize_scale": [
1.0,
1.0
],
"type": "Body"
}
]
],
"objectSign": "65830966"
}
],
"process_id": "83"
"process_id": "4802946666428422",
"requestId": "1d1e7641-0d62-4da2-adc0-b4404910723c",
"callback_url": "https://api.aida.com.hk/api/third/party/receiveDesignResults"
}
"""
try:
# 异步
logger.info(f"generate_image request item is : @@@@@@:{json.dumps(request_data.dict())}")
logger.info(f"generate_image request item is : @@@@@@:{json.dumps(request_data.dict(), indent=4)}")
background_tasks.add_task(design_generate_v2, request_data)
except Exception as e:
logger.warning(f"design Run Exception @@@@@@:{e}")
@@ -380,30 +386,76 @@ async def design_v2(request_data: DesignStreamModel, background_tasks: Backgroun
return ResponseModel()
@router.post('/get_progress')
def get_progress(request_data: DesignProgressModel):
@router.post("/seg_anything")
async def seg_anything(request_data: SAMRequestModel):
"""
获取design 进度
创建一个具有以下参数的请求体:
- **process_id**: 进度id
**Segment Anything 交互式分割接口**
示例参数:
通过传入图片路径和点击的点坐标,返回分割后的掩码数据。
### 参数说明:
- **user_id**:用户id 用于存储分割图
- **image_path**: 图片在服务器或云端的相对路径。
- **type**: 推理类型
- **box**: 框选矩形点位信息
- **points**: 交互点的坐标列表。每个点为 [x, y] 像素格式。
- **labels**: 坐标点的属性标签,必须与 points 长度一致:
- 1: **前景点** (代表想要分割出的区域)
- 0: **背景点** (代表想要排除的区域)
### 请求体示例:
```json
point
{
"process_id": "6878547032381675"
"user_id": 1,
"image_path": "aida-users/89/sketch/4e8fe37d-7068-400a-ac94-c01647fa5f6f.png",
"type":"point",
"points": [[310, 403], [493, 375], [261, 266], [404, 484]],
"labels": [1, 1, 0, 1]
}
box
{
"user_id": 1,
"image_path": "aida-users/89/sketch/4e8fe37d-7068-400a-ac94-c01647fa5f6f.png",
"type":"box",
"box": [350, 286, 544, 520]
}
```
"""
try:
logger.info(f"get_progress request item is : @@@@@@:{json.dumps(request_data.dict())}")
process_id = request_data.process_id
r = Redis()
data = r.read(key=process_id)
if data is None:
raise ValueError(f"No progress ID: {process_id}")
logging.info(f"get_progress process_id @@@@@@ : {process_id} , progress : {json.dumps(data)}")
logger.info(f"seg_anything request item is : @@@@@@:{json.dumps(request_data.dict(), indent=4)}")
data = requests.post("http://10.1.1.240:10075/predict", json=request_data.dict())
logger.info(f"seg_anything response @@@@@@:{json.dumps(json.loads(data.content), indent=4)}")
return ResponseModel(data=json.loads(data.content))
except Exception as e:
logger.warning(f"get_progress Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data)
logger.warning(f"seg_anything Run Exception @@@@@@:{e}")
# @router.post('/get_progress')
# def get_progress(request_data: DesignProgressModel):
# """
# 获取design 进度
# 创建一个具有以下参数的请求体:
# - **process_id**: 进度id
#
# 示例参数:
# {
# "process_id": "6878547032381675"
# }
# """
# try:
# logger.info(f"get_progress request item is : @@@@@@:{json.dumps(request_data.dict(), indent=4)}")
# process_id = request_data.process_id
# r = Redis()
# data = r.read(key=process_id)
# if data is None:
# raise ValueError(f"No progress ID: {process_id}")
# logging.info(f"get_progress process_id @@@@@@ : {process_id} , progress : {json.dumps(data, indent=4)}")
# except Exception as e:
# logger.warning(f"get_progress Run Exception @@@@@@:{e}")
# raise HTTPException(status_code=404, detail=str(e))
# return ResponseModel(data=data)
@router.post('/model_process')
@@ -419,44 +471,42 @@ def model_process(request_data: ModelProgressModel):
}
"""
try:
logger.info(f"model_process request item is : @@@@@@:{json.dumps(request_data.dict())}")
logger.info(f"model_process request item is : @@@@@@:{json.dumps(request_data.dict(), indent=4)}")
data = model_transpose(image_path=request_data.model_path)
logger.info(f"model_process response @@@@@@:{json.dumps(data)}")
logger.info(f"model_process response @@@@@@:{json.dumps(data, indent=4)}")
except Exception as e:
logger.warning(f"model_process Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data)
# ##############################################################
@router.post("/design_batch_generate")
async def design_batch(file: UploadFile = File(...),
tasks_id: str = Form(...),
user_id: str = Form(...),
file_name: str = Form(...),
total: int = Form(...)
):
dbg_config = DBGConfigModel(
tasks_id=tasks_id,
user_id=user_id,
file_name=file_name,
total=total
)
contents = await file.read()
file_name = file.filename
await save_request_file(contents, file_name)
return await start_design_batch_generate(dbg_config, contents)
async def save_request_file(contents, file_name):
# 创建保存文件的目录(如果不存在)
save_dir = os.path.join(os.getcwd(), "service/design_batch", "request_data")
if not os.path.exists(save_dir):
os.makedirs(save_dir)
# 处理文件
file_path = os.path.join(save_dir, file_name)
with open(file_path, "wb") as f:
f.write(contents)
"""design 批量处理 停用"""
# @router.post("/design_batch_generate")
# async def design_batch(file: UploadFile = File(...),
# tasks_id: str = Form(...),
# user_id: str = Form(...),
# file_name: str = Form(...),
# total: int = Form(...)
# ):
# dbg_config = DBGConfigModel(
# tasks_id=tasks_id,
# user_id=user_id,
# file_name=file_name,
# total=total
# )
# contents = await file.read()
# file_name = file.filename
# await save_request_file(contents, file_name)
# return await start_design_batch_generate(dbg_config, contents)
#
#
# async def save_request_file(contents, file_name):
# # 创建保存文件的目录(如果不存在)
# save_dir = os.path.join(os.getcwd(), "service/design_batch", "request_data")
# if not os.path.exists(save_dir):
# os.makedirs(save_dir)
# # 处理文件
# file_path = os.path.join(save_dir, file_name)
# with open(file_path, "wb") as f:
# f.write(contents)

View File

@@ -30,10 +30,10 @@ def design_pre_processing(request_data: DesignPreProcessingModel):
}
"""
try:
logger.info(f"design_pre_processing request item is : @@@@@@:{json.dumps(request_data.dict())}")
logger.info(f"design_pre_processing request item is : @@@@@@:{json.dumps(request_data.dict(), indent=4)}")
server = DesignPreprocessing()
data = server.pipeline(image_list=request_data.sketches)
logger.info(f"design response @@@@@@:{json.dumps(data)}")
logger.info(f"design response @@@@@@:{json.dumps(data, indent=4)}")
except Exception as e:
logger.warning(f"design Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))

View File

@@ -33,18 +33,30 @@ def generate_image(request_item: GenerateImageModel, background_tasks: Backgroun
- **version**: 使用模型版本 fast 或者 high
示例参数:
1. txt 2 img
{
"tasks_id": "123-89",
"prompt": "skeleton sitting by the side of a river looking soulful, concert poster, 4k, artistic",
"image_url": "aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg",
"mode": "img2img",
"tasks_id": "bd2cf809-24bc-49a6-91c9-193c6272a52e-2-89",
"prompt": "a single item of sketch of dress, 4k, white background",
"image_url": "",
"mode": "txt2img",
"category": "sketch",
"gender": "male",
"gender": "Female",
"version": "fast"
}
2. img 2 img
{
"tasks_id": "b861d4fa-5ae3-4a30-9c7a-7ba6bb9aa37b-1-89",
"prompt": "a single item of sketch of dress, 4k, white background",
"image_url": "aida-collection-element/89/Sketchboard/548da3a2-834f-49a7-b52c-e729c5ab5062.png",
"mode": "img2img",
"category": "sketch",
"gender": "Female",
"version": "fast"
}
"""
try:
logger.info(f"generate_image request item is : @@@@@@:{json.dumps(request_item.dict())}")
logger.info(f"generate_image request item is : @@@@@@:{json.dumps(request_item.dict(), indent=4)}")
service = GenerateImage(request_item)
background_tasks.add_task(service.get_result)
except Exception as e:
@@ -65,42 +77,41 @@ def generate_image(tasks_id: str):
return ResponseModel(data=data['data'])
'''multi view'''
'''multi view 停用'''
# @router.post("/generate_multi_view")
# def generate_multi_view(request_item: GenerateMultiViewModel, background_tasks: BackgroundTasks):
# """
# 创建一个具有以下参数的请求体:
# - **tasks_id**: 任务id 用于取消生成任务和获取生成结果
# - **image_url**: 前视角图的输入minio或S3 url 地址
#
# 示例参数:
# {
# "tasks_id": "123-89",
# "image_url": "aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg"
# }
# """
# try:
# logger.info(f"generate_multi_view request item is : @@@@@@:{json.dumps(request_item.dict(),indent=4)}")
# service = GenerateMultiView(request_item)
# background_tasks.add_task(service.get_result)
# except Exception as e:
# logger.warning(f"generate_multi_view Run Exception @@@@@@:{e}")
# raise HTTPException(status_code=404, detail=str(e))
# return ResponseModel()
@router.post("/generate_multi_view")
def generate_multi_view(request_item: GenerateMultiViewModel, background_tasks: BackgroundTasks):
"""
创建一个具有以下参数的请求体:
- **tasks_id**: 任务id 用于取消生成任务和获取生成结果
- **image_url**: 前视角图的输入minio或S3 url 地址
示例参数:
{
"tasks_id": "123-89",
"image_url": "aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg"
}
"""
try:
logger.info(f"generate_multi_view request item is : @@@@@@:{json.dumps(request_item.dict())}")
service = GenerateMultiView(request_item)
background_tasks.add_task(service.get_result)
except Exception as e:
logger.warning(f"generate_multi_view Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel()
@router.get("/generate_multi_view_cancel/{tasks_id}")
def generate_multi_view(tasks_id: str):
try:
logger.info(f"generate_cancel request item is : @@@@@@:{tasks_id}")
data = generate_multi_view_cancel(tasks_id)
logger.info(f"generate_cancel response @@@@@@:{data}")
except Exception as e:
logger.warning(f"generate_cancel Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data['data'])
# @router.get("/generate_multi_view_cancel/{tasks_id}")
# def generate_multi_view(tasks_id: str):
# try:
# logger.info(f"generate_cancel request item is : @@@@@@:{tasks_id}")
# data = generate_multi_view_cancel(tasks_id)
# logger.info(f"generate_cancel response @@@@@@:{data}")
# except Exception as e:
# logger.warning(f"generate_cancel Run Exception @@@@@@:{e}")
# raise HTTPException(status_code=404, detail=str(e))
# return ResponseModel(data=data['data'])
'''single logo'''
@@ -122,7 +133,7 @@ def generate_single_logo(request_item: GenerateSingleLogoImageModel, background_
}
"""
try:
logger.info(f"generate_single_logo request item is : @@@@@@:{json.dumps(request_item.dict())}")
logger.info(f"generate_single_logo request item is : @@@@@@:{json.dumps(request_item.dict(), indent=4)}")
service = GenerateSingleLogoImage(request_item)
background_tasks.add_task(service.get_result)
except Exception as e:
@@ -167,7 +178,7 @@ def generate_product_image(request_item: GenerateProductImageModel, background_t
}
"""
try:
logger.info(f"generate_product_image request item is : @@@@@@:{json.dumps(request_item.dict())}")
logger.info(f"generate_product_image request item is : @@@@@@:{json.dumps(request_item.dict(),indent=4)}")
service = GenerateProductImage(request_item)
background_tasks.add_task(service.get_result)
except Exception as e:
@@ -188,166 +199,164 @@ def generate_product_image(tasks_id: str):
return ResponseModel(data=data['data'])
'''relight image'''
'''relight image 停用'''
# @router.post("/generate_relight_image")
# def generate_relight_image(request_item: GenerateRelightImageModel, background_tasks: BackgroundTasks):
# """
# 创建一个具有以下参数的请求体:
# - **tasks_id**: 任务id 用于取消生成任务和获取生成结果
# - **prompt**: 想要生成图片的描述词
# - **image_url**: 被生成图片的S3或minio url地址
# - **direction**: 光源方向 Right Light Left Light Top Light Bottom Light
# - **product_type**: 输入single item 还是 overall item
#
#
# 示例参数:
# {
# "tasks_id": "123-89",
# "prompt": "beautiful woman, detailed face, sunshine, outdoor, warm atmosphere",
# "image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png",
# "direction": "Right Light",
# "product_type": "overall"
# }
# """
# try:
# logger.info(f"generate_relight_image request item is : @@@@@@:{json.dumps(request_item.dict(),indent=4)}")
# service = GenerateRelightImage(request_item)
# background_tasks.add_task(service.get_result)
# except Exception as e:
# logger.warning(f"generate_relight_image Run Exception @@@@@@:{e}")
# raise HTTPException(status_code=404, detail=str(e))
# return ResponseModel()
#
#
# @router.get("/generate_relight_image_cancel_cancel/{tasks_id}")
# def generate_relight_image(tasks_id: str):
# try:
# logger.info(f"generate_relight_image_cancel_cancel request item is : @@@@@@:{tasks_id}")
# data = generate_relight_image_cancel(tasks_id)
# logger.info(f"generate_relight_image_cancel_cancel response @@@@@@:{data}")
# except Exception as e:
# logger.warning(f"generate_relight_image_cancel_cancel Run Exception @@@@@@:{e}")
# raise HTTPException(status_code=404, detail=str(e))
# return ResponseModel(data=data['data'])
@router.post("/generate_relight_image")
def generate_relight_image(request_item: GenerateRelightImageModel, background_tasks: BackgroundTasks):
"""
创建一个具有以下参数的请求体:
- **tasks_id**: 任务id 用于取消生成任务和获取生成结果
- **prompt**: 想要生成图片的描述词
- **image_url**: 被生成图片的S3或minio url地址
- **direction**: 光源方向 Right Light Left Light Top Light Bottom Light
- **product_type**: 输入single item 还是 overall item
"""batch generate img 停用"""
# @router.post("/batch_generate_product_image")
# async def batch_generate_product(request_batch_item: BatchGenerateProductImageModel):
# """
# 创建一个具有以下参数的请求体:
# - **tasks_id**: 任务id 用于获取生成结果
# - **prompt**: 想要生成图片的描述词
# - **image_url**: 被生成图片的S3或minio url地址
# - **image_strength**: 生成强度,越低越接近原图
# - **product_type**: 输入single item 还是 overall item
# - **batch_size**: 批生成数量
#
#
# 示例参数:
# {
# "tasks_id": "123-89",
# "prompt": "the best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting",
# "image_url": "aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png",
# "image_strength": 0.8,
# "product_type": "overall",
# "batch_size": 1
# }
# """
# return await start_product_batch_generate(request_batch_item)
#
#
# @router.post("/batch_generate_relight_image")
# async def batch_generate_relight(request_batch_item: BatchGenerateRelightImageModel):
# """
# 创建一个具有以下参数的请求体:
# - **tasks_id**: 任务id 用于获取生成结果
# - **prompt**: 想要生成图片的描述词
# - **image_url**: 被生成图片的S3或minio url地址
# - **direction**: 光源方向 Right Light Left Light Top Light Bottom Light
# - **product_type**: 输入single item 还是 overall item
# - **batch_size**: 批生成数量
#
#
# 示例参数:
# {
# "tasks_id": "123-89",
# "prompt": "beautiful woman, detailed face, sunshine, outdoor, warm atmosphere",
# "image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png",
# "direction": "Right Light",
# "product_type": "overall",
# "batch_size": 1
# }
# """
# return await start_relight_batch_generate(request_batch_item)
示例参数:
{
"tasks_id": "123-89",
"prompt": "beautiful woman, detailed face, sunshine, outdoor, warm atmosphere",
"image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png",
"direction": "Right Light",
"product_type": "overall"
}
"""
try:
logger.info(f"generate_relight_image request item is : @@@@@@:{json.dumps(request_item.dict())}")
service = GenerateRelightImage(request_item)
background_tasks.add_task(service.get_result)
except Exception as e:
logger.warning(f"generate_relight_image Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel()
@router.get("/generate_relight_image_cancel_cancel/{tasks_id}")
def generate_relight_image(tasks_id: str):
try:
logger.info(f"generate_relight_image_cancel_cancel request item is : @@@@@@:{tasks_id}")
data = generate_relight_image_cancel(tasks_id)
logger.info(f"generate_relight_image_cancel_cancel response @@@@@@:{data}")
except Exception as e:
logger.warning(f"generate_relight_image_cancel_cancel Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data['data'])
"""batch generate img"""
@router.post("/batch_generate_product_image")
async def batch_generate_product(request_batch_item: BatchGenerateProductImageModel):
"""
创建一个具有以下参数的请求体:
- **tasks_id**: 任务id 用于获取生成结果
- **prompt**: 想要生成图片的描述词
- **image_url**: 被生成图片的S3或minio url地址
- **image_strength**: 生成强度,越低越接近原图
- **product_type**: 输入single item 还是 overall item
- **batch_size**: 批生成数量
示例参数:
{
"tasks_id": "123-89",
"prompt": "the best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting",
"image_url": "aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png",
"image_strength": 0.8,
"product_type": "overall",
"batch_size": 1
}
"""
return await start_product_batch_generate(request_batch_item)
@router.post("/batch_generate_relight_image")
async def batch_generate_relight(request_batch_item: BatchGenerateRelightImageModel):
"""
创建一个具有以下参数的请求体:
- **tasks_id**: 任务id 用于获取生成结果
- **prompt**: 想要生成图片的描述词
- **image_url**: 被生成图片的S3或minio url地址
- **direction**: 光源方向 Right Light Left Light Top Light Bottom Light
- **product_type**: 输入single item 还是 overall item
- **batch_size**: 批生成数量
示例参数:
{
"tasks_id": "123-89",
"prompt": "beautiful woman, detailed face, sunshine, outdoor, warm atmosphere",
"image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png",
"direction": "Right Light",
"product_type": "overall",
"batch_size": 1
}
"""
return await start_relight_batch_generate(request_batch_item)
@router.post("/batch_generate_pose_transform_image")
async def batch_generate_pose_transform(request_batch_item: BatchPoseTransformModel):
"""
创建一个具有以下参数的请求体:
- **tasks_id**: 任务id 用于取消生成任务和获取生成结果
- **image_url**: 被生成图片的S3或minio url地址
- **pose_id**: 1
- **batch_size**: 批生成数量
示例参数:
{
"tasks_id": "123-89",
"image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png",
"pose_id": "1",
"batch_size": 1
}
"""
return await start_pose_transform_batch_generate(request_batch_item)
"""agent tool"""
@router.post("/agent_tool_generate_image")
def agent_tool_generate_image(request_item: AgentTollGenerateImageModel, background_tasks: BackgroundTasks):
"""
创建一个具有以下参数的请求体:
- **prompt**: 想要生成图片的描述词
- **category**: 生成图片的类别sketch print 等等
- **gender**: 生成sketch专用服装类别
- **version**: 使用模型版本 fast 或者 high
- **size**: 生成数量
- **version**: 使用模型版本 fast 或者 high
示例参数:
{
"prompt": "a single item of sketch of Wabi-sabi, skirt, tiered, 4k, white background",
"category": "sketch",
"gender": "male",
"size":2,
"version":"high"
}
"""
try:
logger.info(f"agent_tool_generate_image request item is : @@@@@@:{request_item.dict()}")
request_data = request_item.dict()
service = AgentToolGenerateImage(request_data['version'])
image_url_list, clothing_category_list = service.get_result(
prompt=request_data['prompt'],
size=request_data['size'],
version=request_data['version'],
category=request_data['category'],
gender=request_data['gender']
)
data = {
"image_url_list": image_url_list,
"clothing_category_list": clothing_category_list
}
logger.info(f"agent_tool_generate_image response item is : @@@@@@:{data}")
except Exception as e:
logger.warning(f"agent_tool_generate_image Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data)
# @router.post("/batch_generate_pose_transform_image")
# async def batch_generate_pose_transform(request_batch_item: BatchPoseTransformModel):
# """
# 创建一个具有以下参数的请求体:
# - **tasks_id**: 任务id 用于取消生成任务和获取生成结果
# - **image_url**: 被生成图片的S3或minio url地址
# - **pose_id**: 1
# - **batch_size**: 批生成数量
#
#
# 示例参数:
# {
# "tasks_id": "123-89",
# "image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png",
# "pose_id": "1",
# "batch_size": 1
# }
# """
# return await start_pose_transform_batch_generate(request_batch_item)
#
#
# """agent tool"""
#
#
# @router.post("/agent_tool_generate_image")
# def agent_tool_generate_image(request_item: AgentTollGenerateImageModel):
# """
# 创建一个具有以下参数的请求体:
# - **prompt**: 想要生成图片的描述词
# - **category**: 生成图片的类别sketch print 等等
# - **gender**: 生成sketch专用服装类别
# - **version**: 使用模型版本 fast 或者 high
# - **size**: 生成数量
# - **version**: 使用模型版本 fast 或者 high
#
#
# 示例参数:
# {
# "prompt": "a single item of sketch of Wabi-sabi, skirt, tiered, 4k, white background",
# "category": "sketch",
# "gender": "male",
# "size":2,
# "version":"high"
# }
# """
# try:
# logger.info(f"agent_tool_generate_image request item is : @@@@@@:{request_item.dict()}")
# request_data = request_item.dict()
# service = AgentToolGenerateImage(request_data['version'])
# image_url_list, clothing_category_list = service.get_result(
# prompt=request_data['prompt'],
# size=request_data['size'],
# version=request_data['version'],
# category=request_data['category'],
# gender=request_data['gender']
# )
# data = {
# "image_url_list": image_url_list,
# "clothing_category_list": clothing_category_list
# }
# logger.info(f"agent_tool_generate_image response item is : @@@@@@:{data}")
# except Exception as e:
# logger.warning(f"agent_tool_generate_image Run Exception @@@@@@:{e}")
# raise HTTPException(status_code=404, detail=str(e))
# return ResponseModel(data=data)

View File

@@ -29,7 +29,7 @@ def image2sketch(request_item: Image2SketchModel):
}
"""
try:
logger.info(f"image2sketch request item is : @@@@@@:{json.dumps(request_item.dict())}")
logger.info(f"image2sketch request item is : @@@@@@:{json.dumps(request_item.dict(),indent=4)}")
service = LineArtService(request_item)
result_url = service.get_result()
except Exception as e:

View File

@@ -0,0 +1,116 @@
import logging
import sys
from typing import Optional
from fastapi import APIRouter, HTTPException, Query
from concurrent.futures import ThreadPoolExecutor
import threading
from app.schemas.response_template import ResponseModel
from app.service.recommendation_system.import_sys_sketch_to_milvus import main as import_main
logger = logging.getLogger()
router = APIRouter()
# 使用线程池执行器来运行长时间任务
executor = ThreadPoolExecutor(max_workers=1)
# 用于跟踪任务状态
task_status = {"running": False}
def run_import_task(batch_size: int, retry_times: int, limit: Optional[int], offset: int, skip_create_collection: bool):
"""在后台线程中运行导入任务"""
original_argv = None
try:
task_status["running"] = True
# 保存原始 sys.argv
original_argv = sys.argv.copy()
# 模拟命令行参数
sys.argv = [
"import_sys_sketch_to_milvus.py",
"--batch-size", str(batch_size),
"--retry-times", str(retry_times),
]
if limit is not None:
sys.argv.extend(["--limit", str(limit)])
if offset > 0:
sys.argv.extend(["--offset", str(offset)])
if skip_create_collection:
sys.argv.append("--skip-create-collection")
import_main()
task_status["running"] = False
logger.info("导入任务完成")
except Exception as e:
task_status["running"] = False
logger.error(f"导入任务失败: {e}", exc_info=True)
raise
finally:
# 恢复原始 sys.argv
if original_argv is not None:
sys.argv = original_argv
@router.post("/import-sys-sketch", response_model=ResponseModel)
async def import_sys_sketch(
batch_size: int = Query(1000, description="批量处理大小默认1000"),
retry_times: int = Query(3, description="失败重试次数默认3"),
limit: Optional[int] = Query(None, description="限制处理数量(用于测试,默认:不限制)"),
offset: int = Query(0, description="起始偏移量默认0"),
skip_create_collection: bool = Query(False, description="跳过创建集合(如果集合已存在)"),
):
"""
从 t_sys_file 导入系统图向量到 Milvus
该接口会异步执行导入任务,任务在后台运行。
"""
try:
# 检查是否有任务正在运行
if task_status["running"]:
raise HTTPException(
status_code=409,
detail="已有导入任务正在运行,请等待完成后再试"
)
# 在后台线程中执行任务
executor.submit(
run_import_task,
batch_size,
retry_times,
limit,
offset,
skip_create_collection
)
return ResponseModel(
code=200,
msg="导入任务已启动,正在后台执行",
data={
"status": "started",
"batch_size": batch_size,
"retry_times": retry_times,
"limit": limit,
"offset": offset,
"skip_create_collection": skip_create_collection
}
)
except HTTPException:
raise
except Exception as e:
logger.error(f"启动导入任务失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"启动导入任务失败: {str(e)}")
@router.get("/import-sys-sketch/status", response_model=ResponseModel)
async def get_import_status():
"""
获取导入任务状态
"""
return ResponseModel(
code=200,
msg="OK",
data={
"running": task_status["running"]
}
)

View File

@@ -35,10 +35,10 @@ def mannequins_edit(request_data: MannequinModel):
}**
"""
try:
logger.info(f"mannequins_edit request item is : @@@@@@:{json.dumps(request_data.dict())}")
logger.info(f"mannequins_edit request item is : @@@@@@:{json.dumps(request_data.dict(),indent=4)}")
service = MannequinEditService(request_data)
data = service()
logger.info(f"mannequins_edit response @@@@@@:{json.dumps(data)}")
logger.info(f"mannequins_edit response @@@@@@:{json.dumps(data, indent=4)}")
except Exception as e:
logger.warning(f"mannequins_edit Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))

View File

@@ -1,18 +1,67 @@
import json
import logging
import requests
from fastapi import APIRouter, BackgroundTasks, HTTPException
from app.core.config import settings
from app.schemas.comfyui_i2v import ComfyuiI2VModel, ComfyuiFLF2VModel
from app.schemas.pose_transform import PoseTransformModel
from app.schemas.response_template import ResponseModel
from app.service.generate_image.service_pose_transform import PoseTransformService, infer_cancel as pose_transform_infer_cancel
from app.service.comfyui_I2V.flf2v_server import ComfyUIServerFLF2V
from app.service.comfyui_I2V.i2v_server import ComfyUIServerI2V
from app.service.comfyui_I2V.pose2v_server import ComfyUIServerPose2V
router = APIRouter()
logger = logging.getLogger()
"""停用"""
@router.post("/pose_transform")
def pose_transform(request_item: PoseTransformModel, background_tasks: BackgroundTasks):
# @router.post("/pose_transform")
# def pose_transform(request_item: PoseTransformModel, background_tasks: BackgroundTasks):
# """
# 创建一个具有以下参数的请求体:
# - **tasks_id**: 任务id 用于取消生成任务和获取生成结果
# - **image_url**: 被生成图片的S3或minio url地址
# - **pose_id**: 1
#
#
# 示例参数:
# {
# "tasks_id": "123-89",
# "image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png",
# "pose_id": "1"
# }
# """
# try:
# logger.info(f"pose_transform request item is : @@@@@@:{json.dumps(request_item.dict(),indent=4)}")
# service = PoseTransformService(request_item)
# background_tasks.add_task(service.get_result)
# except Exception as e:
# logger.warning(f"pose_transform Run Exception @@@@@@:{e}")
# raise HTTPException(status_code=404, detail=str(e))
# return ResponseModel()
# @router.get("/pose_transform_cancel/{tasks_id}")
# def pose_transform_cancel(tasks_id: str):
# try:
# logger.info(f"pose_transform_cancel request item is : @@@@@@:{tasks_id}")
# data = pose_transform_infer_cancel(tasks_id)
# logger.info(f"pose_transform_cancel response @@@@@@:{data}")
# except Exception as e:
# logger.warning(f"pose_transform_cancel Run Exception @@@@@@:{e}")
# raise HTTPException(status_code=404, detail=str(e))
# return ResponseModel(data=data['data'])
"""
骨架 + 产品图 => 视频
"""
@router.post("/comfyui_image_pose_2_video")
def comfyui_image_pose_2_video(request_item: PoseTransformModel, background_tasks: BackgroundTasks):
"""
创建一个具有以下参数的请求体:
- **tasks_id**: 任务id 用于取消生成任务和获取生成结果
@@ -28,22 +77,92 @@ def pose_transform(request_item: PoseTransformModel, background_tasks: Backgroun
}
"""
try:
logger.info(f"pose_transform request item is : @@@@@@:{json.dumps(request_item.dict())}")
service = PoseTransformService(request_item)
logger.info(f"image_pose_2_video request item is : @@@@@@:{json.dumps(request_item.dict(),indent=4)}")
service = ComfyUIServerPose2V(request_item)
background_tasks.add_task(service.get_result)
except Exception as e:
logger.warning(f"pose_transform Run Exception @@@@@@:{e}")
logger.warning(f"image_pose_2_video Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel()
@router.get("/pose_transform_cancel/{tasks_id}")
def pose_transform_cancel(tasks_id: str):
"""
产品图 + 文 => 视频
"""
@router.post("/comfyui_image_2_video")
def comfyui_image_2_video(request_item: ComfyuiI2VModel, background_tasks: BackgroundTasks):
"""
创建一个具有以下参数的请求体:
- **tasks_id**: 任务id 用于取消生成任务和获取生成结果
- **image_url**: 被生成图片的S3或minio url地址
- **prompt**: 动作表述
示例参数:
{
"tasks_id": "12222515151123-89111",
"image_url": "aida-users/89/product_image/a6949500-2393-42ac-8723-440b5d5da2b2-0-89.png",
"prompt": "Model executing a series of poses, dynamic camera movement alternating between detailed close-ups and full shots."
}
"""
try:
logger.info(f"pose_transform_cancel request item is : @@@@@@:{tasks_id}")
data = pose_transform_infer_cancel(tasks_id)
logger.info(f"pose_transform_cancel response @@@@@@:{data}")
logger.info(f"image_2_video request item is : @@@@@@:{json.dumps(request_item.dict(),indent=4)}")
service = ComfyUIServerI2V(request_item)
background_tasks.add_task(service.get_result)
except Exception as e:
logger.warning(f"pose_transform_cancel Run Exception @@@@@@:{e}")
logger.warning(f"image_2_video Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel()
"""
首尾帧 + 文 => 视频
"""
@router.post("/comfyui_flf_2_video")
def comfyui_flf_2_video(request_item: ComfyuiFLF2VModel, background_tasks: BackgroundTasks):
"""
创建一个具有以下参数的请求体:
- **tasks_id**: 任务id 用于取消生成任务和获取生成结果
- **start_image_url**: 首帧
- **end_image_url**: 尾帧
- **prompt**: 动作描述
示例参数:
{
"tasks_id": "202511051619-89111",
"start_image_url": "test/start.png",
"end_image_url": "test/end.png",
"prompt": "Model executing a series of poses, dynamic camera movement alternating between detailed close-ups and full shots."
}
"""
try:
logger.info(f"flf_2_video request item is : @@@@@@:{json.dumps(request_item.dict(),indent=4)}")
service = ComfyUIServerFLF2V(request_item)
background_tasks.add_task(service.get_result)
except Exception as e:
logger.warning(f"flf_2_video Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel()
@router.get("/comfyui_i_2_video_cancel/{tasks_id}")
def comfyui_i_2_video_cancel(tasks_id: str):
try:
logger.info(f"comfyui_i_2_video_cancel request item is : @@@@@@:{tasks_id}")
response = requests.post(
f"http://{settings.COMFYUI_SERVER_ADDRESS}/interrupt",
json={"prompt_id": tasks_id}
)
data = {}
if response.status_code == 200:
data['data']['message'] = "任务已成功中断"
else:
data['data']['message'] = f"中断失败:{response.text}"
logger.info(f"comfyui_i_2_video_cancel response @@@@@@:{data}")
except Exception as e:
logger.warning(f"comfyui_i_2_video_cancel Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data['data'])

85
app/api/api_precompute.py Normal file
View File

@@ -0,0 +1,85 @@
import logging
from fastapi import APIRouter, HTTPException
from concurrent.futures import ThreadPoolExecutor
from app.schemas.response_template import ResponseModel
from app.service.recommendation_system.precompute import run_precompute
logger = logging.getLogger()
router = APIRouter()
# 使用线程池执行器来运行长时间任务
executor = ThreadPoolExecutor(max_workers=1)
# 用于跟踪任务状态
task_status = {"running": False}
def run_precompute_task():
"""在后台线程中运行预计算任务"""
try:
task_status["running"] = True
logger.info("开始执行预计算任务...")
run_precompute()
task_status["running"] = False
logger.info("预计算任务完成")
except Exception as e:
task_status["running"] = False
logger.error(f"预计算任务失败: {e}", exc_info=True)
raise
@router.post("/precompute", response_model=ResponseModel)
async def precompute():
"""
运行预计算任务
该接口会异步执行预计算任务,包括:
1. 优化数据库表结构
2. 历史数据迁移
3. 初始用户偏好向量生成
任务在后台运行。
"""
try:
# 检查是否有任务正在运行
if task_status["running"]:
raise HTTPException(
status_code=409,
detail="已有预计算任务正在运行,请等待完成后再试"
)
# 在后台线程中执行任务
executor.submit(run_precompute_task)
return ResponseModel(
code=200,
msg="预计算任务已启动,正在后台执行",
data={
"status": "started",
"tasks": [
"优化数据库表结构",
"历史数据迁移",
"初始用户偏好向量生成"
]
}
)
except HTTPException:
raise
except Exception as e:
logger.error(f"启动预计算任务失败: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"启动预计算任务失败: {str(e)}")
@router.get("/precompute/status", response_model=ResponseModel)
async def get_precompute_status():
"""
获取预计算任务状态
"""
return ResponseModel(
code=200,
msg="OK",
data={
"running": task_status["running"]
}
)

View File

@@ -1,13 +1,10 @@
import json
import logging
import time
from fastapi import APIRouter, HTTPException
from app.schemas.prompt_generation import PromptGenerationImageModel, ImageRequest
from app.schemas.response_template import ResponseModel
from app.service.prompt_generation.chatgpt_for_translation import get_translation_from_llama3, \
get_prompt_from_image
from app.service.prompt_generation.chatgpt_for_translation import get_translation_from_llama3, get_prompt_from_image
router = APIRouter()
logger = logging.getLogger()
@@ -34,19 +31,19 @@ def prompt_generation(request_data: PromptGenerationImageModel):
raise HTTPException(status_code=404, detail=str(e))
return ResponseModel(data=data)
@router.post("/img2prompt")
def get_prompt_from_img(img: ImageRequest):
"""
自动识别图片并输出为prompt
:param img: 图片的minio地址
:return: 图片的文字描述
"""
text = ("Please describe the clothing in the image and provide a line art description of the outfit. "
"The description should allow for the reconstruction of the corresponding line art based on the details "
"given.")
logger.info(f"get_prompt_from_img request item is : @@@@@@:{img}")
description = get_prompt_from_image(img, text)
logger.info(f"生成的图片描述 response @@@@@@:{description}")
return description
# 停用
# @router.post("/img2prompt")
# def get_prompt_from_img(img: ImageRequest):
# """
# 自动识别图片并输出为prompt
#
# :param img: 图片的minio地址
# :return: 图片的文字描述
# """
# text = ("Please describe the clothing in the image and provide a line art description of the outfit. "
# "The description should allow for the reconstruction of the corresponding line art based on the details "
# "given.")
# logger.info(f"get_prompt_from_img request item is : @@@@@@:{img}")
# description = get_prompt_from_image(img, text)
# logger.info(f"生成的图片描述 response @@@@@@:{description}")
# return description

View File

@@ -26,9 +26,9 @@ def query_image(request_data: QueryImageModel):
}
"""
try:
logger.info(f"query_image request item is : @@@@@@:{json.dumps(request_data.dict())}")
logger.info(f"query_image request item is : @@@@@@:{json.dumps(request_data.dict(), indent=4)}")
data = query(request_data.gender, request_data.content)
logger.info(f"query_image response @@@@@@:{json.dumps(data)}")
logger.info(f"query_image response @@@@@@:{json.dumps(data, indent=4)}")
except Exception as e:
logger.warning(f"query_image Run Exception @@@@@@:{e}")
raise HTTPException(status_code=404, detail=str(e))

View File

@@ -1,204 +1,206 @@
import io
import logging
import sys
import time
from typing import List
import os
import json
import math
import random
import numpy as np
from typing import List, Optional
from fastapi import HTTPException, APIRouter, Query
from apscheduler.schedulers.background import BackgroundScheduler
from apscheduler.triggers.cron import CronTrigger
from fastapi import HTTPException, APIRouter
from app.service.recommend.service import load_resources, matrix_data
from app.service.recommendation_system.recommendation_api import get_recommendations as get_new_recommendations
from app.service.recommendation_system.incremental_listener import start_background_listener
from app.service.recommendation_system.milvus_client import create_collection
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
logger = logging.getLogger()
router = APIRouter()
# ========== 旧版推荐接口(基于 npy 矩阵,已废弃)==========
# @router.get("/recommend/{user_id}/{category}/{num_recommendations}/{brand_id}/{brand_scale}", response_model=List[str])
# async def get_recommendations(user_id: int, category: str, brand_id: int, brand_scale: float, num_recommendations: int = 10):
# """
# :param user_id: 4
# :param category: female_skirt
# :param num_recommendations: 1
# :return:
# [
# "aida-sys-image/images/female/skirt/903000017.jpg"
# ]
# """
# try:
# start_time = time.time()
# cache_key = (user_id, category)
# # === 新增:用户存在性检查 ===
# user_exists_inter = user_id in matrix_data["user_index_interaction"]
# user_exists_feat = user_id in matrix_data["user_index_feature"]
#
# # 任一矩阵不存在用户则返回随机推荐
# if not (user_exists_inter and user_exists_feat):
# logger.info(f"用户 {user_id} 数据不完整,触发随机推荐")
# return get_random_recommendations(category, num_recommendations)
#
# # 检查缓存
# if cache_key in matrix_data["cached_scores"]:
# processed_inter, processed_feat = matrix_data["cached_scores"][cache_key]
# valid_sketch_idxs_inter = matrix_data["cached_valid_idxs"][cache_key]
# else:
# # 实时计算逻辑(同原代码)
# user_idx_inter = matrix_data["user_index_interaction"].get(user_id)
# user_idx_feature = matrix_data["user_index_feature"].get(user_id)
#
# category_iids = matrix_data["category_to_iids"].get(category, [])
# valid_sketch_idxs_inter = [
# idx for iid, idx in matrix_data["sketch_index_interaction"].items()
# if iid in category_iids
# ]
#
# # 处理交互分数
# raw_inter_scores = []
# if user_idx_inter is not None and valid_sketch_idxs_inter:
# raw_inter_scores = matrix_data["interaction_matrix"][user_idx_inter, valid_sketch_idxs_inter]
# processed_inter = raw_inter_scores * 0.7
#
# # 处理特征分数
# valid_sketch_idxs_feature = [
# idx for iid, idx in matrix_data["sketch_index_feature"].items()
# if iid in category_iids
# ]
# raw_feat_scores = []
# if user_idx_feature is not None and valid_sketch_idxs_feature:
# raw_feat_scores = matrix_data["feature_matrix"][user_idx_feature, valid_sketch_idxs_feature]
# raw_feat_scores = (raw_feat_scores - np.min(raw_feat_scores)) / (
# np.max(raw_feat_scores) - np.min(raw_feat_scores) + 1e-8)
# processed_feat = raw_feat_scores
# else:
# processed_feat = np.array([])
#
# # 更新缓存
# matrix_data["cached_scores"][cache_key] = (processed_inter, processed_feat)
# matrix_data["cached_valid_idxs"][cache_key] = valid_sketch_idxs_inter
#
# # 合并分数
# if brand_id is not None:
# brand_idx_feature = matrix_data["brand_index_map"].get(brand_id)
#
# brand_feat_valid = (
# matrix_data["brand_feature_matrix"].size > 0 and # 矩阵非空
# brand_idx_feature is not None and
# valid_sketch_idxs_feature # 有可用索引
# )
#
# if brand_feat_valid:
# raw_brand_feat_scores = matrix_data["brand_feature_matrix"][
# brand_idx_feature, valid_sketch_idxs_feature
# ]
# raw_brand_feat_scores = (raw_brand_feat_scores - np.min(raw_brand_feat_scores)) / (
# np.max(raw_brand_feat_scores) - np.min(raw_brand_feat_scores) + 1e-8
# )
# processed_brand_feat = raw_brand_feat_scores
#
# # 如果 processed_feat 是空的,替换为全 0避免 shape 不一致
# if processed_feat.size == 0:
# processed_feat = np.zeros_like(processed_brand_feat)
#
# final_scores = processed_inter + 0.3 * (
# (1 - brand_scale) * processed_feat + brand_scale * processed_brand_feat
# )
# else:
# # brand 信息不可用
# final_scores = processed_inter + 0.3 * processed_feat if processed_feat.size > 0 else processed_inter
# else:
# final_scores = processed_inter + 0.3 * processed_feat if processed_feat.size > 0 else processed_inter
#
# valid_sketch_idxs = matrix_data["cached_valid_idxs"][cache_key]
#
# # 概率采样
# scores = np.array(final_scores)
#
# # 调整后的概率转换带温度控制的softmax
# def calibrated_softmax(scores, temperature=1.0):
# scores = scores / temperature
# scale = scores - max(scores)
# exps = np.exp(scale)
# return exps / np.sum(exps)
#
# probs = calibrated_softmax(scores, 0.09)
#
# chosen_indices = np.random.choice(
# len(valid_sketch_idxs),
# size=min(num_recommendations, len(valid_sketch_idxs)),
# p=probs,
# replace=False
# )
# recommendations = [matrix_data["iid_to_sketch"][valid_sketch_idxs[idx]] for idx in chosen_indices]
#
# logger.info(f"推荐生成完成,耗时: {time.time() - start_time:.2f}秒")
# return recommendations
# except Exception as e:
# logger.error(f"推荐失败: {str(e)}", exc_info=True)
# raise HTTPException(status_code=500, detail=str(e))
@router.on_event("startup")
async def startup_event():
# 初始加载
load_resources()
"""启动时初始化增量监听任务"""
try:
# 屏蔽 apscheduler 的 INFO 日志
logging.getLogger("apscheduler").setLevel(logging.WARNING)
# 确保 Milvus 集合已创建(若已存在则直接返回)
try:
create_collection()
except Exception as exc:
logger.error("Milvus 集合创建/检查失败,不影响服务继续启动: %s", exc, exc_info=True)
# 配置定时任务
scheduler = BackgroundScheduler()
scheduler.add_job(
load_resources,
trigger=CronTrigger(hour=0, minute=30),
name="每日资源刷新"
)
start_background_listener(scheduler)
scheduler.start()
logger.info("定时任务已启动")
def softmax(scores):
max_score = max(scores)
exp_scores = [math.exp(s - max_score) for s in scores]
sum_exp = sum(exp_scores)
return [s / sum_exp for s in exp_scores]
# def get_random_recommendations(category: str, num: int) -> List[str]:
# """根据预加载热度向量推荐(冷启动)"""
# try:
# heat_data = matrix_data.get("heat_data", {})
#
# if category not in heat_data:
# raise ValueError(f"热度数据缺少类别 {category},使用随机推荐")
#
# heat_dict = heat_data[category] # {url: score}
# urls = list(heat_dict.keys())
# scores = list(heat_dict.values())
#
# if not urls:
# raise ValueError("该类别下无热度记录,使用随机推荐")
#
# probs = softmax(scores)
# sample_size = min(num, len(urls))
# sampled_urls = random.choices(urls, weights=probs, k=sample_size)
#
# return sampled_urls
#
# except Exception as e:
# # 回退:完全随机推荐
# all_iids = list(matrix_data["iid_to_sketch"].keys())
# category_iids = matrix_data["category_to_iids"].get(category, all_iids)
# sample_size = min(num, len(category_iids))
# sampled = np.random.choice(category_iids, size=sample_size, replace=False)
# return [matrix_data["iid_to_sketch"][iid] for iid in sampled]
def get_random_recommendations(category: str, num: int) -> List[str]:
"""全品类随机推荐"""
all_iids = list(matrix_data["iid_to_sketch"].keys())
# 优先从当前品类选择
category_iids = matrix_data["category_to_iids"].get(category, all_iids)
# 确保不超出实际数量
sample_size = min(num, len(category_iids))
sampled = np.random.choice(category_iids, size=sample_size, replace=False)
return [matrix_data["iid_to_sketch"][iid] for iid in sampled]
logger.info("增量监听定时任务已启动")
except Exception as e:
logger.error(f"启动增量监听任务失败: {e}", exc_info=True)
@router.get("/recommend/{user_id}/{category}/{num_recommendations}/{brand_id}/{brand_scale}", response_model=List[str])
async def get_recommendations(user_id: int, category: str, brand_id: int, brand_scale: float, num_recommendations: int = 10):
@router.get("/recommend/{user_id}/{category}", response_model=List[str])
async def recommend(
user_id: int,
category: str,
style: Optional[str] = Query(
None,
description="风格样式(可选):若传入,则在利用分支对同 style 的候选进行加分",
),
):
"""新版推荐接口Milvus + Redis 偏好向量)。"""
try:
results = get_new_recommendations(user_id, category, style)
path = results[0] if results else ""
return [path]
except Exception as e:
logger.error("新版推荐接口失败 [user=%s, category=%s]: %s", user_id, category, e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.get("/redis/user_pref")
async def get_all_user_preferences():
"""
:param user_id: 4
:param category: female_skirt
:param num_recommendations: 1
:return:
[
"aida-sys-image/images/female/skirt/903000017.jpg"
]
获取所有以 user_pref 为前缀的 Redis key 信息
"""
try:
start_time = time.time()
cache_key = (user_id, category)
# === 新增:用户存在性检查 ===
user_exists_inter = user_id in matrix_data["user_index_interaction"]
user_exists_feat = user_id in matrix_data["user_index_feature"]
from app.service.utils.redis_utils import Redis
from app.service.recommendation_system.config import REDIS_KEY_USER_PREF_PREFIX
# 任一矩阵不存在用户则返回随机推荐
if not (user_exists_inter and user_exists_feat):
logger.info(f"用户 {user_id} 数据不完整,触发随机推荐")
return get_random_recommendations(category, num_recommendations)
# 扫描所有匹配 user_pref:* 的 key
pattern = f"{REDIS_KEY_USER_PREF_PREFIX}:*"
keys = Redis.scan_keys(pattern)
# 检查缓存
if cache_key in matrix_data["cached_scores"]:
processed_inter, processed_feat = matrix_data["cached_scores"][cache_key]
valid_sketch_idxs_inter = matrix_data["cached_valid_idxs"][cache_key]
else:
# 实时计算逻辑(同原代码)
user_idx_inter = matrix_data["user_index_interaction"].get(user_id)
user_idx_feature = matrix_data["user_index_feature"].get(user_id)
# 直接返回所有 key 和原始 value
result = {}
for key in keys:
# 读取对应的值
value = Redis.read(key)
if value:
result[key] = value
category_iids = matrix_data["category_to_iids"].get(category, [])
valid_sketch_idxs_inter = [
idx for iid, idx in matrix_data["sketch_index_interaction"].items()
if iid in category_iids
]
# 处理交互分数
raw_inter_scores = []
if user_idx_inter is not None and valid_sketch_idxs_inter:
raw_inter_scores = matrix_data["interaction_matrix"][user_idx_inter, valid_sketch_idxs_inter]
processed_inter = raw_inter_scores * 0.7
# 处理特征分数
valid_sketch_idxs_feature = [
idx for iid, idx in matrix_data["sketch_index_feature"].items()
if iid in category_iids
]
raw_feat_scores = []
if user_idx_feature is not None and valid_sketch_idxs_feature:
raw_feat_scores = matrix_data["feature_matrix"][user_idx_feature, valid_sketch_idxs_feature]
raw_feat_scores = (raw_feat_scores - np.min(raw_feat_scores)) / (
np.max(raw_feat_scores) - np.min(raw_feat_scores) + 1e-8)
processed_feat = raw_feat_scores
else:
processed_feat = np.array([])
# 更新缓存
matrix_data["cached_scores"][cache_key] = (processed_inter, processed_feat)
matrix_data["cached_valid_idxs"][cache_key] = valid_sketch_idxs_inter
# 合并分数
if brand_id is not None:
brand_idx_feature = matrix_data["brand_index_map"].get(brand_id)
brand_feat_valid = (
matrix_data["brand_feature_matrix"].size > 0 and # 矩阵非空
brand_idx_feature is not None and
valid_sketch_idxs_feature # 有可用索引
)
if brand_feat_valid:
raw_brand_feat_scores = matrix_data["brand_feature_matrix"][
brand_idx_feature, valid_sketch_idxs_feature
]
raw_brand_feat_scores = (raw_brand_feat_scores - np.min(raw_brand_feat_scores)) / (
np.max(raw_brand_feat_scores) - np.min(raw_brand_feat_scores) + 1e-8
)
processed_brand_feat = raw_brand_feat_scores
# 如果 processed_feat 是空的,替换为全 0避免 shape 不一致
if processed_feat.size == 0:
processed_feat = np.zeros_like(processed_brand_feat)
final_scores = processed_inter + 0.3 * (
(1 - brand_scale) * processed_feat + brand_scale * processed_brand_feat
)
else:
# brand 信息不可用
final_scores = processed_inter + 0.3 * processed_feat if processed_feat.size > 0 else processed_inter
else:
final_scores = processed_inter + 0.3 * processed_feat if processed_feat.size > 0 else processed_inter
valid_sketch_idxs = matrix_data["cached_valid_idxs"][cache_key]
# 概率采样
scores = np.array(final_scores)
# 调整后的概率转换带温度控制的softmax
def calibrated_softmax(scores, temperature=1.0):
scores = scores / temperature
scale = scores - max(scores)
exps = np.exp(scale)
return exps / np.sum(exps)
probs = calibrated_softmax(scores, 0.09)
chosen_indices = np.random.choice(
len(valid_sketch_idxs),
size=min(num_recommendations, len(valid_sketch_idxs)),
p=probs,
replace=False
)
recommendations = [matrix_data["iid_to_sketch"][valid_sketch_idxs[idx]] for idx in chosen_indices]
logger.info(f"推荐生成完成,耗时: {time.time() - start_time:.2f}")
return recommendations
return result
except Exception as e:
logger.error(f"推荐失败: {str(e)}", exc_info=True)
logger.error("获取用户偏好数据失败: %s", e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -1,38 +1,42 @@
from fastapi import APIRouter
from app.api import api_attribute_retrieve, api_query_image
from app.api import api_attribute_retrieve
from app.api import api_brand_dna
from app.api import api_brighten
from app.api import api_chat_robot
from app.api import api_clothing_seg
from app.api import api_design
from app.api import api_design_pre_processing
from app.api import api_extraction_project_info
from app.api import api_generate_image
from app.api import api_image2sketch
from app.api import api_mannequins_edit
from app.api import api_pose_transform
from app.api import api_precompute
from app.api import api_prompt_generation
from app.api import api_recommendation
from app.api import api_super_resolution
from app.api import api_test
router = APIRouter()
router.include_router(api_test.router, tags=["test"], prefix="/test")
router.include_router(api_super_resolution.router, tags=["super_resolution"], prefix="/api")
router.include_router(api_generate_image.router, tags=["generate_image"], prefix="/api")
router.include_router(api_attribute_retrieve.router, tags=["attribute_retrieve"], prefix="/api")
router.include_router(api_design.router, tags=['design'], prefix="/api")
router.include_router(api_chat_robot.router, tags=['chat_robot'], prefix="/api")
router.include_router(api_prompt_generation.router, tags=['prompt_generation'], prefix="/api")
router.include_router(api_design_pre_processing.router, tags=['design_pre_processing'], prefix="/api")
router.include_router(api_image2sketch.router, tags=['api_image2sketch'], prefix="/api")
router.include_router(api_brighten.router, tags=['api_brighten'], prefix="/api")
router.include_router(api_query_image.router, tags=['api_query_image'], prefix="/api")
router.include_router(api_brand_dna.router, tags=['api_brand_dna'], prefix="/api")
router.include_router(api_recommendation.router, tags=['api_recommendation'], prefix="/api")
router.include_router(api_precompute.router, tags=['api_precompute'], prefix="/api")
router.include_router(api_mannequins_edit.router, tags=['api_mannequins_edit'], prefix="/api")
router.include_router(api_pose_transform.router, tags=['api_pose_transform'], prefix="/api")
router.include_router(api_clothing_seg.router, tags=['api_clothing_seg'], prefix="/api")
router.include_router(api_extraction_project_info.router, tags=['api_extraction_project_info'], prefix="/api")
"""停用"""
# from app.api import api_chat_robot
# from app.api import api_query_image
# from app.api import api_brighten
# from app.api import api_extraction_project_info
# from app.api import api_image2sketch
# from app.api import api_super_resolution
# router.include_router(api_image2sketch.router, tags=['api_image2sketch'], prefix="/api")
# router.include_router(api_query_image.router, tags=['api_query_image'], prefix="/api")
# router.include_router(api_chat_robot.router, tags=['chat_robot'], prefix="/api")
# router.include_router(api_super_resolution.router, tags=["super_resolution"], prefix="/api")
# router.include_router(api_brighten.router, tags=['api_brighten'], prefix="/api")
# router.include_router(api_extraction_project_info.router, tags=['api_extraction_project_info'], prefix="/api")

View File

@@ -27,7 +27,7 @@ def super_resolution(request_item: SuperResolutionModel, background_tasks: Backg
}
"""
try:
logger.info(f"super_resolution request item is : @@@@@@:{json.dumps(request_item.dict())}")
logger.info(f"super_resolution request item is : @@@@@@:{json.dumps(request_item.dict(),indent=4)}")
service = SuperResolution(request_item)
background_tasks.add_task(service.sr_result)
except Exception as e:

View File

@@ -4,8 +4,7 @@ import logging
from fastapi import APIRouter
from fastapi import HTTPException
from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES, GRI_RABBITMQ_QUEUES, OSS, JAVA_STREAM_API_URL, GMV_RABBITMQ_QUEUES, SLOGAN_RABBITMQ_QUEUES, GEN_SINGLE_LOGO_RABBITMQ_QUEUES, PS_RABBITMQ_QUEUES, BATCH_GPI_RABBITMQ_QUEUES, BATCH_GRI_RABBITMQ_QUEUES, \
BATCH_PS_RABBITMQ_QUEUES, RABBITMQ_ENV
from app.core.config import settings, SR_RABBITMQ_QUEUES, GMV_RABBITMQ_QUEUES, PS_RABBITMQ_QUEUES, SLOGAN_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES, GRI_RABBITMQ_QUEUES, BATCH_GPI_RABBITMQ_QUEUES, BATCH_GRI_RABBITMQ_QUEUES, BATCH_PS_RABBITMQ_QUEUES
from app.schemas.response_template import ResponseModel
logger = logging.getLogger()
@@ -15,9 +14,9 @@ router = APIRouter()
@router.get("{id}")
def test(id: int):
data = {
"RABBITMQ_ENV":RABBITMQ_ENV,
"超分 SR_RABBITMQ_QUEUES": SR_RABBITMQ_QUEUES,
"多视角 GMV_RABBITMQ_QUEUES": GMV_RABBITMQ_QUEUES,
"RABBITMQ_ENV": settings.SERVE_ENV,
# "超分 SR_RABBITMQ_QUEUES": SR_RABBITMQ_QUEUES,
# "多视角 GMV_RABBITMQ_QUEUES": GMV_RABBITMQ_QUEUES,
"pose transform PS_RABBITMQ_QUEUES": PS_RABBITMQ_QUEUES,
"logan SLOGAN_RABBITMQ_QUEUES": SLOGAN_RABBITMQ_QUEUES,
"image and single logo GI_RABBITMQ_QUEUES": GI_RABBITMQ_QUEUES,
@@ -29,10 +28,9 @@ def test(id: int):
"batch relight BATCH_GRI_RABBITMQ_QUEUES": BATCH_GRI_RABBITMQ_QUEUES,
"batch pose transform BATCH_PS_RABBITMQ_QUEUES": BATCH_PS_RABBITMQ_QUEUES,
"JAVA_STREAM_API_URL": JAVA_STREAM_API_URL,
"local_oss_server": OSS
"JAVA_STREAM_API_URL": settings.JAVA_STREAM_API_URL,
}
logger.info(json.dumps(data))
logger.info(json.dumps(data, ensure_ascii=False, indent=4))
if id == 1:
raise HTTPException(status_code=404, detail="Item not found")

235
app/core/config.backup.py Normal file
View File

@@ -0,0 +1,235 @@
import os
import pika
from dotenv import load_dotenv
from pydantic import BaseSettings
BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../'))
load_dotenv(os.path.join(BASE_DIR, '.env'))
class Settings(BaseSettings):
PROJECT_NAME: str = 'FASTAPI BASE'
SECRET_KEY: str = ''
API_PREFIX: str = ''
BACKEND_CORS_ORIGINS: list[str] = ['*']
DATABASE_URL: str = ''
ACCESS_TOKEN_EXPIRE_SECONDS: int = 60 * 60 * 24 * 7 # Token expired after 7 days
SECURITY_ALGORITHM: str = 'HS256'
LOGGING_CONFIG_FILE: str = os.path.join(BASE_DIR, 'logging_env.py')
OSS = "minio"
DEBUG = False
if DEBUG:
LOGS_PATH = "logs/"
CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv"
SEG_CACHE_PATH = "../seg_cache/"
POSE_TRANSFORM_VIDEO_PATH = "../pose_transform_video/"
RECOMMEND_PATH_PREFIX = "service/recommend/"
CHROMADB_PATH = "./chromadb/"
else:
LOGS_PATH = "app/logs/"
CATEGORY_PATH = "app/service/attribute/config/descriptor/category/category_dis.csv"
SEG_CACHE_PATH = "/seg_cache/"
POSE_TRANSFORM_VIDEO_PATH = "/pose_transform_video/"
RECOMMEND_PATH_PREFIX = "app/service/recommend/"
CHROMADB_PATH = "/chromadb/"
# RABBITMQ_ENV = "" # 生产环境
RABBITMQ_ENV = os.getenv("RABBITMQ_ENV", "-dev")
# RABBITMQ_ENV = "-local" # 本地测试环境
if RABBITMQ_ENV == "-dev":
JAVA_STREAM_API_URL = f"https://develop.api.aida.com.hk/api/third/party/receiveDesignResults"
elif RABBITMQ_ENV == "-prod":
JAVA_STREAM_API_URL = f"https://api.aida.com.hk/api/third/party/receiveDesignResults"
settings = Settings()
# minio 配置
MINIO_URL = "www.minio-api.aida.com.hk"
MINIO_ACCESS = 'vXKFLSJkYeEq2DrSZvkB'
MINIO_SECRET = 'uKTZT3x7C43WvPN9QTc99DiRkwddWZrG9Uh3JVlR'
MINIO_SECURE = True
# S3 配置
S3_ACCESS_KEY = "AKIAVD3OJIMF6UJFLSHZ"
S3_AWS_SECRET_ACCESS_KEY = "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"
S3_REGION_NAME = "ap-east-1"
# redis 配置
REDIS_HOST = "10.1.1.240"
REDIS_PORT = "6379"
REDIS_DB = "2"
# rabbitmq config
RABBITMQ_PARAMS = {
"host": "18.167.251.121",
"port": 5672,
"credentials": pika.credentials.PlainCredentials(username='rabbit', password='123456'),
"virtual_host": "/"
}
# milvus 配置
MILVUS_URL = "http://10.1.1.240:19530"
MILVUS_TOKEN = "root:Milvus"
MILVUS_ALIAS = "default"
MILVUS_TABLE_KEYPOINT = "keypoint_cache_2"
MILVUS_TABLE_SEG = "seg_cache"
# Mysql 配置
DB_HOST = '18.167.251.121' # 数据库主机地址
# DB_PORT = int( 33006)
DB_PORT = 33008 # 数据库端口
DB_USERNAME = 'aida_con_python' # 数据库用户名
DB_PASSWORD = '123456' # 数据库密码
DB_NAME = 'aida' # 数据库库名
# openai
os.environ['SERPAPI_API_KEY'] = "a793513017b0718db7966207c31703d280d12435c982f1e67bbcbffa52e7632c"
OPENAI_STREAM = True
BUFFER_THRESHOLD = 6 # must be even number
SINGLE_TOKEN_THRESHOLD = 200
TOKEN_THRESHOLD = 600
OPENAI_TEMPERATURE = 0
# OPENAI_API_KEY = "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"
OPENAI_API_KEY = "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"
OPENAI_MODEL = "gpt-3.5-turbo-0613"
OPENAI_MODEL_LIST = {"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-4-0314",
"gpt-4-32k-0314",
"gpt-4-0613",
"gpt-4-32k-0613", }
# SR service config
SR_MODEL_NAME = "super_resolution"
SR_TRITON_URL = "10.1.1.240:10031"
SR_MINIO_BUCKET = "aida-users"
SR_RABBITMQ_QUEUES = f"SuperResolution{RABBITMQ_ENV}"
# GenerateImage service config
FAST_GI_MODEL_URL = '10.1.1.243:10011'
FAST_GI_MODEL_NAME = 'stable_diffusion_xl'
GI_MODEL_URL = '10.1.1.240:10061'
GI_MODEL_NAME = 'flux'
GMV_MODEL_URL = '10.1.1.243:10081'
GMV_MODEL_NAME = 'multi_view'
GMV_RABBITMQ_QUEUES = f"GenerateMultiView{RABBITMQ_ENV}"
GI_MINIO_BUCKET = "aida-users"
GI_RABBITMQ_QUEUES = f"GenerateImage{RABBITMQ_ENV}"
GI_SYS_IMAGE_URL = "aida-sys-image/generate_image/white_image.jpg"
# SLOGAN service config
SLOGAN_RABBITMQ_QUEUES = f"Slogan{RABBITMQ_ENV}"
# Generate Single Logo service config
GSL_MODEL_URL = '10.1.1.243:10041'
GSL_MINIO_BUCKET = "aida-users"
GSL_MODEL_NAME = 'stable_diffusion_xl_transparent'
GEN_SINGLE_LOGO_RABBITMQ_QUEUES = f"GenSingleLogo{RABBITMQ_ENV}"
# Generate Product service config
# GPI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"ToProductImage{RABBITMQ_ENV}")
# GPI_MODEL_NAME_OVERALL = 'sdxl_ensemble_all'
# GPI_MODEL_URL = '10.1.1.243:10051'
# Generate Product service config 旧版product img 模型
GPI_RABBITMQ_QUEUES = f"ToProductImage{RABBITMQ_ENV}"
BATCH_GPI_RABBITMQ_QUEUES = f"BatchToProductImage{RABBITMQ_ENV}"
GPI_MODEL_NAME_OVERALL = 'diffusion_ensemble_all'
GPI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_cnet'
GPI_MODEL_URL = '10.1.1.243:10051'
# Generate Single Logo service config
GRI_RABBITMQ_QUEUES = f"Relight{RABBITMQ_ENV}"
BATCH_GRI_RABBITMQ_QUEUES = f"BatchRelight{RABBITMQ_ENV}"
GRI_MODEL_NAME_OVERALL = 'diffusion_relight_ensemble'
GRI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_relight'
GRI_MODEL_URL = '10.1.1.240:10051'
# Pose Transform service config
PS_RABBITMQ_QUEUES = f"PoseTransform{RABBITMQ_ENV}"
BATCH_PS_RABBITMQ_QUEUES = f"BatchPoseTransform{RABBITMQ_ENV}"
PT_MODEL_URL = '10.1.1.243:10061'
# SEG service config
SEGMENTATION = {
"new_model_name": "seg_knet",
"name": "seg_ocrnet_hr18",
"input": "seg_input__0",
"output": "seg_output__0",
}
# ollama config
OLLAMA_URL = "http://10.1.1.240:11434/api/embeddings"
# design batch
BATCH_DESIGN_RABBITMQ_QUEUES = f"DesignBatch{RABBITMQ_ENV}"
# DESIGN config
DESIGN_MODEL_URL = '10.1.1.240:10000'
AIDA_CLOTHING = "aida-clothing"
KEYPOINT_RESULT_TABLE_FIELD_SET = ('neckline_left', 'neckline_right', 'shoulder_left', 'shoulder_right', 'armpit_left', 'armpit_right',
'cuff_left_in', 'cuff_left_out', 'cuff_right_in', 'cuff_right_out', 'waistband_left', 'waistband_right')
# DESIGN 预处理
IF_DEBUG_SHOW = False
# 优先级
PRIORITY_DICT = {
'earring_front': 99,
'bag_front': 98,
'hairstyle_front': 97,
'outwear_front': 20,
'tops_front': 19,
'dress_front': 18,
'blouse_front': 17,
'skirt_front': 16,
'trousers_front': 15,
'bottoms_front': 14,
'shoes_right': 1,
'shoes_left': 1,
'body': 0,
'bottoms_back': -14,
'trousers_back': -15,
'skirt_back': -16,
'blouse_back': -17,
'dress_back': -18,
'tops_back': -19,
'outwear_back': -20,
'hairstyle_back': -97,
'bag_back': -98,
'earring_back': -99,
}
QWEN_API_KEY = "sk-f31c29e61ac2498ba5e307aaa6dc10e0"
DB_CONFIG = {
"host": "18.167.251.121",
"port": 3306,
"user": "root",
"password": "QWa998345",
"database": "aida",
"charset": "utf8mb4"
}
TABLE_CATEGORIES = {
"female_dress": "female/dress",
"female_outwear": "female/outwear",
"female_trousers": "female/trousers",
"female_skirt": "female/skirt",
"female_blouse": "female/blouse",
"male_tops": "male/tops",
"male_bottoms": "male/bottoms",
"male_outwear": "male/outwear"
}
# --- ComfyUI 配置信息 ---
COMFYUI_SERVER_ADDRESS = "10.1.2.227:8080" # 替换为您的 ComfyUI 服务器地址

View File

@@ -1,188 +1,91 @@
import os
import pika
from dotenv import load_dotenv
from pydantic import BaseSettings
BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../'))
load_dotenv(os.path.join(BASE_DIR, '.env'))
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
PROJECT_NAME: str = 'FASTAPI BASE'
SECRET_KEY: str = ''
API_PREFIX: str = ''
BACKEND_CORS_ORIGINS: list[str] = ['*']
DATABASE_URL: str = ''
ACCESS_TOKEN_EXPIRE_SECONDS: int = 60 * 60 * 24 * 7 # Token expired after 7 days
SECURITY_ALGORITHM: str = 'HS256'
LOGGING_CONFIG_FILE: str = os.path.join(BASE_DIR, 'logging_env.py')
"""
应用配置类。Pydantic Settings 会自动从环境变量和 .env 文件中加载这些值。
"""
model_config = SettingsConfigDict(
env_file='.env',
env_file_encoding='utf-8',
# extra='ignore' # 忽略环境变量中多余的键
)
# --- 服务端口配置信息 ---
PORT: int = Field(default=8001, description="")
# --- 服务环境 配置信息 ---
SERVE_ENV: str = Field(default='', description="")
# --- 开发状态 配置信息 ---
DEBUG: bool = Field(default=False, description="")
# --- 千问api 配置信息 ---
QWEN_API_KEY: str = Field(default="", description="")
# --- ComfyUI 配置信息 ---
COMFYUI_SERVER_ADDRESS: str = Field(default='', description="")
OSS = "minio"
DEBUG = False
if DEBUG:
LOGS_PATH = "logs/"
CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv"
SEG_CACHE_PATH = "../seg_cache/"
POSE_TRANSFORM_VIDEO_PATH = "../pose_transform_video/"
RECOMMEND_PATH_PREFIX = "service/recommend/"
CHROMADB_PATH = "./chromadb/"
else:
LOGS_PATH = "app/logs/"
CATEGORY_PATH = "app/service/attribute/config/descriptor/category/category_dis.csv"
SEG_CACHE_PATH = "/seg_cache/"
POSE_TRANSFORM_VIDEO_PATH = "/pose_transform_video/"
RECOMMEND_PATH_PREFIX = "app/service/recommend/"
CHROMADB_PATH = "/chromadb/"
# --- minio 配置信息 ---
MINIO_URL: str = Field(default='', description="")
MINIO_ACCESS: str = Field(default='', description="")
MINIO_SECRET: str = Field(default='', description="")
MINIO_SECURE: bool = Field(default=True, description="")
# RABBITMQ_ENV = "" # 生产环境
RABBITMQ_ENV = os.getenv("RABBITMQ_ENV", "-dev")
# RABBITMQ_ENV = "-local" # 本地测试环境
# --- redis 配置信息 ---
REDIS_HOST: str = Field(default='', description="")
REDIS_PORT: str = Field(default='', description="")
REDIS_DB: int = Field(default=0, description="")
# --- mysql 配置信息 ---
MYSQL_HOST: str = Field(default='', description="")
MYSQL_PORT: int = Field(default='', description="")
MYSQL_USER: str = Field(default='', description="")
MYSQL_PASSWORD: str = Field(default='', description="")
MYSQL_DB: str = Field(default='', description="")
MYSQL_CHARSET: str = Field(default='utf8mb4', description="")
# --- rabbit-mq 配置信息 ---
MQ_HOST: str = Field(default='', description="")
MQ_PORT: str = Field(default='', description="")
MQ_USERNAME: str = Field(default='', description="")
MQ_PASSWORD: str = Field(default='', description="")
MQ_VIRTUAL_HOST: str = Field(default='/', description="")
MQ_ENV: str = Field(default='', description="")
# --- milvus 配置信息 ---
MILVUS_URL: str = Field(default='', description="")
MILVUS_TOKEN: str = Field(default='', description="")
MILVUS_ALIAS: str = Field(default='', description="")
# --- ollama 配置信息 ---
CHROMADB_PATH: str = Field(default='', description="")
# --- ollama 配置信息 ---
OLLAMA_URL: str = Field(default='', description="")
# --- Design Callback Java 接口 ---
JAVA_STREAM_API_URL: str = Field(default='', description="")
# --- 其他配置信息 以下均为Docker容器内配置---
LOGS_PATH: str = Field(default="/logs/", description="")
CATEGORY_PATH: str = Field(default="/app/service/attribute/config/descriptor/category/category_dis.csv", description="")
SEG_CACHE_PATH: str = Field(default="/seg_cache/", description="")
RECOMMEND_PATH_PREFIX: str = Field(default="/app/service/recommend/", description="")
if RABBITMQ_ENV == "-dev":
JAVA_STREAM_API_URL = f"https://develop.api.aida.com.hk/api/third/party/receiveDesignResults"
elif RABBITMQ_ENV == "-prod":
JAVA_STREAM_API_URL = f"https://api.aida.com.hk/api/third/party/receiveDesignResults"
settings = Settings()
# minio 配置
MINIO_URL = "www.minio-api.aida.com.hk"
MINIO_ACCESS = 'vXKFLSJkYeEq2DrSZvkB'
MINIO_SECRET = 'uKTZT3x7C43WvPN9QTc99DiRkwddWZrG9Uh3JVlR'
MINIO_SECURE = True
# S3 配置
S3_ACCESS_KEY = "AKIAVD3OJIMF6UJFLSHZ"
S3_AWS_SECRET_ACCESS_KEY = "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"
S3_REGION_NAME = "ap-east-1"
# redis 配置
REDIS_HOST = "10.1.1.240"
REDIS_PORT = "6379"
REDIS_DB = "2"
# rabbitmq config
RABBITMQ_PARAMS = {
"host": "18.167.251.121",
"port": 5672,
"credentials": pika.credentials.PlainCredentials(username='rabbit', password='123456'),
"virtual_host": "/"
"""Design 服务"""
# 推荐服装类别映射
TABLE_CATEGORIES = {
"female_dress": "female/dress",
"female_outwear": "female/outwear",
"female_trousers": "female/trousers",
"female_skirt": "female/skirt",
"female_blouse": "female/blouse",
"male_tops": "male/tops",
"male_bottoms": "male/bottoms",
"male_outwear": "male/outwear"
}
# milvus 配置
MILVUS_URL = "http://10.1.1.240:19530"
MILVUS_TOKEN = "root:Milvus"
MILVUS_ALIAS = "default"
MILVUS_TABLE_KEYPOINT = "keypoint_cache_2"
MILVUS_TABLE_SEG = "seg_cache"
# Mysql 配置
DB_HOST = '18.167.251.121' # 数据库主机地址
# DB_PORT = int( 33006)
DB_PORT = 33008 # 数据库端口
DB_USERNAME = 'aida_con_python' # 数据库用户名
DB_PASSWORD = '123456' # 数据库密码
DB_NAME = 'aida' # 数据库库名
# openai
os.environ['SERPAPI_API_KEY'] = "a793513017b0718db7966207c31703d280d12435c982f1e67bbcbffa52e7632c"
OPENAI_STREAM = True
BUFFER_THRESHOLD = 6 # must be even number
SINGLE_TOKEN_THRESHOLD = 200
TOKEN_THRESHOLD = 600
OPENAI_TEMPERATURE = 0
# OPENAI_API_KEY = "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"
OPENAI_API_KEY = "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"
OPENAI_MODEL = "gpt-3.5-turbo-0613"
OPENAI_MODEL_LIST = {"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-4-0314",
"gpt-4-32k-0314",
"gpt-4-0613",
"gpt-4-32k-0613", }
# SR service config
SR_MODEL_NAME = "super_resolution"
SR_TRITON_URL = "10.1.1.240:10031"
SR_MINIO_BUCKET = "aida-users"
SR_RABBITMQ_QUEUES = f"SuperResolution{RABBITMQ_ENV}"
# GenerateImage service config
FAST_GI_MODEL_URL = '10.1.1.243:10011'
FAST_GI_MODEL_NAME = 'stable_diffusion_xl'
GI_MODEL_URL = '10.1.1.240:10061'
GI_MODEL_NAME = 'flux'
GMV_MODEL_URL = '10.1.1.243:10081'
GMV_MODEL_NAME = 'multi_view'
GMV_RABBITMQ_QUEUES = f"GenerateMultiView{RABBITMQ_ENV}"
GI_MINIO_BUCKET = "aida-users"
GI_RABBITMQ_QUEUES = f"GenerateImage{RABBITMQ_ENV}"
GI_SYS_IMAGE_URL = "aida-sys-image/generate_image/white_image.jpg"
# SLOGAN service config
SLOGAN_RABBITMQ_QUEUES = f"Slogan{RABBITMQ_ENV}"
# Generate Single Logo service config
GSL_MODEL_URL = '10.1.1.243:10041'
GSL_MINIO_BUCKET = "aida-users"
GSL_MODEL_NAME = 'stable_diffusion_xl_transparent'
GEN_SINGLE_LOGO_RABBITMQ_QUEUES = f"GenSingleLogo{RABBITMQ_ENV}"
# Generate Product service config
# GPI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"ToProductImage{RABBITMQ_ENV}")
# GPI_MODEL_NAME_OVERALL = 'sdxl_ensemble_all'
# GPI_MODEL_URL = '10.1.1.243:10051'
# Generate Product service config 旧版product img 模型
GPI_RABBITMQ_QUEUES = f"ToProductImage{RABBITMQ_ENV}"
BATCH_GPI_RABBITMQ_QUEUES = f"BatchToProductImage{RABBITMQ_ENV}"
GPI_MODEL_NAME_OVERALL = 'diffusion_ensemble_all'
GPI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_cnet'
GPI_MODEL_URL = '10.1.1.243:10051'
# Generate Single Logo service config
GRI_RABBITMQ_QUEUES = f"Relight{RABBITMQ_ENV}"
BATCH_GRI_RABBITMQ_QUEUES = f"BatchRelight{RABBITMQ_ENV}"
GRI_MODEL_NAME_OVERALL = 'diffusion_relight_ensemble'
GRI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_relight'
GRI_MODEL_URL = '10.1.1.240:10051'
# Pose Transform service config
PS_RABBITMQ_QUEUES = f"PoseTransform{RABBITMQ_ENV}"
BATCH_PS_RABBITMQ_QUEUES = f"BatchPoseTransform{RABBITMQ_ENV}"
PT_MODEL_URL = '10.1.1.243:10061'
# SEG service config
SEGMENTATION = {
"new_model_name": "seg_knet",
"name": "seg_ocrnet_hr18",
"input": "seg_input__0",
"output": "seg_output__0",
}
# ollama config
OLLAMA_URL = "http://10.1.1.240:11434/api/embeddings"
# design batch
BATCH_DESIGN_RABBITMQ_QUEUES = f"DesignBatch{RABBITMQ_ENV}"
# DESIGN config
DESIGN_MODEL_URL = '10.1.1.240:10000'
AIDA_CLOTHING = "aida-clothing"
KEYPOINT_RESULT_TABLE_FIELD_SET = ('neckline_left', 'neckline_right', 'shoulder_left', 'shoulder_right', 'armpit_left', 'armpit_right',
'cuff_left_in', 'cuff_left_out', 'cuff_right_in', 'cuff_right_out', 'waistband_left', 'waistband_right')
# DESIGN 预处理
IF_DEBUG_SHOW = False
# 优先级
# Design前后排优先级
PRIORITY_DICT = {
'earring_front': 99,
'bag_front': 98,
@@ -208,25 +111,71 @@ PRIORITY_DICT = {
'bag_back': -98,
'earring_back': -99,
}
# Design 关键点字段
KEYPOINT_RESULT_TABLE_FIELD_SET = ('neckline_left', 'neckline_right', 'shoulder_left', 'shoulder_right', 'armpit_left', 'armpit_right', 'cuff_left_in', 'cuff_left_out', 'cuff_right_in', 'cuff_right_out', 'waistband_left', 'waistband_right')
# milvus配置信息
MILVUS_TABLE_KEYPOINT = "keypoint_cache_2"
QWEN_API_KEY = "sk-f31c29e61ac2498ba5e307aaa6dc10e0"
# ollama 地址
OLLAMA_URL = "http://10.1.1.240:11434/api/embeddings"
DB_CONFIG = {
"host": "18.167.251.121",
"port": 3306,
"user": "root",
"password": "QWa998345",
"database": "aida",
"charset": "utf8mb4"
}
"""Triton Server Config"""
# Design
DESIGN_MODEL_URL = '10.1.1.240:10000'
DESIGN_MODEL_NAME = 'seg_knet'
# Generate Image
GI_MODEL_URL = '10.1.1.240:10061'
GI_MODEL_NAME = 'flux'
# Generate Single Logo
GSL_MODEL_URL = '10.1.1.243:10041'
GSL_MODEL_NAME = 'stable_diffusion_xl_transparent'
# Generate Product (整套和单品)
GPI_MODEL_URL = '10.1.1.243:10051'
GPI_MODEL_NAME_OVERALL = 'diffusion_ensemble_all'
GPI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_cnet'
TABLE_CATEGORIES = {
"female_dress": "female/dress",
"female_outwear": "female/outwear",
"female_trousers": "female/trousers",
"female_skirt": "female/skirt",
"female_blouse": "female/blouse",
"male_tops": "male/tops",
"male_bottoms": "male/bottoms",
"male_outwear": "male/outwear"
}
# 以下停用中...*************
# 多视角生成
GMV_MODEL_URL = '10.1.1.243:10081'
GMV_MODEL_NAME = 'multi_view'
# 超分
SR_MODEL_NAME = "super_resolution"
SR_TRITON_URL = "10.1.1.240:10031"
# 打光
GRI_MODEL_URL = '10.1.1.240:10051'
GRI_MODEL_NAME_OVERALL = 'diffusion_relight_ensemble'
GRI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_relight'
# agent 图片生成
FAST_GI_MODEL_URL = '10.1.1.243:10011'
FAST_GI_MODEL_NAME = 'stable_diffusion_xl'
# 图转视频 triton版
PT_MODEL_URL = '10.1.1.243:10061'
# *************
"""MQ 队列信息"""
# 生成图片 moodboard printboard sketchboard
GI_RABBITMQ_QUEUES = f"GenerateImage-{settings.SERVE_ENV}"
# 生成slogan
SLOGAN_RABBITMQ_QUEUES = f"Slogan-{settings.SERVE_ENV}"
# 转产品图
GPI_RABBITMQ_QUEUES = f"ToProductImage-{settings.SERVE_ENV}"
# 产品图转视频
PS_RABBITMQ_QUEUES = f"PoseTransform-{settings.SERVE_ENV}"
# 以下停用中...*************
# 产品图打光
GRI_RABBITMQ_QUEUES = f"Relight-{settings.SERVE_ENV}"
# 超分
SR_RABBITMQ_QUEUES = f"SuperResolution-{settings.SERVE_ENV}"
# 生成多视图
GMV_RABBITMQ_QUEUES = f"GenerateMultiView-{settings.SERVE_ENV}"
# 批量转产品图
BATCH_GPI_RABBITMQ_QUEUES = f"BatchToProductImage-{settings.SERVE_ENV}"
# 批量打光
BATCH_GRI_RABBITMQ_QUEUES = f"BatchRelight-{settings.SERVE_ENV}"
# 批量图片转视频
BATCH_PS_RABBITMQ_QUEUES = f"BatchPoseTransform-{settings.SERVE_ENV}"
# 批量design
BATCH_DESIGN_RABBITMQ_QUEUES = f"DesignBatch-{settings.SERVE_ENV}"
# *************

10
app/core/mysql_config.py Normal file
View File

@@ -0,0 +1,10 @@
from app.core.config import settings
DB_CONFIG = {
"host": settings.MYSQL_HOST,
"port": settings.MYSQL_PORT,
"user": settings.MYSQL_USER,
"password": settings.MYSQL_PASSWORD,
"database": settings.MYSQL_DB,
"charset": settings.MYSQL_CHARSET,
}

View File

@@ -0,0 +1,10 @@
# rabbitmq config
import pika
from app.core.config import settings
RABBITMQ_PARAMS = {
"host": settings.MQ_HOST,
"port": settings.MQ_PORT,
"credentials": pika.credentials.PlainCredentials(username=settings.MQ_USERNAME, password=settings.MQ_PASSWORD),
"virtual_host": settings.MQ_VIRTUAL_HOST,
}

View File

@@ -79,12 +79,8 @@
}
]
}
],
"process_id": "87",
"tasks_id": ,
"tasks_id": ""
}
//用 openai jsonl
//

View File

@@ -1,31 +1,40 @@
# 1. 这里的顺序至关重要!必须在最顶端
import sys
try:
import asyncore
except ImportError:
import pyasyncore
sys.modules['asyncore'] = pyasyncore
import logging.config
import uvicorn
from apscheduler.schedulers.background import BackgroundScheduler
from apscheduler.triggers.cron import CronTrigger
from fastapi import FastAPI
from fastapi import HTTPException, Request
from fastapi import HTTPException
from fastapi.responses import JSONResponse
from app.api.api_route import router
from app.core.config import settings
from app.core.record_api_count import count_api_calls
from app.schemas.response_template import ResponseModel
from app.service.recommend.service import load_resources
from logging_env import LOGGER_CONFIG_DICT
from dotenv import load_dotenv
from starlette.middleware.cors import CORSMiddleware
logging.config.dictConfig(LOGGER_CONFIG_DICT)
logging.getLogger("pika").setLevel(logging.WARNING)
from starlette.middleware.cors import CORSMiddleware
logger = logging.getLogger(__name__)
load_dotenv()
def get_application() -> FastAPI:
application = FastAPI(
title=settings.PROJECT_NAME, docs_url="/docs", redoc_url='/re-docs',
openapi_url=f"{settings.API_PREFIX}/openapi.json",
docs_url="/docs",
redoc_url='/re-docs',
openapi_url=f"/openapi.json",
description='''
Base frame with FastAPI
- Super Resolution API
@@ -34,13 +43,13 @@ def get_application() -> FastAPI:
)
application.add_middleware(
CORSMiddleware,
allow_origins=[str(origin) for origin in settings.BACKEND_CORS_ORIGINS],
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
application.middleware("http")(count_api_calls)
application.include_router(router=router, prefix=settings.API_PREFIX)
application.include_router(router=router)
return application
@@ -48,14 +57,12 @@ app = get_application()
@app.exception_handler(HTTPException)
async def http_exception_handler(request: Request, exc: HTTPException):
async def http_exception_handler(exc: HTTPException):
return JSONResponse(
status_code=exc.status_code,
content=ResponseModel(code=exc.status_code, msg=exc.detail, data=exc.detail).dict()
)
if __name__ == '__main__':
uvicorn.run(app, host="0.0.0.0", port=8000)
uvicorn.run(app, host="0.0.0.0", port=settings.PORT)

View File

@@ -0,0 +1,23 @@
from pydantic import BaseModel
class ComfyuiPose2VModel(BaseModel):
# 骨架生成视频
image_url: str
tasks_id: str
pose_id: str
class ComfyuiI2VModel(BaseModel):
# 图生视频
image_url: str
prompt: str
tasks_id: str
class ComfyuiFLF2VModel(BaseModel):
# 首尾帧生视频
start_image_url: str
end_image_url: str
prompt: str
tasks_id: str

View File

@@ -1,4 +1,15 @@
from pydantic import BaseModel
from typing import List, Optional
from pydantic import BaseModel, Field
class SAMRequestModel(BaseModel):
user_id: int = Field(..., description="用户id, 必填字段")
image_path: str = Field(..., description="图片路径,必填字段")
type: str = Field(..., description="推理类型,必填字段")
points: Optional[List[List[float]]] = None
labels: Optional[List[int]] = None
box: Optional[List[int]] = None
class DesignModel(BaseModel):
@@ -10,6 +21,7 @@ class DesignStreamModel(BaseModel):
objects: list[dict]
process_id: str
requestId: str
callback_url: str
class DesignProgressModel(BaseModel):

View File

@@ -1,22 +1,24 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
import logging
from pprint import pprint
import torch
import cv2
import mmcv
import numpy as np
import pandas as pd
from minio import Minio
import torch
import tritonclient.http as httpclient
from app.core.config import *
from minio import Minio
from app.core.config import settings, DESIGN_MODEL_URL
from app.schemas.attribute_retrieve import AttributeRecognitionModel
from app.service.utils.oss_client import oss_get_image
from app.service.utils.new_oss_client import oss_get_image
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
class AttributeRecognition:
def __init__(self, const, request_data):
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.request_data = []
for i, sketch in enumerate(request_data):
self.request_data.append(
@@ -96,11 +98,12 @@ class AttributeRecognition:
res = {**dict1, **dict2}
return res
def get_image(self, url):
@staticmethod
def get_image(url):
# response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1])
# img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型
# img = cv2.imdecode(img, cv2.IMREAD_COLOR) #
img = oss_get_image(bucket=url.split("/", 1)[0], object_name=url.split("/", 1)[1], data_type="cv2")
img = oss_get_image(oss_client=minio_client, bucket=url.split("/", 1)[0], object_name=url.split("/", 1)[1], data_type="cv2")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img

View File

@@ -7,24 +7,25 @@
@Date 2023/9/16 18:31:08
@detail
"""
from minio import Minio
from skimage import transform
import cv2
import mmcv
import numpy as np
import pandas as pd
from minio import Minio
import tritonclient.http as httpclient
import torch
from app.core.config import *
from app.core.config import settings, DESIGN_MODEL_URL
from app.schemas.attribute_retrieve import CategoryRecognitionModel
from app.service.utils.oss_client import oss_get_image
from app.service.utils.new_oss_client import oss_get_image
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
class CategoryRecognition:
def __init__(self, request_data):
self.attr_type = pd.read_csv(CATEGORY_PATH)
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.attr_type = pd.read_csv(settings.CATEGORY_PATH)
self.request_data = []
self.triton_client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
for sketch in request_data:
@@ -46,13 +47,14 @@ class CategoryRecognition:
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img
def get_image(self, url):
@staticmethod
def get_image(url):
# Get data of an object.
# Read data from response.
# response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1])
# img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型
# img = cv2.imdecode(img, cv2.IMREAD_COLOR) # 解码
img = oss_get_image(bucket=url.split("/", 1)[0], object_name=url.split("/", 1)[1], data_type="cv2")
img = oss_get_image(oss_client=minio_client, bucket=url.split("/", 1)[0], object_name=url.split("/", 1)[1], data_type="cv2")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img
@@ -68,7 +70,7 @@ class CategoryRecognition:
colattr = list(self.attr_type['labelName'])
task = self.attr_type['taskName'][0]
# self.attr_type['taskName'][0]
maxsc = np.max(scores[0][:5])
indexs = np.argwhere(scores == maxsc)[:, 1]

View File

@@ -9,15 +9,16 @@ import torch.nn.functional as F
import tritonclient.http as httpclient
from minio import Minio
from app.core.config import MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, DESIGN_MODEL_URL, CATEGORY_PATH
from app.core.config import DESIGN_MODEL_URL
from app.core.config import settings
from app.schemas.brand_dna import BrandDnaModel
from app.service.attribute.config import local_debug_const, const
from app.service.attribute.config import const
from app.service.utils.generate_uuid import generate_uuid
from app.service.utils.new_oss_client import oss_upload_image, oss_get_image
logger = logging.getLogger()
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
logger = logging.getLogger()
class BrandDna:
@@ -25,7 +26,7 @@ class BrandDna:
self.sketch_bucket = "test"
self.image_url = request_item.image_url
self.is_brand_dna = request_item.is_brand_dna
self.attr_type = pd.read_csv(CATEGORY_PATH)
self.attr_type = pd.read_csv(settings.CATEGORY_PATH)
# self.attr_type = pd.read_csv(r"E:\workspace\trinity_client_aida\app\service\attribute\config\descriptor\category\category_dis.csv")
self.att_client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
self.seg_client = httpclient.InferenceServerClient(url='10.1.1.243:30000')

View File

@@ -3,23 +3,25 @@ import logging
import cv2
import numpy as np
import tritonclient.grpc as grpcclient
from langchain.output_parsers import ResponseSchema, StructuredOutputParser
from langchain_classic.output_parsers import ResponseSchema, StructuredOutputParser
from langchain_community.chat_models import ChatTongyi
from langchain_core.prompts import PromptTemplate
# from langchain_openai import ChatOpenAI
from minio import Minio
from tritonclient.utils import np_to_triton_dtype
from app.core.config import GI_MODEL_URL, MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, GI_MODEL_NAME
from app.core.config import GI_MODEL_URL, GI_MODEL_NAME
from app.schemas.brand_dna import GenerateBrandModel
from app.service.utils.generate_uuid import generate_uuid
from app.service.utils.new_oss_client import oss_upload_image
from app.core.config import settings
class GenerateBrandInfo:
def __init__(self, request_data):
# minio client init
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
self.generate_logo_prompt = None
self.minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
# user info init
self.user_id = request_data.user_id
@@ -55,7 +57,7 @@ class GenerateBrandInfo:
return self.result_data
def llm_generate_brand_info(self):
output = self.model(self._input.to_messages())
output = self.model.invoke(self._input.to_messages())
brand_data = self.output_parser.parse(output.content)
self.result_data = brand_data
self.generate_logo_prompt = brand_data['brand_logo_prompt']
@@ -87,8 +89,8 @@ class GenerateBrandInfo:
def upload_logo_image(self, image, object_name):
try:
_, img_byte_array = cv2.imencode('.jpg', image)
object_name = f'{self.user_id}/{self.category}/{object_name}'
req = oss_upload_image(oss_client=self.minio_client, bucket="aida-users", object_name=object_name, image_bytes=img_byte_array)
object_name = f'{self.user_id}/{self.category}/{object_name}.jpg'
oss_upload_image(oss_client=self.minio_client, bucket="aida-users", object_name=object_name, image_bytes=img_byte_array)
image_url = f"aida-users/{object_name}"
return image_url
except Exception as e:

View File

@@ -1,32 +0,0 @@
from dotenv import load_dotenv
from langchain.output_parsers import StructuredOutputParser, ResponseSchema
from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI
# 加载.env文件的环境变量
load_dotenv()
# 创建一个大语言模型model指定了大语言模型的种类
model = ChatOpenAI(model="qwen2.5-14b-instruct")
# 想要接收的响应模式
response_schemas = [
ResponseSchema(name="brand_name", description="Brand name."),
ResponseSchema(name="brand_slogan", description="Brand slogan."),
ResponseSchema(name="brand_logo_prompt", description="prompt required for brand logo generation.")
]
output_parser = StructuredOutputParser.from_response_schemas(response_schemas)
format_instructions = output_parser.get_format_instructions()
prompt = PromptTemplate(
template="你是一个时装品牌的设计师。根据用户输入提取出brand namebrand slogan,brand logo 描述。如果没有以上内容需要你根据用户输入随意发挥。随后根据brand logo 描述生成一个prompt这个prompt用于生成模型.\n{format_instructions}\n{question}",
input_variables=["question"],
partial_variables={"format_instructions": format_instructions}
)
_input = prompt.format_prompt(question="brand name: cat home")
output = model(_input.to_messages())
brand_data = output_parser.parse(output.content)
def generate_logo(bucket_name, object_name, prompt):
pass

View File

@@ -3,27 +3,20 @@ import json
import logging
from typing import Any, Dict, List, Optional, Union, Tuple
from langchain.agents import AgentExecutor
from langchain.callbacks.manager import Callbacks, CallbackManager
from langchain.load.dump import dumpd
from langchain.schema import RUN_KEY, RunInfo
from langchain_classic.agents import AgentExecutor
from langchain_classic.schema import RUN_KEY
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import Callbacks, CallbackManager
from langchain_core.load import dumpd
from langchain_core.outputs import RunInfo
class CustomAgentExecutor(AgentExecutor):
def __call__(
self,
inputs: Union[Dict[str, Any], Any],
return_only_outputs: bool = False,
callbacks: Callbacks = None,
session_key: str = "",
*,
tags: Optional[List[str]] = None,
include_run_info: bool = False,
) -> Dict[str, Any]:
def __call__(self, inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False, callbacks: Callbacks = None, session_key: str = "", *, tags: Optional[List[str]] = None, include_run_info: bool = False, **kwargs) -> Dict[str, Any]:
"""Run the logic of this chain and add to output if desired.
Args:
**kwargs:
inputs: Dictionary of inputs, or single input if chain expects
only one param.
return_only_outputs: boolean for whether to return only outputs in the
@@ -72,7 +65,7 @@ class CustomAgentExecutor(AgentExecutor):
"""Validate and prep outputs."""
self._validate_outputs(outputs)
if self.memory is not None and outputs['need_record']:
self.memory.save_context(inputs, outputs, session_key)
self.memory.save_context(inputs, outputs)
if return_only_outputs:
return outputs
else:
@@ -95,7 +88,7 @@ class CustomAgentExecutor(AgentExecutor):
)
inputs = {list(_input_keys)[0]: inputs}
if self.memory is not None:
external_context = self.memory.load_memory_variables(inputs, session_key)
external_context = self.memory.load_memory_variables(inputs)
inputs = dict(inputs, **external_context)
self._validate_inputs(inputs)
return inputs
@@ -119,7 +112,8 @@ class CustomAgentExecutor(AgentExecutor):
{return_value_key: observation},
"",
)
except:
except Exception as e:
print(e)
pass
# Invalid tools won't be in the map, so we return False.

View File

@@ -1,26 +1,15 @@
import json
import re
from dataclasses import dataclass
from json import JSONDecodeError
from typing import List, Tuple, Any, Union
from dataclasses import dataclass
from langchain.callbacks.manager import Callbacks
from langchain.agents import (
OpenAIFunctionsAgent,
)
from langchain.schema import (
AgentAction,
AgentFinish,
BaseMessage,
OutputParserException
)
from langchain.schema.messages import (
AIMessage,
FunctionMessage
)
from langchain.tools import BaseTool, StructuredTool
# from langchain.tools.convert_to_openai import FunctionDescription
from langchain.utils.openai_functions import FunctionDescription
from langchain_classic.agents import OpenAIFunctionsAgent
from langchain_community.utils.ernie_functions import FunctionDescription
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks import Callbacks
from langchain_core.exceptions import OutputParserException
from langchain_core.messages import BaseMessage, AIMessage, FunctionMessage
from langchain_core.tools import BaseTool
@dataclass
@@ -76,7 +65,6 @@ def _create_function_message(
content = observation
return FunctionMessage(
name=agent_action.tool,
content=content,
)
@@ -177,6 +165,7 @@ class ConversationalFunctionsAgent(OpenAIFunctionsAgent):
into it.
Args:
callbacks:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.
**kwargs: Including user's input string

View File

@@ -2,18 +2,16 @@
from typing import Any, Dict
from langchain_community.callbacks.openai_info import OpenAICallbackHandler
from langchain.schema import LLMResult
from langchain_community.callbacks.openai_info import standardize_model_name, MODEL_COST_PER_1K_TOKENS, \
get_openai_token_cost_for_model
# from langchain.callbacks.openai_info import standardize_model_name, MODEL_COST_PER_1K_TOKENS, get_openai_token_cost_for_model
from langchain_core.outputs import LLMResult
class OpenAITokenRecordCallbackHandler(OpenAICallbackHandler):
need_record: bool = True
response_type: str = "string"
"""Callback Handler that tracks OpenAI info and write to redis after agent finish"""
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
"""Collect token usage."""
if response.llm_output is None:
@@ -39,6 +37,7 @@ class OpenAITokenRecordCallbackHandler(OpenAICallbackHandler):
self.total_tokens += token_usage.get("total_tokens", 0)
self.prompt_tokens += prompt_tokens
self.completion_tokens += completion_tokens
return None
def on_chain_end(self, outputs: Dict, **kwargs: Any) -> None:
"""Write token usage to redis."""

View File

@@ -44,12 +44,17 @@ class CustomDatabase(SQLDatabase):
final_str = "\n\n".join(tables)
return final_str
def run(self, command: str, fetch: str = "all") -> str:
def run(self, command: str, fetch: str = "all", **kwargs) -> str:
"""Execute a SQL command and return a string representing the results.
If the statement returns rows, a string of the results is returned.
If the statement returns no rows, an empty string is returned.
Args:
command:
fetch:
**kwargs:
"""
with self._engine.begin() as connection:
if self._schema is not None:

View File

@@ -1,15 +1,15 @@
import json
import logging
from langchain.agents import Tool
from langchain.callbacks import FileCallbackHandler
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder
from langchain.schema import SystemMessage, AIMessage
from langchain.utilities import SerpAPIWrapper
from langchain_community.utilities import SerpAPIWrapper
from langchain_core.callbacks import FileCallbackHandler
from langchain_core.messages import SystemMessage, AIMessage
from langchain_core.prompts import MessagesPlaceholder, HumanMessagePromptTemplate, ChatPromptTemplate
from langchain_core.tools import Tool
from langchain_community.chat_models import ChatTongyi
from loguru import logger
from app.core.config import *
from app.core.config import settings
from app.service.chat_robot.script.agents import CustomAgentExecutor, ConversationalFunctionsAgent
from app.service.chat_robot.script.database import CustomDatabase
from app.service.chat_robot.script.memory import UserConversationBufferWindowMemory
@@ -30,10 +30,10 @@ log_handler = FileCallbackHandler(logfile)
# # callbacks=[OpenAICallbackHandler()]
# )
llm = ChatTongyi(api_key=QWEN_API_KEY)
llm = ChatTongyi(api_key=settings.QWEN_API_KEY)
search = SerpAPIWrapper()
db = CustomDatabase.from_uri(f'mysql+pymysql://{DB_USERNAME}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/attribute_retrieval_V3',
db = CustomDatabase.from_uri(f'mysql+pymysql://{settings.DB_USERNAME}:{settings.DB_PASSWORD}@{settings.DB_HOST}:{settings.DB_PORT}/attribute_retrieval_V3',
include_tables=['female_top', 'female_skirt', 'female_pants', 'female_dress',
'female_outwear', 'male_bottom', 'male_top', 'male_outwear'],
engine_args={"pool_recycle": 7200})
@@ -43,11 +43,11 @@ tools = [
description="Can be used to perform Internet searches",
func=search.run
),
QuerySQLDataBaseTool(db=db, return_direct=False),
QuerySQLDataBaseTool(db=db),
InfoSQLDatabaseTool(db=db),
ListSQLDatabaseTool(db=db),
# QuerySQLCheckerTool(db=db, llm=OpenAI(temperature=0, openai_api_key=OPENAI_API_KEY)),
QuerySQLCheckerTool(db=db, llm=ChatTongyi(temperature=0, api_key=QWEN_API_KEY)),
QuerySQLCheckerTool(db=db, llm=ChatTongyi(api_key=settings.QWEN_API_KEY)),
# Tool(
# name="tutorial_tool",
# description="Utilize this tool to retrieve specific statements related to user guidance tutorials."
@@ -133,5 +133,5 @@ def chat(post_data):
'completion_tokens': final_outputs['completion_tokens'],
'response_type': final_outputs["response_type"]
}
logging.info(json.dumps(api_response))
logging.info(json.dumps(api_response, indent=4))
return api_response

View File

@@ -3,13 +3,12 @@ from typing import Any, Dict, List, Tuple
import json
import redis
from langchain_classic.memory.chat_memory import BaseChatMemory
from langchain_classic.memory.utils import get_prompt_input_key
from langchain_core.messages import messages_from_dict, get_buffer_string, BaseMessage, HumanMessage, AIMessage, message_to_dict
from redis import Redis
from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema.messages import BaseMessage, get_buffer_string, HumanMessage, AIMessage
from langchain.schema.messages import _message_to_dict, messages_from_dict
from langchain.memory.utils import get_prompt_input_key
from app.core.config import *
from app.core.config import settings
class UserConversationBufferWindowMemory(BaseChatMemory):
@@ -24,8 +23,8 @@ class UserConversationBufferWindowMemory(BaseChatMemory):
@classmethod
def from_redis(
cls,
host: str = REDIS_HOST,
port: int = REDIS_PORT,
host: str = settings.REDIS_HOST,
port: int = settings.REDIS_PORT,
db: int = 3,
**kwargs
):
@@ -79,7 +78,7 @@ class UserConversationBufferWindowMemory(BaseChatMemory):
return inputs[prompt_input_key], outputs[output_key]
def add_message(self, key: str, message: BaseMessage) -> None:
self.redis_client.lpush(key, json.dumps(_message_to_dict(message)))
self.redis_client.lpush(key, json.dumps(message_to_dict(message)))
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str], key: str = "") -> None:
"""Save context from this conversation to buffer."""

View File

@@ -5,10 +5,10 @@ from dashscope import Generation
from retry import retry
from urllib3.exceptions import NewConnectionError
from app.core.config import *
from app.core.config import settings
from app.service.chat_robot.script.callbacks.qwen_callback_handler import QWenCallbackHandler
from app.service.chat_robot.script.database import CustomDatabase
from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX, TUTORIAL_TOOL_RETURN, \
from app.service.chat_robot.script.prompt import TOOLS_FUNCTIONS_SUFFIX, TUTORIAL_TOOL_RETURN, \
GET_LANGUAGE_PREFIX, FASHION_CHAT_BOT_PREFIX_TEMP
from app.service.search_image_with_text.service import query
@@ -149,7 +149,7 @@ tools = [
}
]
db = CustomDatabase.from_uri(f'mysql+pymysql://{DB_USERNAME}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/attribute_retrieval_V3',
db = CustomDatabase.from_uri(f'mysql+pymysql://{settings.MYSQL_USER}:{settings.MYSQL_PASSWORD}@{settings.MYSQL_HOST}:{settings.MYSQL_PORT}/attribute_retrieval_V3',
include_tables=['female_top', 'female_skirt', 'female_pants', 'female_dress',
'female_outwear', 'male_bottom', 'male_top', 'male_outwear'],
engine_args={"pool_recycle": 7200})
@@ -159,7 +159,7 @@ qwen = QWenCallbackHandler()
def search_from_internet(message):
response = Generation.call(
model='qwen-turbo',
api_key=QWEN_API_KEY,
api_key=settings.QWEN_API_KEY,
messages=message,
prompt='The output must be in English.Keep the final result under 200 words.'
# tools=tools,
@@ -190,7 +190,7 @@ def get_image_from_vector_db(gender, content):
def get_response(messages):
response = Generation.call(
model='qwen-max',
api_key=QWEN_API_KEY,
api_key=settings.QWEN_API_KEY,
messages=messages,
tools=tools,
# seed=random.randint(1, 10000), # 设置随机数种子seed如果没有设置则随机数种子默认为1234
@@ -203,7 +203,7 @@ def get_response(messages):
def get_assistant_response(messages):
response = Generation.call(
model='qwen-max',
api_key=QWEN_API_KEY,
api_key=settings.QWEN_API_KEY,
messages=messages,
# seed=random.randint(1, 10000), # 设置随机数种子seed如果没有设置则随机数种子默认为1234
result_format='message', # 将输出设置为message形式
@@ -212,8 +212,10 @@ def get_assistant_response(messages):
return response
def call_with_messages(message):
global tool_info
def call_with_messages(message):
user_input = message
print('\n')
@@ -241,7 +243,7 @@ def call_with_messages(message):
response_type = "chat"
while flag and count <= 3:
first_response = get_response(messages)
first_response = get_response
assistant_output = first_response.output.choices[0].message
QWenCallbackHandler.on_llm_end(qwen, first_response.usage)
print(f"\n大模型第 {count} 轮输出信息:{first_response}\n")
@@ -260,7 +262,7 @@ def call_with_messages(message):
]
tool_info['content'] = search_from_internet(message)
flag = False
result_content = tool_info['content'].output.text
result_content = tool_info['content']
# 如果模型选择的工具是get_database_table
# elif assistant_output.tool_calls[0]['function']['name'] == 'get_database_table':
# tool_info = {"name": "get_database_table", "role": "tool", 'content': get_database_table()}

View File

@@ -2,21 +2,15 @@
"""Tools for interacting with a SQL database."""
from typing import Any, Dict, Optional, Type
from pydantic import BaseModel, Extra, Field, root_validator
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain.chains.llm import LLMChain
from langchain.prompts import PromptTemplate
from langchain_community.tools.sql_database.prompt import QUERY_CHECKER
from langchain_community.tools.sql_database.tool import _QuerySQLCheckerToolInput
# from langchain.sql_database import SQLDatabase
from langchain_community.utilities import SQLDatabase
from langchain.tools.base import BaseTool
from langchain_community.tools.sql_database.prompt import QUERY_CHECKER
from langchain_community.tools.sql_database.tool import QuerySQLCheckerTool, _QuerySQLCheckerToolInput
from langchain_core.callbacks import CallbackManagerForToolRun, AsyncCallbackManagerForToolRun
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts import PromptTemplate
from langchain_core.tools import BaseTool
from pydantic import BaseModel, Extra, Field, root_validator
class BaseSQLDatabaseTool(BaseModel):
@@ -183,7 +177,7 @@ class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool):
args_schema: Type[BaseModel] = _QuerySQLCheckerToolInput
@root_validator(pre=True)
def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
def initialize_llm_chain(self, values: Dict[str, Any]) -> Dict[str, Any]:
if "llm_chain" not in values:
# from langchain.chains.llm import LLMChain

View File

@@ -1,6 +1,6 @@
from typing import Any
from langchain.tools.base import BaseTool
from langchain_core.tools import BaseTool
from app.service.chat_robot.script.prompt import TUTORIAL_TOOL_RETURN

View File

@@ -9,14 +9,14 @@ from PIL import Image
from minio import Minio
from tritonclient.utils import np_to_triton_dtype
from app.core.config import *
from app.core.config import settings
from app.schemas.clothing_seg import ClothingSegModel
from app.service.design_fast.utils.design_ensemble import get_seg_result
from app.service.utils.decorator import RunTime
from app.service.utils.generate_uuid import generate_uuid
from app.service.utils.new_oss_client import oss_get_image, oss_upload_image
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
class ClothingSeg:
@@ -64,9 +64,9 @@ class ClothingSeg:
if image_type == "sketch":
if len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
seg_mask = get_seg_result(1, image[:, :, :3])
seg_mask = get_seg_result(image[:, :, :3])
else:
seg_mask = get_seg_result(1, image[:, :, :3])
seg_mask = get_seg_result(image[:, :, :3])
temp = seg_mask != 0.0
mask = (255 * (temp + 0).astype(np.uint8))
x_min, y_min, x_max, y_max = get_bounding_box(mask)

View File

@@ -0,0 +1,646 @@
import io
import json
import logging
import os
import random
import tempfile
import time
import uuid
import requests
from PIL import Image
from minio import Minio, S3Error
from moviepy.video.io.VideoFileClip import VideoFileClip
from app.core.config import PS_RABBITMQ_QUEUES
from app.core.config import settings
from app.schemas.comfyui_i2v import ComfyuiFLF2VModel
from app.service.generate_image.utils.mq import publish_status
logger = logging.getLogger()
# 首尾帧 + 文字 = 视频 工作流
workflow_json = {
"6": {
"inputs": {
"text": "A bearded man with red facial hair wearing a yellow straw hat and dark coat in Van Gogh's self-portrait style, slowly and continuously transforms into a space astronaut. The transformation flows like liquid paint - his beard fades away strand by strand, the yellow hat melts and reforms smoothly into a silver space helmet, dark coat gradually lightens and restructures into a white spacesuit. The background swirling brushstrokes slowly organize and clarify into realistic stars and space, with Earth appearing gradually in the distance. Every change happens in seamless waves, maintaining visual continuity throughout the metamorphosis.\n\nConsistent soft lighting throughout, medium close-up maintaining same framing, central composition stays fixed, gentle color temperature shift from warm to cool, gradual contrast increase, smooth style transition from painterly to photorealistic. Static camera with subtle slow zoom, emphasizing the flowing transformation process without abrupt changes.",
"clip": [
"38",
0
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Positive Prompt)"
}
},
"7": {
"inputs": {
"text": "色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
"clip": [
"38",
0
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Negative Prompt)"
}
},
"8": {
"inputs": {
"samples": [
"58",
0
],
"vae": [
"39",
0
]
},
"class_type": "VAEDecode",
"_meta": {
"title": "VAE解码"
}
},
"37": {
"inputs": {
"unet_name": "wan2.2_i2v_high_noise_14B_fp8_scaled.safetensors",
"weight_dtype": "default"
},
"class_type": "UNETLoader",
"_meta": {
"title": "UNet加载器"
}
},
"38": {
"inputs": {
"clip_name": "umt5_xxl_fp8_e4m3fn_scaled.safetensors",
"type": "wan",
"device": "default"
},
"class_type": "CLIPLoader",
"_meta": {
"title": "加载CLIP"
}
},
"39": {
"inputs": {
"vae_name": "wan_2.1_vae.safetensors"
},
"class_type": "VAELoader",
"_meta": {
"title": "加载VAE"
}
},
"54": {
"inputs": {
"shift": 5,
"model": [
"91",
0
]
},
"class_type": "ModelSamplingSD3",
"_meta": {
"title": "采样算法SD3"
}
},
"55": {
"inputs": {
"shift": 5,
"model": [
"92",
0
]
},
"class_type": "ModelSamplingSD3",
"_meta": {
"title": "采样算法SD3"
}
},
"56": {
"inputs": {
"unet_name": "wan2.2_i2v_low_noise_14B_fp8_scaled.safetensors",
"weight_dtype": "default"
},
"class_type": "UNETLoader",
"_meta": {
"title": "UNet加载器"
}
},
"57": {
"inputs": {
"add_noise": "enable",
"noise_seed": 984937593540091,
"steps": 4,
"cfg": 1,
"sampler_name": "euler",
"scheduler": "simple",
"start_at_step": 0,
"end_at_step": 2,
"return_with_leftover_noise": "enable",
"model": [
"54",
0
],
"positive": [
"67",
0
],
"negative": [
"67",
1
],
"latent_image": [
"67",
2
]
},
"class_type": "KSamplerAdvanced",
"_meta": {
"title": "K采样器高级"
}
},
"58": {
"inputs": {
"add_noise": "disable",
"noise_seed": 0,
"steps": 4,
"cfg": 1,
"sampler_name": "euler",
"scheduler": "simple",
"start_at_step": 2,
"end_at_step": 10000,
"return_with_leftover_noise": "disable",
"model": [
"55",
0
],
"positive": [
"67",
0
],
"negative": [
"67",
1
],
"latent_image": [
"57",
0
]
},
"class_type": "KSamplerAdvanced",
"_meta": {
"title": "K采样器高级"
}
},
"60": {
"inputs": {
"fps": 16,
"images": [
"8",
0
]
},
"class_type": "CreateVideo",
"_meta": {
"title": "创建视频"
}
},
"61": {
"inputs": {
"filename_prefix": "video/ComfyUI",
"format": "auto",
"codec": "auto",
"video": [
"60",
0
]
},
"class_type": "SaveVideo",
"_meta": {
"title": "保存视频"
}
},
"62": {
"inputs": {
"image": "video_wan2_2_14B_flf2v_start_image.png"
},
"class_type": "LoadImage",
"_meta": {
"title": "加载end图像"
}
},
"67": {
"inputs": {
"width": 640,
"height": 640,
"length": 81,
"batch_size": 1,
"positive": [
"6",
0
],
"negative": [
"7",
0
],
"vae": [
"39",
0
],
"start_image": [
"68",
0
],
"end_image": [
"62",
0
]
},
"class_type": "WanFirstLastFrameToVideo",
"_meta": {
"title": "WanFirstLastFrameToVideo"
}
},
"68": {
"inputs": {
"image": "video_wan2_2_14B_flf2v_end_image.png"
},
"class_type": "LoadImage",
"_meta": {
"title": "加载start图像"
}
},
"91": {
"inputs": {
"lora_name": "wan2.2_i2v_lightx2v_4steps_lora_v1_high_noise.safetensors",
"strength_model": 1,
"model": [
"37",
0
]
},
"class_type": "LoraLoaderModelOnly",
"_meta": {
"title": "LoRA加载器仅模型"
}
},
"92": {
"inputs": {
"lora_name": "wan2.2_i2v_lightx2v_4steps_lora_v1_low_noise.safetensors",
"strength_model": 1,
"model": [
"56",
0
]
},
"class_type": "LoraLoaderModelOnly",
"_meta": {
"title": "LoRA加载器仅模型"
}
}
}
class ComfyUIServerFLF2V:
def __init__(self, request_data):
self.pose_transform_data = None
self.start_image_url = request_data.start_image_url
self.end_image_url = request_data.end_image_url
self.prompt = request_data.prompt
self.tasks_id = request_data.tasks_id
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.server_status_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'gif_url': '', 'video_url': '', 'image_url': ''}
self.minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
def get_result(self):
workflow_json['6']['inputs']['text'] = self.prompt
workflow_json['57']['inputs']["noise_seed"] = random.randint(0, 10 ** 18)
if self.start_image_url:
# 下载图片 上传 comfyui server
# TODO 设置视频宽度为480高度自适应
workflow_json['67']['inputs']["width"] = 480
workflow_json['67']['inputs']["height"] = 848
if self.start_image_url:
start_in_memory_file, start_object_name = self.download_from_minio_in_memory(self.start_image_url)
# 上传图片到comfyui server
filename = self.upload_in_memory_file_to_comfyui(start_in_memory_file, start_object_name)
workflow_json['68']['inputs']['image'] = filename
else:
assert "start_image_url is None"
if self.end_image_url:
end_in_memory_file, end_object_name = self.download_from_minio_in_memory(self.end_image_url)
# 上传图片到comfyui server
filename = self.upload_in_memory_file_to_comfyui(end_in_memory_file, end_object_name)
workflow_json['62']['inputs']['image'] = filename
else:
assert "end_image_url is None"
# 1. 提交任务
prompt_response = self.queue_prompt(workflow_json, self.tasks_id)
if not prompt_response:
return None
prompt_id = prompt_response.get("prompt_id")
logger.info(f" 任务已提交Prompt ID: {prompt_id}")
outputs = self.poll_history(prompt_id)
file_list = {}
for node_id, node_output in outputs.items():
# 检查当前节点输出中是否包含 'images' 列表
if 'images' in node_output and isinstance(node_output['images'], list):
# 'images' 列表中的每个元素都是一个文件对象
for file_info in node_output['images']:
# 确保关键字段存在
if all(key in file_info for key in ['filename', 'subfolder', 'type']):
file_list = {
'filename': file_info['filename'],
'subfolder': file_info['subfolder'],
'type': file_info['type']
}
logger.info(file_list)
return self.process_and_upload_comfyui_video(filename=file_list['filename'], subfolder=file_list['subfolder'], prompt_id=prompt_response['prompt_id']), prompt_id
return None
def download_from_minio_in_memory(self, image_url):
bucket = image_url.split('/')[0]
object_name = image_url[image_url.find('/') + 1:]
try:
# get_object 返回一个 ResponseStream 对象
response_stream = self.minio_client.get_object(
bucket,
object_name,
)
# 读取整个流到内存 (BytesIO),避免写入本地文件
image_bytes = response_stream.read()
response_stream.close()
response_stream.release_conn()
in_memory_file = io.BytesIO(image_bytes)
# print(f"✅ 图片已下载到内存 ({len(image_bytes)} 字节)。")
return in_memory_file, object_name.rsplit('/')[-1]
except S3Error as e:
logger.error(f"❌ MinIO S3 错误 (例如,对象不存在): {e}")
return None, None
except Exception as e:
logger.error(f"❌ MinIO 下载过程中发生未知错误: {e}")
return None, None
@staticmethod
def upload_in_memory_file_to_comfyui(in_memory_file, filename):
upload_url = f"http://{settings.COMFYUI_SERVER_ADDRESS}/upload/image"
data = {
"overwrite": "true",
"type": "input"
}
# 构建 multipart/form-data: (文件名, 内存文件对象, MIME 类型)
# MIME 类型可以根据实际图片类型修改,这里使用常见的 png/jpeg
mime_type = 'image/png' if filename.lower().endswith('.png') else 'image/jpeg'
files = {
'image': (filename, in_memory_file, mime_type)
}
# print(f"⬆️ 正在上传图片 ({filename}) 到 ComfyUI...")
try:
comfyui_response = requests.post(upload_url, data=data, files=files)
comfyui_response.raise_for_status()
result = comfyui_response.json()
uploaded_name = result.get('name')
# print(f"🎉 ComfyUI 上传成功! 服务器文件名: {uploaded_name}")
return uploaded_name
except requests.exceptions.RequestException as e:
logger.error(f"❌ ComfyUI 上传失败: {e}")
logger.error(f" 响应内容: {comfyui_response.text}")
return None
def process_and_upload_comfyui_video(self, filename: str, subfolder: str, prompt_id: str, ):
"""
完整的自动化流程:获取 ComfyUI 视频 -> 转换 GIF 并提取帧 -> 上传所有结果到 MinIO。
"""
# 1. 从 ComfyUI 获取视频二进制数据
mp4_bytes = self.get_comfyui_video_bytes(filename, subfolder)
if not mp4_bytes:
return None
# 2. 准备进行视频处理
# moviepy 不支持直接使用 bytes需要将 bytes 写入一个 BytesIO 或临时文件
# 为了避免写磁盘,我们将使用 BytesIO但 MoviePy 内部依赖 FFmpeg有时需要一个可寻址的本地文件路径。
# 最可靠且避免写本地的方案是在内存中操作,然后将结果上传。
# ⚠️ 关键点:将 mp4_bytes 写入 BytesIO 以模拟文件,供 moviepy 读取
# 定义输出对象名
output_base_name = uuid.uuid4().hex
MP4_OBJECT = f"{self.user_id}/pose_transform_video/{prompt_id}/{output_base_name}.mp4"
GIF_OBJECT = f"{self.user_id}/pose_transform_gif/{prompt_id}/{output_base_name}.gif"
FRAME_OBJECT = f"{self.user_id}/pose_transform_first_img/{prompt_id}/{output_base_name}_frame.jpg"
# --- 视频处理和帧提取 ---
try:
# 1. 创建一个临时的 MP4 文件路径
# delete=False 确保文件在关闭后仍然存在,直到我们手动删除
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_file:
tmp_file.write(mp4_bytes) # 将内存数据写入磁盘
temp_mp4_path = tmp_file.name # 记录文件路径
# print(f"临时文件已写入: {temp_mp4_path}")
# 2. 使用 moviepy 打开临时文件 (传入文件路径字符串)
clip = VideoFileClip(temp_mp4_path)
# --- 在这里进行所有的视频处理和提取操作 ---
# 提取第一帧 (保持原尺寸)
frame_array = clip.get_frame(t=0.0)
image = Image.fromarray(frame_array)
frame_stream = io.BytesIO()
image.save(frame_stream, 'JPEG')
frame_bytes = frame_stream.getvalue()
logger.info("✅ 成功提取第一帧图片。")
# 视频转 GIF (使用另一个临时文件来保存 GIF)
temp_gif_path = ""
with tempfile.NamedTemporaryFile(suffix=".gif", delete=False) as tmp_file:
temp_gif_path = tmp_file.name
target_fps = int(round(clip.fps)) if clip.fps else 24
clip.write_gif(temp_gif_path, fps=target_fps)
with open(temp_gif_path, 'rb') as f:
gif_bytes = f.read()
logger.info("✅ 成功生成 GIF。")
# 返回结果 (例如: 上传到 MinIO)
# return mp4_bytes, gif_bytes, frame_bytes
# -----------------------------------------------
except Exception as e:
logger.error(f"❌ 视频处理或文件操作失败: {e}")
# 在失败时,也尝试清理文件
finally:
# 3. 清理临时文件 (非常重要!)
if os.path.exists(temp_mp4_path):
os.remove(temp_mp4_path)
logger.info(f"🗑️ 已删除临时 MP4 文件: {temp_mp4_path}")
if 'temp_gif_path' in locals() and os.path.exists(temp_gif_path):
os.remove(temp_gif_path)
logger.info(f"🗑️ 已删除临时 GIF 文件: {temp_gif_path}")
# 3. 上传所有结果到 MinIO
try:
# 上传原始 MP4
self.upload_stream_to_minio(mp4_bytes, MP4_OBJECT, "video/mp4")
# 上传生成的 GIF
self.upload_stream_to_minio(gif_bytes, GIF_OBJECT, "image/gif")
# 上传第一帧图片
self.upload_stream_to_minio(frame_bytes, FRAME_OBJECT, "image/jpeg")
self.pose_transform_data = {'tasks_id': self.tasks_id, 'status': 'SUCCESS', 'message': "success", 'gif_url': f'aida-users/{GIF_OBJECT}', 'video_url': f'aida-users/{MP4_OBJECT}', 'image_url': f'aida-users/{FRAME_OBJECT}'}
# 推送消息
if not settings.DEBUG:
publish_status(json.dumps(self.pose_transform_data), PS_RABBITMQ_QUEUES)
logger.info(
f" [x] Sent to {PS_RABBITMQ_QUEUES} data@@@@ {json.dumps(self.pose_transform_data, indent=4)}")
return "\n🎉 所有任务完成!"
except Exception as e:
logger.error(e)
return None
# --- 辅助函数:提交任务到队列 ---
@staticmethod
def queue_prompt(prompt, client_id):
"""向 ComfyUI 提交工作流提示。"""
p = {"prompt": prompt, "client_id": client_id, "prompt_id": client_id}
data = json.dumps(p).encode('utf-8')
# 提交任务到 /prompt 端点
response = requests.post(f"http://{settings.COMFYUI_SERVER_ADDRESS}/prompt", data=data)
# print(f"-------------{response.text}")
# print(f"------------{client_id}")
if response.status_code == 200:
return response.json()
else:
logger.warning(f"提交任务失败,状态码: {response.status_code}")
logger.warning(response.text)
return None
@staticmethod
def poll_history(prompt_id, interval_seconds=5):
"""步骤 2: 轮询 /history/{prompt_id} 检查任务是否完成"""
url = f"http://{settings.COMFYUI_SERVER_ADDRESS}/history/{prompt_id}"
logger.info(f"⏳ 开始轮询状态 (间隔 {interval_seconds} 秒)...")
while True:
time.sleep(interval_seconds)
try:
response = requests.get(url)
# 任务未完成时ComfyUI可能会返回404或空响应我们只关注成功响应
if response.status_code == 200:
history_data = response.json()
# ComfyUI 返回的历史记录结构是 {prompt_id: {outputs: ...}}
if prompt_id in history_data:
logger.info("🎉 任务已完成!")
return history_data[prompt_id]['outputs']
logger.info("⏳ 任务仍在执行或等待中...")
except requests.exceptions.RequestException as e:
# 处理可能的连接错误,但通常不会在内部轮询中发生
logger.info(f"⚠️ 轮询时发生错误: {e}")
pass
@staticmethod
def get_comfyui_video_bytes(filename: str, subfolder: str, file_type: str = "output"):
"""
从 ComfyUI 的 /view 端点获取视频文件的二进制数据。
参数:
- filename: 视频文件名 (例如: 'ComfyUI_00002_.mp4')
- subfolder: 存储子文件夹 (例如: 'ComfyUI_2025-10-31')
- file_type: 文件类型 (通常是 'output')
返回:
- 视频文件的二进制内容 (bytes) 或 None。
"""
url = f"http://{settings.COMFYUI_SERVER_ADDRESS}/view"
params = {
"filename": filename,
"subfolder": subfolder,
"type": file_type
}
logger.info(f"📡 正在从 ComfyUI 下载视频: {filename}")
try:
# 使用 requests.get 下载文件
response = requests.get(url, params=params, stream=True)
response.raise_for_status() # 检查 HTTP 错误
# 返回文件的完整二进制内容
return response.content
except requests.exceptions.RequestException as e:
logger.error(f"❌ 从 ComfyUI 获取视频失败: {e}")
return None
def upload_stream_to_minio(self, video_bytes: bytes, object_name: str, content_type: str):
"""从内存流上传数据到 MinIO。"""
logger.info(f"☁️ 正在上传对象到 MinIO: {object_name}")
try:
data_stream = io.BytesIO(video_bytes)
result = self.minio_client.put_object(
bucket_name='aida-users',
object_name=object_name,
data=data_stream,
length=len(video_bytes),
content_type=content_type
)
logger.info(f"✅ MinIO 上传成功: {result.object_name}")
return True
except S3Error as e:
logger.error(f"❌ MinIO 上传失败: {e}")
return False
if __name__ == '__main__':
request_data = ComfyuiFLF2VModel(
tasks_id="202511051619-89111",
start_image_url="test/start.png",
end_image_url="test/end.png",
prompt="Model executing a series of poses, dynamic camera movement alternating between detailed close-ups and full shots."
)
server = ComfyUIServerFLF2V(request_data)
print(server.get_result())

View File

@@ -0,0 +1,622 @@
import io
import json
import logging
import os
import random
import tempfile
import time
import uuid
import requests
from PIL import Image
from minio import Minio, S3Error
from moviepy.video.io.VideoFileClip import VideoFileClip
from app.core.config import PS_RABBITMQ_QUEUES, settings
from app.schemas.comfyui_i2v import ComfyuiI2VModel
from app.service.generate_image.utils.mq import publish_status
logger = logging.getLogger()
# 图 + 文字 = 视频 工作流
workflow_json = {
"84": {
"inputs": {
"clip_name": "umt5_xxl_fp8_e4m3fn_scaled.safetensors",
"type": "wan",
"device": "default"
},
"class_type": "CLIPLoader",
"_meta": {
"title": "加载CLIP"
}
},
"85": {
"inputs": {
"add_noise": "disable",
"noise_seed": 0,
"steps": 4,
"cfg": 1,
"sampler_name": "euler",
"scheduler": "simple",
"start_at_step": 2,
"end_at_step": 4,
"return_with_leftover_noise": "disable",
"model": [
"103",
0
],
"positive": [
"98",
0
],
"negative": [
"98",
1
],
"latent_image": [
"86",
0
]
},
"class_type": "KSamplerAdvanced",
"_meta": {
"title": "K采样器高级"
}
},
"86": {
"inputs": {
"add_noise": "enable",
"noise_seed": 823962998672127,
"steps": 4,
"cfg": 1,
"sampler_name": "euler",
"scheduler": "simple",
"start_at_step": 0,
"end_at_step": 2,
"return_with_leftover_noise": "enable",
"model": [
"104",
0
],
"positive": [
"98",
0
],
"negative": [
"98",
1
],
"latent_image": [
"98",
2
]
},
"class_type": "KSamplerAdvanced",
"_meta": {
"title": "K采样器高级"
}
},
"87": {
"inputs": {
"samples": [
"85",
0
],
"vae": [
"90",
0
]
},
"class_type": "VAEDecode",
"_meta": {
"title": "VAE解码"
}
},
"89": {
"inputs": {
"text": "色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
"clip": [
"84",
0
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Negative Prompt)"
}
},
"90": {
"inputs": {
"vae_name": "wan_2.1_vae.safetensors"
},
"class_type": "VAELoader",
"_meta": {
"title": "加载VAE"
}
},
"93": {
"inputs": {
"text": "Model executing a series of poses, dynamic camera movement alternating between detailed close-ups and full shots.",
"clip": [
"84",
0
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Positive Prompt)"
}
},
"94": {
"inputs": {
"fps": 16,
"images": [
"87",
0
]
},
"class_type": "CreateVideo",
"_meta": {
"title": "创建视频"
}
},
"95": {
"inputs": {
"unet_name": "wan2.2_i2v_high_noise_14B_fp8_scaled.safetensors",
"weight_dtype": "default"
},
"class_type": "UNETLoader",
"_meta": {
"title": "UNet加载器"
}
},
"96": {
"inputs": {
"unet_name": "wan2.2_i2v_low_noise_14B_fp8_scaled.safetensors",
"weight_dtype": "default"
},
"class_type": "UNETLoader",
"_meta": {
"title": "UNet加载器"
}
},
"97": {
"inputs": {
"image": "start (1).png"
},
"class_type": "LoadImage",
"_meta": {
"title": "加载图像"
}
},
"98": {
"inputs": {
"width": 480,
"height": 848,
"length": 81,
"batch_size": 1,
"positive": [
"93",
0
],
"negative": [
"89",
0
],
"vae": [
"90",
0
],
"start_image": [
"97",
0
]
},
"class_type": "WanImageToVideo",
"_meta": {
"title": "Wan图像到视频"
}
},
"101": {
"inputs": {
"lora_name": "wan2.2_i2v_lightx2v_4steps_lora_v1_high_noise.safetensors",
"strength_model": 1.0000000000000002,
"model": [
"95",
0
]
},
"class_type": "LoraLoaderModelOnly",
"_meta": {
"title": "LoRA加载器仅模型"
}
},
"102": {
"inputs": {
"lora_name": "wan2.2_i2v_lightx2v_4steps_lora_v1_low_noise.safetensors",
"strength_model": 1.0000000000000002,
"model": [
"96",
0
]
},
"class_type": "LoraLoaderModelOnly",
"_meta": {
"title": "LoRA加载器仅模型"
}
},
"103": {
"inputs": {
"shift": 5.000000000000001,
"model": [
"102",
0
]
},
"class_type": "ModelSamplingSD3",
"_meta": {
"title": "采样算法SD3"
}
},
"104": {
"inputs": {
"shift": 5.000000000000001,
"model": [
"101",
0
]
},
"class_type": "ModelSamplingSD3",
"_meta": {
"title": "采样算法SD3"
}
},
"108": {
"inputs": {
"filename_prefix": "video/ComfyUI",
"format": "auto",
"codec": "auto",
"video-preview": "",
"video": [
"94",
0
]
},
"class_type": "SaveVideo",
"_meta": {
"title": "保存视频"
}
}
}
class ComfyUIServerI2V:
def __init__(self, request_data):
self.pose_transform_data = None
self.image_url = request_data.image_url
self.prompt = request_data.prompt
self.tasks_id = request_data.tasks_id
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.server_status_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'gif_url': '', 'video_url': '', 'image_url': ''}
self.minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
def get_result(self):
workflow_json['93']['inputs']['text'] = self.prompt
workflow_json['86']['inputs']["noise_seed"] = random.randint(0, 10 ** 18)
if self.image_url:
# 下载图片 上传 comfyui server
in_memory_file, object_name = self.download_from_minio_in_memory(self.image_url)
# TODO 设置视频宽度为480高度自适应
workflow_json['98']['inputs']["width"] = 480
workflow_json['98']['inputs']["height"] = 848
if in_memory_file and object_name:
# 上传图片到comfyui server
filename = self.upload_in_memory_file_to_comfyui(in_memory_file, object_name)
workflow_json['97']['inputs']['image'] = filename
# 1. 提交任务
prompt_response = self.queue_prompt(workflow_json, self.tasks_id)
if not prompt_response:
return None
prompt_id = prompt_response.get("prompt_id")
logger.info(f" 任务已提交Prompt ID: {prompt_id}")
outputs = self.poll_history(prompt_id)
file_list = {}
for node_id, node_output in outputs.items():
# 检查当前节点输出中是否包含 'images' 列表
if 'images' in node_output and isinstance(node_output['images'], list):
# 'images' 列表中的每个元素都是一个文件对象
for file_info in node_output['images']:
# 确保关键字段存在
if all(key in file_info for key in ['filename', 'subfolder', 'type']):
file_list = {
'filename': file_info['filename'],
'subfolder': file_info['subfolder'],
'type': file_info['type']
}
logger.info(file_list)
return self.process_and_upload_comfyui_video(filename=file_list['filename'], subfolder=file_list['subfolder'], prompt_id=prompt_response['prompt_id']), prompt_id
return None
def download_from_minio_in_memory(self, image_url):
bucket = image_url.split('/')[0]
object_name = image_url[image_url.find('/') + 1:]
try:
# get_object 返回一个 ResponseStream 对象
response_stream = self.minio_client.get_object(
bucket,
object_name,
)
# 读取整个流到内存 (BytesIO),避免写入本地文件
image_bytes = response_stream.read()
response_stream.close()
response_stream.release_conn()
in_memory_file = io.BytesIO(image_bytes)
# print(f"✅ 图片已下载到内存 ({len(image_bytes)} 字节)。")
return in_memory_file, object_name.rsplit('/')[-1]
except S3Error as e:
logger.error(f"❌ MinIO S3 错误 (例如,对象不存在): {e}")
return None, None
except Exception as e:
logger.error(f"❌ MinIO 下载过程中发生未知错误: {e}")
return None, None
@staticmethod
def upload_in_memory_file_to_comfyui(in_memory_file, filename):
upload_url = f"http://{settings.COMFYUI_SERVER_ADDRESS}/upload/image"
data = {
"overwrite": "true",
"type": "input"
}
# 构建 multipart/form-data: (文件名, 内存文件对象, MIME 类型)
# MIME 类型可以根据实际图片类型修改,这里使用常见的 png/jpeg
mime_type = 'image/png' if filename.lower().endswith('.png') else 'image/jpeg'
files = {
'image': (filename, in_memory_file, mime_type)
}
# print(f"⬆️ 正在上传图片 ({filename}) 到 ComfyUI...")
try:
comfyui_response = requests.post(upload_url, data=data, files=files)
comfyui_response.raise_for_status()
result = comfyui_response.json()
uploaded_name = result.get('name')
# print(f"🎉 ComfyUI 上传成功! 服务器文件名: {uploaded_name}")
return uploaded_name
except requests.exceptions.RequestException as e:
logger.error(f"❌ ComfyUI 上传失败: {e}")
logger.error(f" 响应内容: {comfyui_response.text}")
return None
def process_and_upload_comfyui_video(self, filename: str, subfolder: str, prompt_id: str, ):
"""
完整的自动化流程:获取 ComfyUI 视频 -> 转换 GIF 并提取帧 -> 上传所有结果到 MinIO。
"""
# 1. 从 ComfyUI 获取视频二进制数据
mp4_bytes = self.get_comfyui_video_bytes(filename, subfolder)
if not mp4_bytes:
return None
# 2. 准备进行视频处理
# moviepy 不支持直接使用 bytes需要将 bytes 写入一个 BytesIO 或临时文件
# 为了避免写磁盘,我们将使用 BytesIO但 MoviePy 内部依赖 FFmpeg有时需要一个可寻址的本地文件路径。
# 最可靠且避免写本地的方案是在内存中操作,然后将结果上传。
# ⚠️ 关键点:将 mp4_bytes 写入 BytesIO 以模拟文件,供 moviepy 读取
# 定义输出对象名
output_base_name = uuid.uuid4().hex
MP4_OBJECT = f"{self.user_id}/pose_transform_video/{prompt_id}/{output_base_name}.mp4"
GIF_OBJECT = f"{self.user_id}/pose_transform_gif/{prompt_id}/{output_base_name}.gif"
FRAME_OBJECT = f"{self.user_id}/pose_transform_first_img/{prompt_id}/{output_base_name}_frame.jpg"
# --- 视频处理和帧提取 ---
try:
# 1. 创建一个临时的 MP4 文件路径
# delete=False 确保文件在关闭后仍然存在,直到我们手动删除
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_file:
tmp_file.write(mp4_bytes) # 将内存数据写入磁盘
temp_mp4_path = tmp_file.name # 记录文件路径
# print(f"临时文件已写入: {temp_mp4_path}")
# 2. 使用 moviepy 打开临时文件 (传入文件路径字符串)
clip = VideoFileClip(temp_mp4_path)
# --- 在这里进行所有的视频处理和提取操作 ---
# 提取第一帧 (保持原尺寸)
frame_array = clip.get_frame(t=0.0)
image = Image.fromarray(frame_array)
frame_stream = io.BytesIO()
image.save(frame_stream, 'JPEG')
frame_bytes = frame_stream.getvalue()
logger.info("✅ 成功提取第一帧图片。")
# 视频转 GIF (使用另一个临时文件来保存 GIF)
temp_gif_path = ""
with tempfile.NamedTemporaryFile(suffix=".gif", delete=False) as tmp_file:
temp_gif_path = tmp_file.name
target_fps = int(round(clip.fps)) if clip.fps else 24
clip.write_gif(temp_gif_path, fps=target_fps)
with open(temp_gif_path, 'rb') as f:
gif_bytes = f.read()
logger.info("✅ 成功生成 GIF。")
# 返回结果 (例如: 上传到 MinIO)
# return mp4_bytes, gif_bytes, frame_bytes
# -----------------------------------------------
except Exception as e:
logger.error(f"❌ 视频处理或文件操作失败: {e}")
# 在失败时,也尝试清理文件
finally:
# 3. 清理临时文件 (非常重要!)
if os.path.exists(temp_mp4_path):
os.remove(temp_mp4_path)
logger.info(f"🗑️ 已删除临时 MP4 文件: {temp_mp4_path}")
if 'temp_gif_path' in locals() and os.path.exists(temp_gif_path):
os.remove(temp_gif_path)
logger.info(f"🗑️ 已删除临时 GIF 文件: {temp_gif_path}")
# 3. 上传所有结果到 MinIO
try:
# 上传原始 MP4
self.upload_stream_to_minio(mp4_bytes, MP4_OBJECT, "video/mp4")
# 上传生成的 GIF
self.upload_stream_to_minio(gif_bytes, GIF_OBJECT, "image/gif")
# 上传第一帧图片
self.upload_stream_to_minio(frame_bytes, FRAME_OBJECT, "image/jpeg")
self.pose_transform_data = {'tasks_id': self.tasks_id, 'status': 'SUCCESS', 'message': "success", 'gif_url': f'aida-users/{GIF_OBJECT}', 'video_url': f'aida-users/{MP4_OBJECT}', 'image_url': f'aida-users/{FRAME_OBJECT}'}
# 推送消息
if not settings.DEBUG:
publish_status(json.dumps(self.pose_transform_data), PS_RABBITMQ_QUEUES)
logger.info(
f" [x] Sent to {PS_RABBITMQ_QUEUES} data@@@@ {json.dumps(self.pose_transform_data, indent=4)}")
return "\n🎉 所有任务完成!"
except Exception as e:
logger.error(e)
return None
# --- 辅助函数:提交任务到队列 ---
@staticmethod
def queue_prompt(prompt, client_id):
"""向 ComfyUI 提交工作流提示。"""
p = {"prompt": prompt, "client_id": client_id, "prompt_id": client_id}
data = json.dumps(p).encode('utf-8')
# 提交任务到 /prompt 端点
response = requests.post(f"http://{settings.COMFYUI_SERVER_ADDRESS}/prompt", data=data)
# print(f"-------------{response.text}")
# print(f"------------{client_id}")
if response.status_code == 200:
return response.json()
else:
logger.warning(f"提交任务失败,状态码: {response.status_code}")
logger.warning(response.text)
return None
@staticmethod
def poll_history(prompt_id, interval_seconds=5):
"""步骤 2: 轮询 /history/{prompt_id} 检查任务是否完成"""
url = f"http://{settings.COMFYUI_SERVER_ADDRESS}/history/{prompt_id}"
logger.info(f"⏳ 开始轮询状态 (间隔 {interval_seconds} 秒)...")
while True:
time.sleep(interval_seconds)
try:
response = requests.get(url)
# 任务未完成时ComfyUI可能会返回404或空响应我们只关注成功响应
if response.status_code == 200:
history_data = response.json()
# ComfyUI 返回的历史记录结构是 {prompt_id: {outputs: ...}}
if prompt_id in history_data:
logger.info("🎉 任务已完成!")
return history_data[prompt_id]['outputs']
logger.info("⏳ 任务仍在执行或等待中...")
except requests.exceptions.RequestException as e:
# 处理可能的连接错误,但通常不会在内部轮询中发生
logger.info(f"⚠️ 轮询时发生错误: {e}")
pass
@staticmethod
def get_comfyui_video_bytes(filename: str, subfolder: str, file_type: str = "output"):
"""
从 ComfyUI 的 /view 端点获取视频文件的二进制数据。
参数:
- filename: 视频文件名 (例如: 'ComfyUI_00002_.mp4')
- subfolder: 存储子文件夹 (例如: 'ComfyUI_2025-10-31')
- file_type: 文件类型 (通常是 'output')
返回:
- 视频文件的二进制内容 (bytes) 或 None。
"""
url = f"http://{settings.COMFYUI_SERVER_ADDRESS}/view"
params = {
"filename": filename,
"subfolder": subfolder,
"type": file_type
}
logger.info(f"📡 正在从 ComfyUI 下载视频: {filename}")
try:
# 使用 requests.get 下载文件
response = requests.get(url, params=params, stream=True)
response.raise_for_status() # 检查 HTTP 错误
# 返回文件的完整二进制内容
return response.content
except requests.exceptions.RequestException as e:
logger.error(f"❌ 从 ComfyUI 获取视频失败: {e}")
return None
def upload_stream_to_minio(self, video_bytes: bytes, object_name: str, content_type: str):
"""从内存流上传数据到 MinIO。"""
logger.info(f"☁️ 正在上传对象到 MinIO: {object_name}")
try:
data_stream = io.BytesIO(video_bytes)
result = self.minio_client.put_object(
bucket_name='aida-users',
object_name=object_name,
data=data_stream,
length=len(video_bytes),
content_type=content_type
)
logger.info(f"✅ MinIO 上传成功: {result.object_name}")
return True
except S3Error as e:
logger.error(f"❌ MinIO 上传失败: {e}")
return False
if __name__ == '__main__':
request_data = ComfyuiI2VModel(
tasks_id="12222515151123-89111",
image_url="aida-users/89/product_image/a6949500-2393-42ac-8723-440b5d5da2b2-0-89.png",
prompt="Model executing a series of poses, dynamic camera movement alternating between detailed close-ups and full shots."
)
server = ComfyUIServerI2V(request_data)
print(server.get_result())

View File

@@ -0,0 +1,745 @@
import io
import json
import logging
import os
import random
import tempfile
import time
import uuid
import redis
import requests
from PIL import Image
from minio import Minio, S3Error
from moviepy.video.io.VideoFileClip import VideoFileClip
from app.core.config import settings
from app.schemas.comfyui_i2v import ComfyuiPose2VModel
from app.service.generate_image.utils.mq import publish_status
logger = logging.getLogger()
# 图 + 骨架 = 视频 工作流
workflow_json = {
"162": {
"inputs": {
"text": "色调艳丽过曝静态细节模糊不清字幕风格作品画作画面静止整体发灰最差质量低质量JPEG压缩残留丑陋的残缺的多余的手指画得不好的手部画得不好的脸部畸形的毁容的形态畸形的肢体手指融合静止不动的画面杂乱的背景三条腿背景人很多倒着走",
"clip": [
"167",
0
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Negative Prompt)"
}
},
"163": {
"inputs": {
"fps": 24,
"images": [
"192",
0
]
},
"class_type": "CreateVideo",
"_meta": {
"title": "创建视频"
}
},
"164": {
"inputs": {
"samples": [
"175",
0
],
"vae": [
"168",
0
]
},
"class_type": "VAEDecode",
"_meta": {
"title": "VAE解码"
}
},
"165": {
"inputs": {
"unet_name": "wan2.2_fun_control_high_noise_14B_fp8_scaled.safetensors",
"weight_dtype": "default"
},
"class_type": "UNETLoader",
"_meta": {
"title": "UNet加载器"
}
},
"166": {
"inputs": {
"unet_name": "wan2.2_fun_control_low_noise_14B_fp8_scaled.safetensors",
"weight_dtype": "default"
},
"class_type": "UNETLoader",
"_meta": {
"title": "UNet加载器"
}
},
"167": {
"inputs": {
"clip_name": "umt5_xxl_fp8_e4m3fn_scaled.safetensors",
"type": "wan",
"device": "default"
},
"class_type": "CLIPLoader",
"_meta": {
"title": "加载CLIP"
}
},
"168": {
"inputs": {
"vae_name": "wan_2.1_vae.safetensors"
},
"class_type": "VAELoader",
"_meta": {
"title": "加载VAE"
}
},
"169": {
"inputs": {
"add_noise": "enable",
"noise_seed": 8860422635573,
"steps": 4,
"cfg": 1,
"sampler_name": "euler",
"scheduler": "simple",
"start_at_step": 0,
"end_at_step": 2,
"return_with_leftover_noise": "enable",
"model": [
"176",
0
],
"positive": [
"180",
0
],
"negative": [
"180",
1
],
"latent_image": [
"180",
2
]
},
"class_type": "KSamplerAdvanced",
"_meta": {
"title": "K采样器高级"
}
},
"170": {
"inputs": {
"filename_prefix": "video/wan2.2_fun_control",
"format": "auto",
"codec": "auto",
"video-preview": "",
"video": [
"163",
0
]
},
"class_type": "SaveVideo",
"_meta": {
"title": "保存视频"
}
},
"171": {
"inputs": {
"video": [
"174",
0
]
},
"class_type": "GetVideoComponents",
"_meta": {
"title": "获取视频组件"
}
},
"174": {
"inputs": {
"file": "skeleton_3.mp4"
},
"class_type": "LoadVideo",
"_meta": {
"title": "加载视频"
}
},
"175": {
"inputs": {
"add_noise": "disable",
"noise_seed": 0,
"steps": 4,
"cfg": 1,
"sampler_name": "euler",
"scheduler": "simple",
"start_at_step": 2,
"end_at_step": 4,
"return_with_leftover_noise": "disable",
"model": [
"177",
0
],
"positive": [
"180",
0
],
"negative": [
"180",
1
],
"latent_image": [
"169",
0
]
},
"class_type": "KSamplerAdvanced",
"_meta": {
"title": "K采样器高级"
}
},
"176": {
"inputs": {
"shift": 8.000000000000002,
"model": [
"181",
0
]
},
"class_type": "ModelSamplingSD3",
"_meta": {
"title": "采样算法SD3"
}
},
"177": {
"inputs": {
"shift": 8.000000000000002,
"model": [
"182",
0
]
},
"class_type": "ModelSamplingSD3",
"_meta": {
"title": "采样算法SD3"
}
},
"178": {
"inputs": {
"image": "296f5fd6-c5e4-4003-9798-f378a4f08411-0-89.png"
},
"class_type": "LoadImage",
"_meta": {
"title": "加载图像"
}
},
"179": {
"inputs": {
"text": "The model is catwalking at the fashion show.",
"clip": [
"167",
0
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Positive Prompt)"
}
},
"180": {
"inputs": {
"width": 480,
"height": 720,
"length": 121,
"batch_size": 1,
"positive": [
"179",
0
],
"negative": [
"162",
0
],
"vae": [
"168",
0
],
"ref_image": [
"178",
0
],
"control_video": [
"171",
0
]
},
"class_type": "Wan22FunControlToVideo",
"_meta": {
"title": "Wan22FunControlToVideo"
}
},
"181": {
"inputs": {
"lora_name": "wan2.2_i2v_lightx2v_4steps_lora_v1_high_noise.safetensors",
"strength_model": 1,
"model": [
"165",
0
]
},
"class_type": "LoraLoaderModelOnly",
"_meta": {
"title": "LoRA加载器仅模型"
}
},
"182": {
"inputs": {
"lora_name": "wan2.2_i2v_lightx2v_4steps_lora_v1_low_noise.safetensors",
"strength_model": 1,
"model": [
"166",
0
]
},
"class_type": "LoraLoaderModelOnly",
"_meta": {
"title": "LoRA加载器仅模型"
}
},
"189": {
"inputs": {
"images": [
"171",
0
]
},
"class_type": "PreviewImage",
"_meta": {
"title": "预览图像"
}
},
"190": {
"inputs": {
"images": [
"192",
0
]
},
"class_type": "PreviewImage",
"_meta": {
"title": "预览图像"
}
},
"192": {
"inputs": {
"batch_index": 4,
"length": 117,
"image": [
"164",
0
]
},
"class_type": "ImageFromBatch",
"_meta": {
"title": "从批次获取图像"
}
}
}
# 骨架映射
video_map = {
"1": "input_pose_video/1.mp4",
"2": "input_pose_video/2.mp4",
"3": "input_pose_video/3.mp4",
"4": "input_pose_video/4.mp4",
"5": "input_pose_video/5.mp4",
"6": "input_pose_video/6.mp4"
}
class ComfyUIServerPose2V:
def __init__(self, request_data):
self.image_url = request_data.image_url
self.pose_num = request_data.pose_id
self.tasks_id = request_data.tasks_id
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
self.redis_client = redis.StrictRedis(host=settings.REDIS_HOST, port=settings.REDIS_PORT, db=settings.REDIS_DB, decode_responses=True)
self.pose_transform_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'gif_url': '', 'video_url': '', 'image_url': ''}
self.redis_client.set(self.tasks_id, json.dumps(self.pose_transform_data))
self.redis_client.expire(self.tasks_id, 600)
self.minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
def get_result(self):
workflow_json['174']['inputs']['file'] = video_map[self.pose_num]
workflow_json['169']['inputs']['noise_seed'] = random.randint(0, 10 ** 18)
# 下载图片 上传 comfyui server
in_memory_file, object_name = self.download_from_minio_in_memory()
if in_memory_file and object_name:
uploaded_filename = self.upload_in_memory_file_to_comfyui(in_memory_file, object_name)
workflow_json['178']['inputs']['image'] = uploaded_filename
# 1. 提交任务
prompt_response = self.queue_prompt(workflow_json, self.tasks_id)
if not prompt_response:
return None
prompt_id = prompt_response.get("prompt_id")
logger.info(f" 任务已提交Prompt ID: {prompt_id}")
outputs = self.poll_history(prompt_id)
file_list = {}
for node_id, node_output in outputs.items():
# 检查当前节点输出中是否包含 'images' 列表
if 'images' in node_output and isinstance(node_output['images'], list):
# 'images' 列表中的每个元素都是一个文件对象
for file_info in node_output['images']:
# 确保关键字段存在
if all(key in file_info for key in ['filename', 'subfolder', 'type']):
file_list = {
'filename': file_info['filename'],
'subfolder': file_info['subfolder'],
'type': file_info['type']
}
logger.info(file_list)
return self.process_and_upload_comfyui_video(filename=file_list['filename'], subfolder=file_list['subfolder'], prompt_id=prompt_response['prompt_id']), prompt_id
return None
def read_tasks_status(self):
status_data = self.redis_client.get(self.tasks_id)
return json.loads(status_data), status_data
def download_from_minio_in_memory(self):
bucket = self.image_url.split('/')[0]
object_name = self.image_url[self.image_url.find('/') + 1:]
# print("🚀 正在连接 MinIO 客户端...")
try:
# get_object 返回一个 ResponseStream 对象
response_stream = self.minio_client.get_object(
bucket,
object_name,
)
# 读取整个流到内存 (BytesIO),避免写入本地文件
image_bytes = response_stream.read()
response_stream.close()
response_stream.release_conn()
in_memory_file = io.BytesIO(image_bytes)
# print(f"✅ 图片已下载到内存 ({len(image_bytes)} 字节)。")
return in_memory_file, object_name.rsplit('/')[-1]
except S3Error as e:
logger.error(f"❌ MinIO S3 错误 (例如,对象不存在): {e}")
return None, None
except Exception as e:
logger.error(f"❌ MinIO 下载过程中发生未知错误: {e}")
return None, None
def upload_video_to_minio(self, BUCKET_NAME, OBJECT_NAME, LOCAL_FILE_PATH):
"""使用 fput_object 从本地路径上传 MP4 文件"""
try:
# 使用 fput_object 上传文件
# content_type 对于视频流播放非常重要MP4 文件应使用 'video/mp4'
result = self.minio_client.fput_object(
bucket_name=BUCKET_NAME,
object_name=OBJECT_NAME,
file_path=LOCAL_FILE_PATH,
content_type="video/mp4" # 设置正确的内容类型
)
# print(f"✅ 文件 '{LOCAL_FILE_PATH}' 已成功上传至 MinIO:")
# print(f" 对象名: {result.object_name}")
# print(f" Etag: {result.etag}")
except S3Error as e:
logger.error(f"❌ MinIO 操作失败: {e}")
except FileNotFoundError:
logger.error(f"❌ 找不到本地文件: {LOCAL_FILE_PATH}")
except Exception as e:
logger.error(f"❌ 发生未知错误: {e}")
def upload_gif_to_minio(self, BUCKET_NAME, OBJECT_NAME, LOCAL_FILE_PATH):
"""使用 fput_object 从本地路径上传 MP4 文件"""
try:
# 使用 fput_object 上传文件
# content_type 对于视频流播放非常重要MP4 文件应使用 'video/mp4'
result = self.minio_client.fput_object(
bucket_name=BUCKET_NAME,
object_name=OBJECT_NAME,
file_path=LOCAL_FILE_PATH,
content_type="video/mp4" # 设置正确的内容类型
)
# print(f"✅ 文件 '{LOCAL_FILE_PATH}' 已成功上传至 MinIO:")
# print(f" 对象名: {result.object_name}")
# print(f" Etag: {result.etag}")
except S3Error as e:
logger.error(f"❌ MinIO 操作失败: {e}")
except FileNotFoundError:
logger.error(f"❌ 找不到本地文件: {LOCAL_FILE_PATH}")
except Exception as e:
logger.error(f"❌ 发生未知错误: {e}")
@staticmethod
def upload_in_memory_file_to_comfyui(in_memory_file, filename):
upload_url = f"http://{settings.COMFYUI_SERVER_ADDRESS}/upload/image"
data = {
"overwrite": "true",
"type": "input"
}
# 构建 multipart/form-data: (文件名, 内存文件对象, MIME 类型)
# MIME 类型可以根据实际图片类型修改,这里使用常见的 png/jpeg
mime_type = 'image/png' if filename.lower().endswith('.png') else 'image/jpeg'
files = {
'image': (filename, in_memory_file, mime_type)
}
# print(f"⬆️ 正在上传图片 ({filename}) 到 ComfyUI...")
try:
comfyui_response = requests.post(upload_url, data=data, files=files)
comfyui_response.raise_for_status()
result = comfyui_response.json()
uploaded_name = result.get('name')
# print(f"🎉 ComfyUI 上传成功! 服务器文件名: {uploaded_name}")
return uploaded_name
except requests.exceptions.RequestException as e:
logger.error(f"❌ ComfyUI 上传失败: {e}")
logger.error(f" 响应内容: {comfyui_response.text}")
return None
def process_and_upload_comfyui_video(self, filename: str, subfolder: str, prompt_id: str, ):
"""
完整的自动化流程:获取 ComfyUI 视频 -> 转换 GIF 并提取帧 -> 上传所有结果到 MinIO。
"""
# 1. 从 ComfyUI 获取视频二进制数据
mp4_bytes = self.get_comfyui_video_bytes(filename, subfolder)
if not mp4_bytes:
return None
# 2. 准备进行视频处理
# moviepy 不支持直接使用 bytes需要将 bytes 写入一个 BytesIO 或临时文件
# 为了避免写磁盘,我们将使用 BytesIO但 MoviePy 内部依赖 FFmpeg有时需要一个可寻址的本地文件路径。
# 最可靠且避免写本地的方案是在内存中操作,然后将结果上传。
# ⚠️ 关键点:将 mp4_bytes 写入 BytesIO 以模拟文件,供 moviepy 读取
# 定义输出对象名
output_base_name = uuid.uuid4().hex
MP4_OBJECT = f"{self.user_id}/pose_transform_video/{prompt_id}/{output_base_name}.mp4"
GIF_OBJECT = f"{self.user_id}/pose_transform_gif/{prompt_id}/{output_base_name}.gif"
FRAME_OBJECT = f"{self.user_id}/pose_transform_first_img/{prompt_id}/{output_base_name}_frame.jpg"
# --- 视频处理和帧提取 ---
try:
# 1. 创建一个临时的 MP4 文件路径
# delete=False 确保文件在关闭后仍然存在,直到我们手动删除
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_file:
tmp_file.write(mp4_bytes) # 将内存数据写入磁盘
temp_mp4_path = tmp_file.name # 记录文件路径
# print(f"临时文件已写入: {temp_mp4_path}")
# 2. 使用 moviepy 打开临时文件 (传入文件路径字符串)
clip = VideoFileClip(temp_mp4_path)
# --- 在这里进行所有的视频处理和提取操作 ---
# 提取第一帧 (保持原尺寸)
frame_array = clip.get_frame(t=0.0)
image = Image.fromarray(frame_array)
frame_stream = io.BytesIO()
image.save(frame_stream, 'JPEG')
frame_bytes = frame_stream.getvalue()
logger.info("✅ 成功提取第一帧图片。")
# 视频转 GIF (使用另一个临时文件来保存 GIF)
temp_gif_path = ""
with tempfile.NamedTemporaryFile(suffix=".gif", delete=False) as tmp_file:
temp_gif_path = tmp_file.name
target_fps = int(round(clip.fps)) if clip.fps else 24
clip.write_gif(temp_gif_path, fps=target_fps)
with open(temp_gif_path, 'rb') as f:
gif_bytes = f.read()
logger.info("✅ 成功生成 GIF。")
# 返回结果 (例如: 上传到 MinIO)
# return mp4_bytes, gif_bytes, frame_bytes
# -----------------------------------------------
except Exception as e:
logger.error(f"❌ 视频处理或文件操作失败: {e}")
# 在失败时,也尝试清理文件
finally:
# 3. 清理临时文件 (非常重要!)
if os.path.exists(temp_mp4_path):
os.remove(temp_mp4_path)
logger.info(f"🗑️ 已删除临时 MP4 文件: {temp_mp4_path}")
if 'temp_gif_path' in locals() and os.path.exists(temp_gif_path):
os.remove(temp_gif_path)
logger.info(f"🗑️ 已删除临时 GIF 文件: {temp_gif_path}")
# 3. 上传所有结果到 MinIO
try:
# 上传原始 MP4
self.upload_stream_to_minio(mp4_bytes, MP4_OBJECT, "video/mp4")
# 上传生成的 GIF
self.upload_stream_to_minio(gif_bytes, GIF_OBJECT, "image/gif")
# 上传第一帧图片
self.upload_stream_to_minio(frame_bytes, FRAME_OBJECT, "image/jpeg")
self.pose_transform_data = {'tasks_id': self.tasks_id, 'status': 'SUCCESS', 'message': "success", 'gif_url': f'aida-users/{GIF_OBJECT}', 'video_url': f'aida-users/{MP4_OBJECT}', 'image_url': f'aida-users/{FRAME_OBJECT}'}
# 推送消息
if not settings.DEBUG:
publish_status(json.dumps(self.pose_transform_data), settings.COMFYUI_SERVER_ADDRESS)
logger.info(
f" [x] Sent to {settings.COMFYUI_SERVER_ADDRESS} data@@@@ {json.dumps(self.pose_transform_data, indent=4)}")
return "\n🎉 所有任务完成!"
except Exception as e:
logger.error(e)
return None
# --- 辅助函数:提交任务到队列 ---
@staticmethod
def queue_prompt(prompt, client_id):
"""向 ComfyUI 提交工作流提示。"""
p = {"prompt": prompt, "client_id": client_id, "prompt_id": client_id}
data = json.dumps(p).encode('utf-8')
# 提交任务到 /prompt 端点
# noinspection HttpUrlsUsage
response = requests.post(f"http://{settings.COMFYUI_SERVER_ADDRESS}/prompt", data=data)
# print(f"-------------{response.text}")
# print(f"------------{client_id}")
if response.status_code == 200:
return response.json()
else:
logger.warning(f"提交任务失败,状态码: {response.status_code}")
logger.warning(response.text)
return None
@staticmethod
def poll_history(prompt_id, interval_seconds=5):
"""步骤 2: 轮询 /history/{prompt_id} 检查任务是否完成"""
url = f"http://{settings.COMFYUI_SERVER_ADDRESS}/history/{prompt_id}"
logger.info(f"⏳ 开始轮询状态 (间隔 {interval_seconds} 秒)...")
while True:
time.sleep(interval_seconds)
try:
response = requests.get(url)
# 任务未完成时ComfyUI可能会返回404或空响应我们只关注成功响应
if response.status_code == 200:
history_data = response.json()
# ComfyUI 返回的历史记录结构是 {prompt_id: {outputs: ...}}
if prompt_id in history_data:
logger.info("🎉 任务已完成!")
return history_data[prompt_id]['outputs']
logger.info("⏳ 任务仍在执行或等待中...")
except requests.exceptions.RequestException as e:
# 处理可能的连接错误,但通常不会在内部轮询中发生
logger.info(f"⚠️ 轮询时发生错误: {e}")
pass
@staticmethod
def get_comfyui_video_bytes(filename: str, subfolder: str, file_type: str = "output"):
"""
从 ComfyUI 的 /view 端点获取视频文件的二进制数据。
参数:
- filename: 视频文件名 (例如: 'ComfyUI_00002_.mp4')
- subfolder: 存储子文件夹 (例如: 'ComfyUI_2025-10-31')
- file_type: 文件类型 (通常是 'output')
返回:
- 视频文件的二进制内容 (bytes) 或 None。
"""
url = f"http://{settings.COMFYUI_SERVER_ADDRESS}/view"
params = {
"filename": filename,
"subfolder": subfolder,
"type": file_type
}
logger.info(f"📡 正在从 ComfyUI 下载视频: {filename}")
try:
# 使用 requests.get 下载文件
response = requests.get(url, params=params, stream=True)
response.raise_for_status() # 检查 HTTP 错误
# 返回文件的完整二进制内容
return response.content
except requests.exceptions.RequestException as e:
logger.error(f"❌ 从 ComfyUI 获取视频失败: {e}")
return None
def upload_stream_to_minio(self, video_bytes: bytes, object_name: str, content_type: str):
"""从内存流上传数据到 MinIO。"""
logger.info(f"☁️ 正在上传对象到 MinIO: {object_name}")
try:
data_stream = io.BytesIO(video_bytes)
result = self.minio_client.put_object(
bucket_name='aida-users',
object_name=object_name,
data=data_stream,
length=len(video_bytes),
content_type=content_type
)
logger.info(f"✅ MinIO 上传成功: {result.object_name}")
return True
except S3Error as e:
logger.error(f"❌ MinIO 上传失败: {e}")
return False
if __name__ == '__main__':
request_data = ComfyuiPose2VModel(
tasks_id="122522251123-89111",
image_url="aida-users/89/product_image/a6949500-2393-42ac-8723-440b5d5da2b2-0-89.png",
pose_id="6"
)
server = ComfyUIServerPose2V(request_data)
print(server.get_result())

View File

@@ -1,116 +0,0 @@
import logging
import numpy as np
import cv2
from matplotlib import pyplot as plt
from PIL import Image
def show(img, win_name="temp"):
cv2.imshow(win_name, img)
cv2.waitKey(0)
def crop(img):
mid_point_h, mid_point_w = int(img.shape[0] / 2 + 30), int(img.shape[1] / 2)
img_roi = img[mid_point_h - 520: mid_point_h + 520, mid_point_w - 340: mid_point_w + 340]
return img_roi
class Layer(object):
def __init__(self):
self._layer = []
@property
def layer(self):
return self._layer
def insert(self, layer_instance):
if layer_instance['name'] == 'body':
self._body = layer_instance
self._layer.append(layer_instance)
def sort(self, priority):
self._layer.sort(key=lambda x: priority[x['name']])
# def merge(self, cfg):
# """
# opencv shape order (height, width, channel)
# image coordinate system:
# |------------->x (width)
# |
# |
# |
# y (height)
# Returns:
#
#
# """
# base_image = Image.new('RGBA', self._layer[1]['image'].size, (0, 0, 0, 0))
# for layer in self._layer:
# y, x = layer['position']
# base_image.paste(layer['image'], (x, y), layer['image'])
# # base_image.show()
#
# for x in self._layer:
# if np.all(x['mask'] == 0):
# continue
# # obtain region of interest about roi(roi) and item-image(roi_image, roi_mask)
# roi, roi_mask, roi_image, signal = self.get_roi(dst=dst, image=x)
# temp_bg = np.expand_dims(cv2.bitwise_not(roi_mask), axis=2).repeat(3, axis=2)
# tmp1 = (roi * (temp_bg / 255)).astype(np.uint8)
# temp_fg = np.expand_dims(roi_mask, axis=2).repeat(3, axis=2)
# tmp2 = (roi_image * (temp_fg / 255)).astype(np.uint8)
#
# roi[:] = cv2.add(tmp1, tmp2)
# # show(cv2.resize(dst, (int(dst.shape[1] * 0.5), int(dst.shape[0] * 0.5)), interpolation=cv2.INTER_AREA),
# # win_name=x.get('name'))
# # crop image and get the central part
# if cfg.get('basic')['self_template'] == False:
# dst_roi = crop(dst)
# else:
# dst_roi = dst
# return dst_roi, signal
#
# @staticmethod
# def get_roi(dst, image):
# signal = False
# dst_y, dst_x = dst.shape[:2]
# roi_height, roi_width = image['mask'].shape
# roi_y0, roi_x0 = image['position']
#
# if roi_y0 < 0:
# roi_yin = 0
# mask_yin = -roi_y0
# signal = True
# else:
# roi_yin = roi_y0
# mask_yin = 0
# if roi_y0 + roi_height > dst_y:
# roi_yout = dst_y
# mask_yout = dst_y - roi_y0
# signal = True
# else:
# roi_yout = roi_height + roi_y0
# mask_yout = roi_height
# # x part
# if roi_x0 < 0:
# roi_xin = 0
# mask_xin = -roi_x0
# signal = True
# else:
# roi_xin = roi_x0
# mask_xin = 0
# if roi_x0 + roi_width > dst_x:
# roi_xout = dst_x
# mask_xout = dst_x - roi_x0
# signal = True
# else:
# roi_xout = roi_width + roi_x0
# mask_xout = roi_width
#
# roi = dst[roi_yin: roi_yout, roi_xin: roi_xout]
# roi_mask = image['mask'][mask_yin: mask_yout, mask_xin: mask_xout]
# roi_image = image['image'][mask_yin: mask_yout, mask_xin: mask_xout]
# return roi, roi_mask, roi_image, signal

View File

@@ -1,45 +0,0 @@
class Priority(object):
"""Item layer priority levels.
"""
def __init__(self, item_list):
self._priority = dict(
earring_front=99,
bag_front=98,
hairstyle_front=97,
outwear_front=20,
bottoms_front=19,
dress_front=18,
blouse_front=17,
skirt_front=16,
trousers_front=15,
tops_front=14,
shoes_right=1,
shoes_left=1,
body=0,
tops_back=-14,
trousers_back=-15,
skirt_back=-16,
blouse_back=-17,
dress_back=-18,
bottoms_back=-19,
outwear_back=-20,
hairstyle_back=-97,
bag_back=-98,
earring_back=-99,
)
self.clothing_start_num = 10
if not isinstance(item_list, list):
raise ValueError('item_list must be a list!')
for cate in item_list:
cate = cate.lower()
if cate not in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms'):
raise ValueError(f'Item type error. Cannot recognize {cate}')
for i, cate in enumerate(item_list):
cate = cate.lower()
self._priority[f'{cate}_front'] = self.clothing_start_num - i
self._priority[f'{cate}_back'] = -(self.clothing_start_num - i)
@property
def priority(self):
return self._priority

View File

@@ -1,16 +0,0 @@
from .builder import ITEMS, build_item
from .clothing import Clothing # 4.0 sec
from .body import Body
from .top import Top, Blouse, Outwear, Dress
from .bottom import Bottom, Trousers, Skirt
from .shoes import Shoes
from .bag import Bag
from .accessories import Hairstyle, Earring
__all__ = [
'ITEMS', 'build_item',
'Clothing', 'Body',
'Top', 'Blouse', 'Outwear', 'Dress',
'Bottom', 'Trousers', 'Skirt',
'Shoes', 'Bag', 'Hairstyle', 'Earring'
]

View File

@@ -1,59 +0,0 @@
from .builder import ITEMS
from .clothing import Clothing
@ITEMS.register_module()
class Hairstyle(Clothing):
def __init__(self, **kwargs):
pipeline = [
dict(type='LoadImageFromFile', path=kwargs['path']),
dict(type='KeypointDetection'),
dict(type='ContourDetection'),
dict(type='Painting'),
dict(type='Scaling'),
dict(type='Split'),
# dict(type='ImageShow', key=['image', 'mask', 'pattern_image']),
]
kwargs.update(pipeline=pipeline)
super(Hairstyle, self).__init__(**kwargs)
@staticmethod
def calculate_start_point(keypoint_type, scale, clothes_point, body_point):
"""
align up
Args:
keypoint_type: string, "head_point"
scale: float
clothes_point: dict{'left': [x1, y1, z1], 'right': [x2, y2, z2]}
body_point: dict, containing keypoint data of body figure
Returns:
start_point: tuple (x', y')
x' = y_body - y1 * scale
y' = x_body - x1 * scale
"""
side_indicator = f'{keypoint_type}_up'
# clothes_point = {k: tuple(map(lambda x: int(scale * x), v[0: 2])) for k, v in clothes_point.items()}
# logging.info(clothes_point[side_indicator])
start_point = (
int(body_point[side_indicator][1] - int(clothes_point[side_indicator].split("_")[1] * scale)),
int(body_point[side_indicator][0] - int(clothes_point[side_indicator].split("_")[0] * scale))
)
return start_point
@ITEMS.register_module()
class Earring(Clothing):
def __init__(self, **kwargs):
pipeline = [
dict(type='LoadImageFromFile', path=kwargs['path']),
dict(type='KeypointDetection'),
dict(type='ContourDetection'),
dict(type='Painting'),
dict(type='Scaling'),
dict(type='Split'),
# dict(type='ImageShow', key=['image', 'mask', 'pattern_image']),
]
kwargs.update(pipeline=pipeline)
super(Earring, self).__init__(**kwargs)

View File

@@ -1,45 +0,0 @@
import random
from .builder import ITEMS
from .clothing import Clothing
@ITEMS.register_module()
class Bag(Clothing):
def __init__(self, **kwargs):
pipeline = [
dict(type='LoadImageFromFile', path=kwargs['path'], color=kwargs['color']),
dict(type='KeypointDetection'),
dict(type='ContourDetection'),
dict(type='Painting'),
dict(type='Scaling'),
dict(type='Split'),
# dict(type='ImageShow', key=['image', 'mask', 'pattern_image']),
]
kwargs.update(pipeline=pipeline)
super(Bag, self).__init__(**kwargs)
@staticmethod
def calculate_start_point(keypoint_type, scale, clothes_point, body_point):
"""
align left
Args:
keypoint_type: string, "hand_point"
scale: float
clothes_point: dict{'left': [x1, y1, z1], 'right': [x2, y2, z2]}
body_point: dict, containing keypoint data of body figure
Returns:
start_point: tuple (y', x')
x' = y_body - y1 * scale
y' = x_body - x1 * scale
"""
location = random.choice(seq=['left', 'right'])
if location == 'left':
side_indicator = f'{keypoint_type}_left'
else:
side_indicator = f'{keypoint_type}_right'
# clothes_point = {k: tuple(map(lambda x: int(scale * x), v[0: 2])) for k, v in clothes_point.items()}
start_point = (body_point[side_indicator][1] - int(int(clothes_point[keypoint_type].split("_")[1]) * scale),
body_point[side_indicator][0] - int(int(clothes_point[keypoint_type].split("_")[0]) * scale))
return start_point

View File

@@ -1,36 +0,0 @@
import cv2
from .builder import ITEMS
from .pipelines import Compose
@ITEMS.register_module()
class Body(object):
def __init__(self, **kwargs):
pipeline = [
dict(type='LoadBodyImageFromFile', body_path=kwargs['body_path']),
# dict(type='ImageShow', key=['body_image', "body_mask"])
]
self.pipeline = Compose(pipeline)
self.result = dict()
def process(self):
self.pipeline(self.result)
pass
def organize(self, layer):
body_layer = dict(priority=0,
name=type(self).__name__.lower(),
image=self.result['body_image'],
image_url=self.result['image_url'],
mask_image=None,
mask_url=None,
sacle=1,
# mask=self.result['body_mask'],
position=(0, 0))
layer.insert(body_layer)
@staticmethod
def show(img):
cv2.imshow('', img)
cv2.waitKey(0)

View File

@@ -1,39 +0,0 @@
from .builder import ITEMS
from .clothing import Clothing
@ITEMS.register_module()
class Bottom(Clothing):
def __init__(self, pipeline, **kwargs):
if pipeline is None:
pipeline = [
dict(type='LoadImageFromFile', path=kwargs['path'], color=kwargs['color'], print_dict=kwargs['print']),
dict(type='KeypointDetection'),
dict(type='ContourDetection'),
# dict(type='Segmentation'),
dict(type='Painting', painting_flag=True),
dict(type='PrintPainting', print_flag=True),
dict(type='Scaling'),
dict(type='Split'),
# dict(type='ImageShow', key=['image', 'mask', 'pattern_image', 'print_image']),
]
kwargs.update(pipeline=pipeline)
super(Bottom, self).__init__(**kwargs)
@ITEMS.register_module()
class Trousers(Bottom):
def __init__(self, pipeline=None, **kwargs):
super(Trousers, self).__init__(pipeline, **kwargs)
@ITEMS.register_module()
class Skirt(Bottom):
def __init__(self, pipeline=None, **kwargs):
super(Skirt, self).__init__(pipeline, **kwargs)
@ITEMS.register_module()
class Bottoms(Bottom):
def __init__(self, pipeline=None, **kwargs):
super(Bottoms, self).__init__(pipeline, **kwargs)

View File

@@ -1,9 +0,0 @@
from mmcv.utils import Registry, build_from_cfg
ITEMS = Registry('item')
PIPELINES = Registry('pipeline')
def build_item(cfg, default_args=None):
item = build_from_cfg(cfg, ITEMS, default_args)
return item

View File

@@ -1,100 +0,0 @@
import cv2
from app.core.config import PRIORITY_DICT
from .builder import ITEMS
from .pipelines import Compose
@ITEMS.register_module()
class Clothing(object):
def __init__(self, pipeline, **kwargs):
self.pipeline = Compose(pipeline)
self.result = dict(name=type(self).__name__.lower(), **kwargs)
def process(self):
self.pipeline(self.result)
def apply_scale(self, img):
scale = self.result['scale']
height, width = img.shape[0: 2]
if len(img.shape) > 2:
height, width = img.shape[0: 2]
scaled_img = cv2.resize(img, (int(width * scale), int(height * scale)), interpolation=cv2.INTER_AREA)
return scaled_img
def organize(self, layer):
start_point = self.calculate_start_point(self.result['keypoint'], self.result['scale'], self.result['clothes_keypoint'], self.result['body_point_test'], self.result["offset"], self.result["resize_scale"])
front_layer = dict(priority=self.result.get("priority", None) if self.result.get("layer_order", False) else PRIORITY_DICT.get(f'{type(self).__name__.lower()}_front', None),
name=f'{type(self).__name__.lower()}_front',
image=self.result["front_image"],
# mask_image=self.result['front_mask_image'],
image_url=self.result['front_image_url'],
mask_url=self.result['mask_url'],
sacle=self.result['scale'],
clothes_keypoint=self.result['clothes_keypoint'],
position=start_point,
resize_scale=self.result["resize_scale"],
mask=cv2.resize(self.result['mask'], self.result["front_image"].size),
gradient_string=self.result['gradient_string'] if 'gradient_string' in self.result.keys() else "",
pattern_image_url=self.result['pattern_image_url'],
pattern_image=self.result['pattern_image']
)
layer.insert(front_layer)
back_layer = dict(priority=-self.result.get("priority", 0) if self.result.get("layer_order", False) else PRIORITY_DICT.get(f'{type(self).__name__.lower()}_back', None),
name=f'{type(self).__name__.lower()}_back',
image=self.result["back_image"],
# mask_image=self.result['back_mask_image'],
image_url=self.result['back_image_url'],
mask_url=self.result['mask_url'],
sacle=self.result['scale'],
clothes_keypoint=self.result['clothes_keypoint'],
position=start_point,
resize_scale=self.result["resize_scale"],
mask=cv2.resize(self.result['mask'], self.result["front_image"].size),
gradient_string=self.result['gradient_string'] if 'gradient_string' in self.result.keys() else "",
pattern_image_url=self.result['pattern_image_url'],
)
layer.insert(back_layer)
@staticmethod
def calculate_start_point(keypoint_type, scale, clothes_point, body_point, offset, resize_scale):
"""
Align left
Args:
keypoint_type: string, "waistband" | "shoulder" | "ear_point"
scale: float
clothes_point: dict{'left': [x1, y1, z1], 'right': [x2, y2, z2]}
body_point: dict, containing keypoint data of body figure
Returns:
start_point: tuple (x', y')
x' = y_body - y1 * scale + offset
y' = x_body - x1 * scale + offset
"""
side_indicator = f'{keypoint_type}_left'
# if keypoint_type == "ear_point":
# start_point = (body_point[side_indicator][1] - int(int(clothes_point[side_indicator].split("_")[1]) * scale),
# body_point[side_indicator][0] - int(int(clothes_point[side_indicator].split("_")[0]) * scale))
# else:
# start_point = (
# int(body_point[side_indicator][1] + offset[1] - int(clothes_point[side_indicator].split("_")[0]) * scale), # y
# int(body_point[side_indicator][0] + offset[0] - int(clothes_point[side_indicator].split("_")[1]) * scale) # x
# )
# milvus_DB_keypoint_cache:
start_point = (
int(body_point[side_indicator][1] + offset[1] - int(clothes_point[side_indicator][0]) * scale), # y
int(body_point[side_indicator][0] + offset[0] - int(clothes_point[side_indicator][1]) * scale) # x
)
# start_point = (
# int(body_point[side_indicator][1] + offset[1] - int(clothes_point[side_indicator].split("_")[0]) * scale), # y
# int(body_point[side_indicator][0] + offset[0] - int(clothes_point[side_indicator].split("_")[1]) * scale) # x
# )
return start_point

View File

@@ -1,19 +0,0 @@
from .compose import Compose
from .loading import LoadImageFromFile, LoadBodyImageFromFile, ImageShow
from .keypoints import KeypointDetection
from .segmentation import Segmentation
from .painting import Painting, PrintPainting
from .scale import Scaling
from .contour_detection import ContourDetection
from .split import Split
__all__ = [
'Compose',
'LoadImageFromFile', 'LoadBodyImageFromFile', 'ImageShow',
'KeypointDetection',
'Segmentation',
'Painting', 'PrintPainting',
'Scaling',
'ContourDetection',
'split',
]

View File

@@ -1,36 +0,0 @@
import collections
from mmcv.utils import build_from_cfg
from ..builder import PIPELINES
@PIPELINES.register_module()
class Compose(object):
def __init__(self, transforms):
assert isinstance(transforms, collections.abc.Sequence)
self.transforms = []
for transform in transforms:
if isinstance(transform, dict):
transform = build_from_cfg(transform, PIPELINES)
self.transforms.append(transform)
elif callable(transform):
self.transforms.append(transform)
else:
raise TypeError('transform must be callable or a dict')
def __call__(self, data):
"""Call function to apply transforms sequentially.
Args:
data (dict): A result dict contains the data to transform.
Returns:
dict: Transformed data.
"""
for t in self.transforms:
data = t(data)
if data is None:
return None
return data

View File

@@ -1,59 +0,0 @@
import cv2
import numpy as np
from ..builder import PIPELINES
@PIPELINES.register_module()
class ContourDetection(object):
def __init__(self):
# logging.info("ContourDetection run ")
pass
# @ RunTime
def __call__(self, result):
# shoe diff
if result['name'] == 'shoes':
Contour = self.get_contours(result['image'])
Mask = np.zeros(result['image'].shape[:2], np.uint8)
for i in range(2):
Max_contour = Contour[i]
Epsilon = 0.001 * cv2.arcLength(Max_contour, True)
Approx = cv2.approxPolyDP(Max_contour, Epsilon, True)
cv2.drawContours(Mask, [Approx], -1, 255, -1)
if result['pre_mask'] is None:
result['mask'] = Mask
else:
result['mask'] = cv2.bitwise_and(Mask, result['pre_mask'])
else:
Contour = self.get_contours(result['image'])
Mask = np.zeros(result['image'].shape[:2], np.uint8)
if len(Contour):
Max_contour = Contour[0]
Epsilon = 0.001 * cv2.arcLength(Max_contour, True)
Approx = cv2.approxPolyDP(Max_contour, Epsilon, True)
cv2.drawContours(Mask, [Approx], -1, 255, -1)
else:
Mask = np.ones(result['image'].shape[:2], np.uint8) * 255
# TODO 修复部分图片出现透明的情况 下版本上线
# img2gray = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY)
# ret, Mask = cv2.threshold(img2gray, 126, 255, cv2.THRESH_BINARY)
# Mask = cv2.bitwise_not(Mask)
if result['pre_mask'] is None:
result['mask'] = Mask
else:
result['mask'] = cv2.bitwise_and(Mask, result['pre_mask'])
result['front_mask'] = result['mask']
result['back_mask'] = result['mask']
return result
@staticmethod
def get_contours(image):
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
Edge = cv2.Canny(gray, 10, 150)
kernel = np.ones((5, 5), np.uint8)
Edge = cv2.dilate(Edge, kernel=kernel, iterations=1)
Edge = cv2.erode(Edge, kernel=kernel, iterations=1)
Contour, _ = cv2.findContours(Edge, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
Contour = sorted(Contour, key=cv2.contourArea, reverse=True)
return Contour

View File

@@ -1,140 +0,0 @@
import logging
import time
import numpy as np
from pymilvus import MilvusClient
from app.core.config import *
from app.service.utils.decorator import RunTime, ClassCallRunTime
from ..builder import PIPELINES
from ...utils.design_ensemble import get_keypoint_result
@PIPELINES.register_module()
class KeypointDetection(object):
"""
path here: abstract path
"""
# def __init__(self):
# self.client = MilvusClient(
# uri="http://10.1.1.240:19530",
# token="root:Milvus",
# db_name=MILVUS_ALIAS
# )
# def __del__(self):
# start_time = time.time()
# self.client.close()
# print(f"client close time : {time.time() - start_time}")
# @ClassCallRunTime
def __call__(self, result):
# logging.info("KeypointDetection run ")
if result['name'] in ['blouse', 'skirt', 'dress', 'outwear', 'trousers', 'tops', 'bottoms']: # 查询是否有数据 且类别相同 相同则直接读 不同则推理后更新
# result['clothes_keypoint'] = self.infer_keypoint_result(result)
site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
# keypoint_cache = search_keypoint_cache(result["image_id"], site)
keypoint_cache = self.keypoint_cache(result, site)
# 取消向量查询 直接过模型推理
# keypoint_cache = False
if keypoint_cache is False:
keypoint_infer_result, site = self.infer_keypoint_result(result)
result['clothes_keypoint'] = self.save_keypoint_cache(result["image_id"], keypoint_infer_result, site)
else:
result['clothes_keypoint'] = keypoint_cache
return result
@staticmethod
def infer_keypoint_result(result):
site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
start_time = time.time()
keypoint_infer_result = get_keypoint_result(result["image"], site) # 推理结果
# logging.info(f"infer keypoint time : {time.time() - start_time}")
return keypoint_infer_result, site
@staticmethod
# @ RunTime
def save_keypoint_cache(keypoint_id, cache, site):
if site == "down":
zeros = np.zeros(20, dtype=int)
result = np.concatenate([zeros, cache.flatten()])
else:
zeros = np.zeros(4, dtype=int)
result = np.concatenate([cache.flatten(), zeros])
# 取消向量保存 直接拿结果
data = [
{"keypoint_id": keypoint_id,
"keypoint_site": site,
"keypoint_vector": result.tolist()
}
]
try:
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
# start_time = time.time()
res = client.upsert(collection_name=MILVUS_TABLE_KEYPOINT, data=data)
# logging.info(f"save keypoint time : {time.time() - start_time}")
client.close()
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
except Exception as e:
logging.info(f"save keypoint cache milvus error : {e}")
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
@staticmethod
def update_keypoint_cache(keypoint_id, infer_result, search_result, site):
if site == "up":
# 需要的是up 即推理出来的是up 那么查询的就是down
result = np.concatenate([infer_result.flatten(), search_result[-4:]])
else:
# 需要的是down 即推理出来的是down 那么查询的就是up
result = np.concatenate([search_result[:20], infer_result.flatten()])
data = [
{"keypoint_id": keypoint_id,
"keypoint_site": "all",
"keypoint_vector": result.tolist()
}
]
try:
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
# connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT)
start_time = time.time()
# collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection.
# mr = collection.upsert(data)
client.upsert(
collection_name=MILVUS_TABLE_KEYPOINT,
data=data
)
# logging.info(f"save keypoint time : {time.time() - start_time}")
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
except Exception as e:
logging.info(f"save keypoint cache milvus error : {e}")
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
# @ RunTime
def keypoint_cache(self, result, site):
try:
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
keypoint_id = result['image_id']
res = client.query(
collection_name=MILVUS_TABLE_KEYPOINT,
# ids=[keypoint_id],
filter=f"keypoint_id == {keypoint_id}",
output_fields=['keypoint_vector', 'keypoint_site']
)
if len(res) == 0:
# 没有结果 直接推理拿结果 并保存
keypoint_infer_result, site = self.infer_keypoint_result(result)
return self.save_keypoint_cache(result['image_id'], keypoint_infer_result, site)
elif res[0]["keypoint_site"] == "all" or res[0]["keypoint_site"] == site:
# 需要的类型和查询的类型一致或者查询的类型为all 则直接返回查询的结果
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, np.array(res[0]['keypoint_vector']).astype(int).reshape(12, 2).tolist()))
elif res[0]["keypoint_site"] != site:
# 需要的类型和查询到的不一致则更新类型为all
keypoint_infer_result, site = self.infer_keypoint_result(result)
return self.update_keypoint_cache(result["image_id"], keypoint_infer_result, res[0]['keypoint_vector'], site)
except Exception as e:
logging.info(f"search keypoint cache milvus error {e}")
return False

View File

@@ -1,134 +0,0 @@
import cv2
from app.service.utils.oss_client import oss_get_image
from ..builder import PIPELINES
@PIPELINES.register_module()
class LoadImageFromFile(object):
def __init__(self, path, color=None, print_dict=None):
self.path = path
self.color = color
self.print_dict = print_dict
# self.minio_client = Minio(f"{MINIO_URL}", access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
# @ClassCallRunTime
def __call__(self, result):
result['image'], result['pre_mask'] = self.read_image(self.path)
result['gray'] = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY)
result['keypoint'] = self.get_keypoint(result['name'])
result['path'] = self.path
result['img_shape'] = result['image'].shape
result['ori_shape'] = result['image'].shape
result['color'] = self.color if self.color is not None else None
result['print_dict'] = self.print_dict
return result
@staticmethod
def get_keypoint(name):
if name == 'blouse' or name == 'outwear' or name == 'dress' or name == 'tops':
keypoint = 'shoulder'
elif name == 'trousers' or name == 'skirt' or name == 'bottoms':
keypoint = 'waistband'
elif name == 'bag':
keypoint = 'hand_point'
elif name == 'shoes':
keypoint = 'toe'
elif name == 'hairstyle':
keypoint = 'head_point'
elif name == 'earring':
keypoint = 'ear_point'
else:
raise KeyError(f"{name} does not belong to item category list: blouse, outwear, dress, trousers, skirt, "
f"bag, shoes, hairstyle, earring.")
return keypoint
@staticmethod
def read_image(image_path):
image_mask = None
image = oss_get_image(bucket=image_path.split("/", 1)[0], object_name=image_path.split("/", 1)[1], data_type="cv2")
if len(image.shape) == 2:
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
if image.shape[2] == 4: # 如果是四通道 mask
image_mask = image[:, :, 3]
image = image[:, :, :3]
if image.shape[:2] <= (50, 50):
# 计算新尺寸
new_size = (image.shape[1] * 2, image.shape[0] * 2)
# 调整大小
image = cv2.resize(image, new_size, interpolation=cv2.INTER_LINEAR)
return image, image_mask
@PIPELINES.register_module()
class LoadBodyImageFromFile(object):
def __init__(self, body_path):
self.body_path = body_path
# self.minioClient = Minio(f"{MINIO_URL}", access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
# response = self.minioClient.get_object("aida-mannequins", "model_1693218345.2714431.png")
# @ RunTime
def __call__(self, result):
result["image_url"] = result['body_path'] = self.body_path
result["name"] = "mannequin"
# if not result['image_url'].lower().endswith(".png"):
# bucket = self.body_path.split("/", 1)[0]
# object_name = self.body_path.split("/", 1)[1]
# new_object_name = f'{object_name[:object_name.rfind(".")]}.png'
# image = self.minioClient.get_object(bucket, object_name)
# image = Image.open(io.BytesIO(image.data))
# image = image.convert("RGBA")
# data = image.getdata()
# #
# new_data = []
# for item in data:
# if item[0] >= 230 and item[1] >= 230 and item[2] >= 230:
# new_data.append((255, 255, 255, 0))
# else:
# new_data.append(item)
# image.putdata(new_data)
# image_data = io.BytesIO()
# image.save(image_data, format='PNG')
# image_data.seek(0)
# image_bytes = image_data.read()
# image_path = f"{bucket}/{self.minioClient.put_object(bucket, new_object_name, io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
# self.body_path = image_path
# result["image_url"] = result['body_path'] = self.body_path
# response = self.minioClient.get_object(self.body_path.split("/", 1)[0], self.body_path.split("/", 1)[1])
# put_image_time = time.time()
# result['body_image'] = Image.open(io.BytesIO(response.read()))
result['body_image'] = oss_get_image(bucket=self.body_path.split("/", 1)[0], object_name=self.body_path.split("/", 1)[1], data_type="PIL")
# logging.info(f"Image.open time is : {time.time() - put_image_time}")
return result
@PIPELINES.register_module()
class ImageShow(object):
def __init__(self, key):
self.key = key
# @ RunTime
def __call__(self, result):
import matplotlib.pyplot as plt
if isinstance(self.key, list):
for key in self.key:
plt.imshow(result[key])
plt.title(key)
plt.show()
elif isinstance(self.key, str):
img = self._resize_img(result[self.key])
cv2.imshow(self.key, img)
cv2.waitKey(0)
else:
raise TypeError(f'key should be string but got type {type(self.key)}.')
return result
@staticmethod
def _resize_img(img):
shape = img.shape
if shape[0] > 400 or shape[1] > 400:
ratio = min(400 / shape[0], 400 / shape[1])
img = cv2.resize(img, (int(ratio * shape[1]), int(ratio * shape[0])))
return img

View File

@@ -1,605 +0,0 @@
import logging
import random
import cv2
import numpy as np
from PIL import Image
from app.service.utils.oss_client import oss_get_image
from ..builder import PIPELINES
logger = logging.getLogger()
@PIPELINES.register_module()
class Painting(object):
def __init__(self, painting_flag=True):
self.painting_flag = painting_flag
# @ClassCallRunTime
def __call__(self, result):
if result['name'] not in ['hairstyle', 'earring'] and self.painting_flag and result['color'] != 'none':
dim_image_h, dim_image_w = result['image'].shape[0:2]
if "gradient" in result.keys() and result['gradient'] != "":
bucket_name = result['gradient'].split('/')[0]
object_name = result['gradient'][result['gradient'].find('/') + 1:]
pattern = self.get_gradient(bucket_name=bucket_name, object_name=object_name)
resize_pattern = cv2.resize(pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA)
else:
pattern = self.get_pattern(result['color'])
resize_pattern = cv2.resize(pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA)
closed_mo = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
get_image_fir = resize_pattern * (closed_mo / 255) * (gray_mo / 255)
result['pattern_image'] = get_image_fir.astype(np.uint8)
result['final_image'] = result['pattern_image']
canvas = np.full_like(result['final_image'], 255)
temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2)
tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8)
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
result['single_image'] = cv2.add(tmp1, tmp2)
result['alpha'] = 100 / 255.0
else:
closed_mo = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
get_image_fir = result['image'] * (closed_mo / 255)
result['pattern_image'] = get_image_fir.astype(np.uint8)
result['final_image'] = result['pattern_image']
return result
@staticmethod
def get_gradient(bucket_name, object_name):
# image_data = minio_client.get_object(bucket_name, object_name)
# image_data = s3.get_object(Bucket=bucket_name, Key=object_name)['Body']
# 从数据流中读取图像
# image_bytes = image_data.read()
# 将图像数据转换为numpy数组
# image_array = np.asarray(bytearray(image_bytes), dtype=np.uint8)
# 使用OpenCV解码图像数组
# image = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
image = oss_get_image(bucket=bucket_name, object_name=object_name, data_type="cv2")
if image.shape[2] == 4:
image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR)
return image
@staticmethod
def crop_image(image, image_size_h, image_size_w):
x_offset = np.random.randint(low=0, high=int(image_size_h / 5) - 6)
y_offset = np.random.randint(low=0, high=int(image_size_w / 5) - 6)
image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w, :]
return image
@staticmethod
def get_pattern(single_color):
if single_color is None:
raise False
R, G, B = single_color.split(' ')
pattern = np.zeros([1, 1, 3], np.uint8)
pattern[0, 0, 0] = int(B)
pattern[0, 0, 1] = int(G)
pattern[0, 0, 2] = int(R)
return pattern
@PIPELINES.register_module()
class PrintPainting(object):
def __init__(self, print_flag=True):
self.print_flag = print_flag
# @ClassCallRunTime
def __call__(self, result):
single_print = result['print']['single']
overall_print = result['print']['overall']
element_print = result['print']['element']
result['single_image'] = None
result['print_image'] = None
if overall_print['print_path_list']:
painting_dict = {'dim_image_h': result['pattern_image'].shape[0], 'dim_image_w': result['pattern_image'].shape[1]}
result['print_image'] = result['pattern_image']
if "print_angle_list" in overall_print.keys() and overall_print['print_angle_list'][0] != 0:
painting_dict = self.painting_collection(painting_dict, overall_print, print_trigger=True)
painting_dict['tile_print'] = self.rotate_crop_image(img=painting_dict['tile_print'], angle=-overall_print['print_angle_list'][0], crop=True)
painting_dict['mask_inv_print'] = self.rotate_crop_image(img=painting_dict['mask_inv_print'], angle=-overall_print['print_angle_list'][0], crop=True)
# resize 到sketch大小
painting_dict['tile_print'] = self.resize_and_crop(img=painting_dict['tile_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h'])
painting_dict['mask_inv_print'] = self.resize_and_crop(img=painting_dict['mask_inv_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h'])
else:
painting_dict = self.painting_collection(painting_dict, overall_print, print_trigger=True, is_single=False)
result['print_image'] = self.printpaint(result, painting_dict, print_=True)
result['single_image'] = result['final_image'] = result['pattern_image'] = result['print_image']
if single_print['print_path_list']:
print_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
mask_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
for i in range(len(single_print['print_path_list'])):
image, image_mode = self.read_image(single_print['print_path_list'][i])
if image_mode == "RGBA":
new_size = (int(image.width * single_print['print_scale_list'][i]), int(image.height * single_print['print_scale_list'][i]))
mask = image.split()[3]
resized_source = image.resize(new_size)
resized_source_mask = mask.resize(new_size)
rotated_resized_source = resized_source.rotate(-single_print['print_angle_list'][i])
rotated_resized_source_mask = resized_source_mask.rotate(-single_print['print_angle_list'][i])
source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB))
source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB))
source_image_pil.paste(rotated_resized_source, (int(single_print['location'][i][0]), int(single_print['location'][i][1])), rotated_resized_source)
source_image_pil_mask.paste(rotated_resized_source_mask, (int(single_print['location'][i][0]), int(single_print['location'][i][1])), rotated_resized_source_mask)
print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR)
mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR)
ret, mask_background = cv2.threshold(mask_background, 124, 255, cv2.THRESH_BINARY)
else:
mask = self.get_mask_inv(image)
mask = np.expand_dims(mask, axis=2)
mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
mask = cv2.bitwise_not(mask)
# 旋转后的坐标需要重新算
rotate_mask, _ = self.img_rotate(mask, single_print['print_angle_list'][i], single_print['print_scale_list'][i])
rotate_image, rotated_new_size = self.img_rotate(image, single_print['print_angle_list'][i], single_print['print_scale_list'][i])
# x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2)
x, y = int(single_print['location'][i][0] - rotated_new_size[0]), int(single_print['location'][i][1] - rotated_new_size[1])
image_x = print_background.shape[1]
image_y = print_background.shape[0]
print_x = rotate_image.shape[1]
print_y = rotate_image.shape[0]
# 有bug
# if x + print_x > image_x:
# rotate_image = rotate_image[:, :x + print_x - image_x]
# rotate_mask = rotate_mask[:, :x + print_x - image_x]
# #
# if y + print_y > image_y:
# rotate_image = rotate_image[:y + print_y - image_y]
# rotate_mask = rotate_mask[:y + print_y - image_y]
# 不能是并行
# 当前第一轮的if 108以及115是判断有没有过下界和右界。第二轮的是判断左上有没有超出。 如果这个样子的话先裁了右边再左移region就会有问题
# 先挪 再判断 最后裁剪
# 如果print旋转了 或者 print贴边了 则需要判断 判断左界和上界是否小于0
if x <= 0:
rotate_image = rotate_image[:, -x:]
rotate_mask = rotate_mask[:, -x:]
start_x = x = 0
else:
start_x = x
if y <= 0:
rotate_image = rotate_image[-y:, :]
rotate_mask = rotate_mask[-y:, :]
start_y = y = 0
else:
start_y = y
# ------------------
# 如果print-size大于image-size 则需要裁剪print
if x + print_x > image_x:
rotate_image = rotate_image[:, :image_x - x]
rotate_mask = rotate_mask[:, :image_x - x]
if y + print_y > image_y:
rotate_image = rotate_image[:image_y - y, :]
rotate_mask = rotate_mask[:image_y - y, :]
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = cv2.bitwise_xor(mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]], rotate_mask)
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = cv2.add(print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]], rotate_image)
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = rotate_mask
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image
mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x)
print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x)
# gray_image = cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY)
# print_background = cv2.bitwise_and(print_background, print_background, mask=gray_image)
print_mask = cv2.bitwise_and(result['mask'], cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY))
img_fg = cv2.bitwise_or(print_background, print_background, mask=print_mask)
img_bg = cv2.bitwise_and(result['pattern_image'], result['pattern_image'], mask=cv2.bitwise_not(print_mask))
mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2)
gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8)
result['final_image'] = cv2.add(img_bg, img_fg)
canvas = np.full_like(result['final_image'], 255)
temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2)
tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8)
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
result['single_image'] = cv2.add(tmp1, tmp2)
if element_print['element_path_list']:
print_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8)
mask_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8)
for i in range(len(element_print['element_path_list'])):
image, image_mode = self.read_image(element_print['element_path_list'][i])
if image_mode == "RGBA":
new_size = (int(image.width * element_print['element_scale_list'][i]), int(image.height * element_print['element_scale_list'][i]))
mask = image.split()[3]
resized_source = image.resize(new_size)
resized_source_mask = mask.resize(new_size)
rotated_resized_source = resized_source.rotate(-element_print['element_angle_list'][i])
rotated_resized_source_mask = resized_source_mask.rotate(-element_print['element_angle_list'][i])
source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB))
source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB))
source_image_pil.paste(rotated_resized_source, (int(element_print['location'][i][0]), int(element_print['location'][i][1])), rotated_resized_source)
source_image_pil_mask.paste(rotated_resized_source_mask, (int(element_print['location'][i][0]), int(element_print['location'][i][1])), rotated_resized_source_mask)
print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR)
mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR)
else:
mask = self.get_mask_inv(image)
mask = np.expand_dims(mask, axis=2)
mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
mask = cv2.bitwise_not(mask)
# 旋转后的坐标需要重新算
rotate_mask, _ = self.img_rotate(mask, element_print['element_angle_list'][i], element_print['element_scale_list'][i])
rotate_image, rotated_new_size = self.img_rotate(image, element_print['element_angle_list'][i], element_print['element_scale_list'][i])
# x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2)
x, y = int(element_print['location'][i][0] - rotated_new_size[0]), int(element_print['location'][i][1] - rotated_new_size[1])
image_x = print_background.shape[1]
image_y = print_background.shape[0]
print_x = rotate_image.shape[1]
print_y = rotate_image.shape[0]
# 有bug
# if x + print_x > image_x:
# rotate_image = rotate_image[:, :x + print_x - image_x]
# rotate_mask = rotate_mask[:, :x + print_x - image_x]
# #
# if y + print_y > image_y:
# rotate_image = rotate_image[:y + print_y - image_y]
# rotate_mask = rotate_mask[:y + print_y - image_y]
# 不能是并行
# 当前第一轮的if 108以及115是判断有没有过下界和右界。第二轮的是判断左上有没有超出。 如果这个样子的话先裁了右边再左移region就会有问题
# 先挪 再判断 最后裁剪
# 如果print旋转了 或者 print贴边了 则需要判断 判断左界和上界是否小于0
if x <= 0:
rotate_image = rotate_image[:, -x:]
rotate_mask = rotate_mask[:, -x:]
start_x = x = 0
else:
start_x = x
if y <= 0:
rotate_image = rotate_image[-y:, :]
rotate_mask = rotate_mask[-y:, :]
start_y = y = 0
else:
start_y = y
# ------------------
# 如果print-size大于image-size 则需要裁剪print
if x + print_x > image_x:
rotate_image = rotate_image[:, :image_x - x]
rotate_mask = rotate_mask[:, :image_x - x]
if y + print_y > image_y:
rotate_image = rotate_image[:image_y - y, :]
rotate_mask = rotate_mask[:image_y - y, :]
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = cv2.bitwise_xor(mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]], rotate_mask)
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = cv2.add(print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]], rotate_image)
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = rotate_mask
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image
mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x)
print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x)
# gray_image = cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY)
# print_background = cv2.bitwise_and(print_background, print_background, mask=gray_image)
print_mask = cv2.bitwise_and(result['mask'], cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY))
img_fg = cv2.bitwise_or(print_background, print_background, mask=print_mask)
# TODO element 丢失信息
three_channel_image = cv2.merge([cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask)])
img_bg = cv2.bitwise_and(result['final_image'], three_channel_image)
# mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2)
# gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
# img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8)
result['final_image'] = cv2.add(img_bg, img_fg)
canvas = np.full_like(result['final_image'], 255)
temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2)
tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8)
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
result['single_image'] = cv2.add(tmp1, tmp2)
return result
@staticmethod
def stack_prin(print_background, pattern_image, rotate_image, start_y, y, start_x, x):
temp_print = np.zeros((pattern_image.shape[0], pattern_image.shape[1], 3), dtype=np.uint8)
temp_print[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image
img2gray = cv2.cvtColor(temp_print, cv2.COLOR_BGR2GRAY)
ret, mask_ = cv2.threshold(img2gray, 1, 255, cv2.THRESH_BINARY)
mask_inv = cv2.bitwise_not(mask_)
img1_bg = cv2.bitwise_and(print_background, print_background, mask=mask_inv)
img2_fg = cv2.bitwise_and(temp_print, temp_print, mask=mask_)
print_background = img1_bg + img2_fg
return print_background
def painting_collection(self, painting_dict, print_dict, print_trigger=False, is_single=False):
if print_trigger:
print_ = self.get_print(print_dict)
painting_dict['Trigger'] = not is_single
painting_dict['location'] = print_['location']
single_mask_inv_print = self.get_mask_inv(print_['image'])
dim_max = max(painting_dict['dim_image_h'], painting_dict['dim_image_w'])
dim_pattern = (int(dim_max * print_['scale'] / 5), int(dim_max * print_['scale'] / 5))
if not is_single:
self.random_seed = random.randint(0, 1000)
# 如果print 模式为overall 且 有角度的话 组合的print为正方形方便裁剪
if "print_angle_list" in print_dict.keys() and print_dict['print_angle_list'][0] != 0:
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], dim_max, dim_max, painting_dict['location'], trigger=True)
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], dim_max, dim_max, painting_dict['location'], trigger=True)
else:
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'], trigger=True)
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'], trigger=True)
else:
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'])
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'])
painting_dict['dim_print_h'], painting_dict['dim_print_w'] = dim_pattern
return painting_dict
def tile_image(self, pattern, dim, scale, dim_image_h, dim_image_w, location, trigger=False):
tile = None
if not trigger:
tile = cv2.resize(pattern, dim, interpolation=cv2.INTER_AREA)
else:
resize_pattern = cv2.resize(pattern, dim, interpolation=cv2.INTER_AREA)
if len(pattern.shape) == 2:
tile = np.tile(resize_pattern, (int((5 + 1) / scale) + 4, int((5 + 1) / scale) + 4))
if len(pattern.shape) == 3:
tile = np.tile(resize_pattern, (int((5 + 1) / scale) + 4, int((5 + 1) / scale) + 4, 1))
tile = self.crop_image(tile, dim_image_h, dim_image_w, location, resize_pattern.shape)
return tile
def get_mask_inv(self, print_):
if print_[0][0][0] == 255 and print_[0][0][1] == 255 and print_[0][0][2] == 255:
bg_color = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)[0][0]
print_tile = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)
bg_l, bg_a, bg_b = bg_color[0], bg_color[1], bg_color[2]
bg_L_high, bg_L_low = self.get_low_high_lab(bg_l, L=True)
bg_a_high, bg_a_low = self.get_low_high_lab(bg_a)
bg_b_high, bg_b_low = self.get_low_high_lab(bg_b)
lower = np.array([bg_L_low, bg_a_low, bg_b_low])
upper = np.array([bg_L_high, bg_a_high, bg_b_high])
mask_inv = cv2.inRange(print_tile, lower, upper)
return mask_inv
else:
# bg_color = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)[0][0]
# print_tile = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)
# bg_l, bg_a, bg_b = bg_color[0], bg_color[1], bg_color[2]
# bg_L_high, bg_L_low = self.get_low_high_lab(bg_l, L=True)
# bg_a_high, bg_a_low = self.get_low_high_lab(bg_a)
# bg_b_high, bg_b_low = self.get_low_high_lab(bg_b)
# lower = np.array([bg_L_low, bg_a_low, bg_b_low])
# upper = np.array([bg_L_high, bg_a_high, bg_b_high])
# print_tile = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)
# mask_inv = cv2.cvtColor(print_tile, cv2.COLOR_BGR2GRAY)
# mask_inv = cv2.cvtColor(print_, cv2.COLOR_BGR2GRAY)
mask_inv = np.zeros(print_.shape[:2], dtype=np.uint8)
return mask_inv
@staticmethod
def printpaint(result, painting_dict, print_=False):
if print_ and painting_dict['Trigger']:
print_mask = cv2.bitwise_and(result['mask'], cv2.bitwise_not(painting_dict['mask_inv_print']))
img_fg = cv2.bitwise_and(painting_dict['tile_print'], painting_dict['tile_print'], mask=print_mask)
else:
print_mask = result['mask']
img_fg = result['final_image']
if print_ and not painting_dict['Trigger']:
index_ = None
try:
index_ = len(painting_dict['location'])
except:
assert f'there must be parameter of location if choose IfSingle'
for i in range(index_):
start_h, start_w = int(painting_dict['location'][i][1]), int(painting_dict['location'][i][0])
length_h = min(start_h + painting_dict['dim_print_h'], img_fg.shape[0])
length_w = min(start_w + painting_dict['dim_print_w'], img_fg.shape[1])
change_region = img_fg[start_h: length_h, start_w: length_w, :]
# problem in change_mask
change_mask = print_mask[start_h: length_h, start_w: length_w]
# get real part into change mask
_, change_mask = cv2.threshold(change_mask, 220, 255, cv2.THRESH_BINARY)
mask = cv2.bitwise_not(painting_dict['mask_inv_print'])
img_fg[start_h:start_h + painting_dict['dim_print_h'], start_w:start_w + painting_dict['dim_print_w'], :] = change_region
clothes_mask_print = cv2.bitwise_not(print_mask)
img_bg = cv2.bitwise_and(result['pattern_image'], result['pattern_image'], mask=clothes_mask_print)
mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2)
gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8)
print_image = cv2.add(img_bg, img_fg)
return print_image
@staticmethod
def get_print(print_dict):
if 'print_scale_list' not in print_dict.keys() or print_dict['print_scale_list'][0] < 0.3:
print_dict['scale'] = 0.3
else:
print_dict['scale'] = print_dict['print_scale_list'][0]
bucket_name = print_dict['print_path_list'][0].split("/", 1)[0]
object_name = print_dict['print_path_list'][0].split("/", 1)[1]
image = oss_get_image(bucket=bucket_name, object_name=object_name, data_type="PIL")
# 判断图片格式如果是RGBA 则贴在一张纯白图片上 防止透明转黑
if image.mode == "RGBA":
new_background = Image.new('RGB', image.size, (255, 255, 255))
new_background.paste(image, mask=image.split()[3])
image = new_background
print_dict['image'] = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
return print_dict
def crop_image(self, image, image_size_h, image_size_w, location, print_shape):
print_w = print_shape[1]
print_h = print_shape[0]
random.seed(self.random_seed)
# logging.info(f'overall print location : {location}')
# x_offset = random.randint(0, image.shape[0] - image_size_h)
# y_offset = random.randint(0, image.shape[1] - image_size_w)
# 1.拿到偏移量后和resize后的print宽高取余 得到真正偏移量
x_offset = print_w - int(location[0][1] % print_w)
y_offset = print_w - int(location[0][0] % print_h)
# y_offset = int(location[0][0])
# x_offset = int(location[0][1])
if len(image.shape) == 2:
image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w]
elif len(image.shape) == 3:
image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w, :]
return image
@staticmethod
def get_low_high_lab(Lab_value, L=False):
if L:
high = Lab_value + 30 if Lab_value + 30 < 255 else 255
low = Lab_value - 30 if Lab_value - 30 > 0 else 0
else:
high = Lab_value + 30 if Lab_value + 30 < 255 else 255
low = Lab_value - 30 if Lab_value - 30 > 0 else 0
return high, low
@staticmethod
def img_rotate(image, angel, scale):
"""顺时针旋转图像任意角度
Args:
image (np.array): [原始图像]
angel (float): [逆时针旋转的角度]
Returns:
[array]: [旋转后的图像]
"""
h, w = image.shape[:2]
center = (w // 2, h // 2)
# if type(angel) is not int:
# angel = 0
M = cv2.getRotationMatrix2D(center, -angel, scale)
# 调整旋转后的图像长宽
rotated_h = int((w * np.abs(M[0, 1]) + (h * np.abs(M[0, 0]))))
rotated_w = int((h * np.abs(M[0, 1]) + (w * np.abs(M[0, 0]))))
M[0, 2] += (rotated_w - w) // 2
M[1, 2] += (rotated_h - h) // 2
# 旋转图像
rotated_img = cv2.warpAffine(image, M, (rotated_w, rotated_h))
return rotated_img, ((rotated_img.shape[1] - image.shape[1] * scale) // 2, (rotated_img.shape[0] - image.shape[0] * scale) // 2)
# return rotated_img, (0, 0)
@staticmethod
def rotate_crop_image(img, angle, crop):
"""
angle: 旋转的角度
crop: 是否需要进行裁剪,布尔向量
"""
crop_image = lambda img, x0, y0, w, h: img[y0:y0 + h, x0:x0 + w]
w, h = img.shape[:2]
# 旋转角度的周期是360°
angle %= 360
# 计算仿射变换矩阵
M_rotation = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1)
# 得到旋转后的图像
img_rotated = cv2.warpAffine(img, M_rotation, (w, h))
# 如果需要去除黑边
if crop:
# 裁剪角度的等效周期是180°
angle_crop = angle % 180
if angle > 90:
angle_crop = 180 - angle_crop
# 转化角度为弧度
theta = angle_crop * np.pi / 180
# 计算高宽比
hw_ratio = float(h) / float(w)
# 计算裁剪边长系数的分子项
tan_theta = np.tan(theta)
numerator = np.cos(theta) + np.sin(theta) * np.tan(theta)
# 计算分母中和高宽比相关的项
r = hw_ratio if h > w else 1 / hw_ratio
# 计算分母项
denominator = r * tan_theta + 1
# 最终的边长系数
crop_mult = numerator / denominator
# 得到裁剪区域
w_crop = int(crop_mult * w)
h_crop = int(crop_mult * h)
x0 = int((w - w_crop) / 2)
y0 = int((h - h_crop) / 2)
img_rotated = crop_image(img_rotated, x0, y0, w_crop, h_crop)
return img_rotated
@staticmethod
def read_image(image_url):
image = oss_get_image(bucket=image_url.split("/", 1)[0], object_name=image_url.split("/", 1)[1], data_type="cv2")
if image.shape[2] == 4:
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
image = Image.fromarray(image_rgb)
image_mode = "RGBA"
else:
image_mode = "RGB"
return image, image_mode
@staticmethod
def resize_and_crop(img, target_width, target_height):
# 获取原始图像的尺寸
original_height, original_width = img.shape[:2]
# 计算目标尺寸的宽高比
target_ratio = target_width / target_height
# 计算原始图像的宽高比
original_ratio = original_width / original_height
# 调整尺寸
if original_ratio > target_ratio:
# 原始图像更宽按高度resize然后裁剪宽度
new_height = target_height
new_width = int(original_width * (target_height / original_height))
resized_img = cv2.resize(img, (new_width, new_height))
# 裁剪宽度
start_x = (new_width - target_width) // 2
cropped_img = resized_img[:, start_x:start_x + target_width]
else:
# 原始图像更高按宽度resize然后裁剪高度
new_width = target_width
new_height = int(original_height * (target_width / original_width))
resized_img = cv2.resize(img, (new_width, new_height))
# 裁剪高度
start_y = (new_height - target_height) // 2
cropped_img = resized_img[start_y:start_y + target_height, :]
return cropped_img

View File

@@ -1,57 +0,0 @@
import math
import cv2
from app.service.utils.decorator import ClassCallRunTime
from ..builder import PIPELINES
@PIPELINES.register_module()
class Scaling(object):
def __init__(self):
pass
# @ClassCallRunTime
def __call__(self, result):
if result['keypoint'] in ['waistband', 'shoulder', 'head_point']:
# milvus_db_keypoint_cache
distance_clo = math.sqrt(
(int(result['clothes_keypoint'][result['keypoint'] + '_left'][0]) - int(result['clothes_keypoint'][result['keypoint'] + '_right'][0])) ** 2
+
(int(result['clothes_keypoint'][result['keypoint'] + '_left'][1]) - int(result['clothes_keypoint'][result['keypoint'] + '_right'][1])) ** 2)
distance_bdy = math.sqrt((int(result['body_point_test'][result['keypoint'] + '_left'][0]) - int(result['body_point_test'][result['keypoint'] + '_right'][0])) ** 2 + 1)
# distance_clo = math.sqrt(
# (int(result['clothes_keypoint'][result['keypoint'] + '_left'].split("_")[0]) - int(result['clothes_keypoint'][result['keypoint'] + '_right'].split("_")[0])) ** 2
# +
# (int(result['clothes_keypoint'][result['keypoint'] + '_left'].split("_")[1]) - int(result['clothes_keypoint'][result['keypoint'] + '_right'].split("_")[1])) ** 2)
#
# distance_bdy = math.sqrt((int(result['body_point_test'][result['keypoint'] + '_left'][0]) - int(result['body_point_test'][result['keypoint'] + '_right'][0])) ** 2 + 1)
if distance_clo == 0:
result['scale'] = 1
else:
result['scale'] = distance_bdy / distance_clo
elif result['keypoint'] == 'toe':
distance_bdy = math.sqrt(
(int(result['body_point_test']['foot_length'][0]) - int(result['body_point_test']['foot_length'][2])) ** 2
+
(int(result['body_point_test']['foot_length'][1]) - int(result['body_point_test']['foot_length'][3])) ** 2
)
Blur = cv2.GaussianBlur(result['gray'], (3, 3), 0)
Edge = cv2.Canny(Blur, 10, 200)
Edge = cv2.dilate(Edge, None)
Edge = cv2.erode(Edge, None)
Contour, _ = cv2.findContours(Edge, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
Contours = sorted(Contour, key=cv2.contourArea, reverse=True)
Max_contour = Contours[0]
x, y, w, h = cv2.boundingRect(Max_contour)
width = w
distance_clo = width
result['scale'] = distance_bdy / distance_clo
elif result['keypoint'] == 'hand_point':
result['scale'] = result['scale_bag']
elif result['keypoint'] == 'ear_point':
result['scale'] = result['scale_earrings']
return result

View File

@@ -1,71 +0,0 @@
import logging
import os
import cv2
import numpy as np
from app.core.config import SEG_CACHE_PATH
from app.service.utils.decorator import ClassCallRunTime
from app.service.utils.oss_client import oss_get_image
from ..builder import PIPELINES
from ...utils.design_ensemble import get_seg_result
logger = logging.getLogger()
@PIPELINES.register_module()
class Segmentation(object):
@ClassCallRunTime
def __call__(self, result):
if "seg_mask_url" in result.keys() and result['seg_mask_url'] != "":
seg_mask = oss_get_image(bucket=result['seg_mask_url'].split('/')[0], object_name=result['seg_mask_url'][result['seg_mask_url'].find('/') + 1:], data_type="cv2")
seg_mask = cv2.resize(seg_mask, (result['img_shape'][1], result['img_shape'][0]), interpolation=cv2.INTER_NEAREST)
# 转换颜色空间为 RGBOpenCV 默认是 BGR
image_rgb = cv2.cvtColor(seg_mask, cv2.COLOR_BGR2RGB)
r, g, b = cv2.split(image_rgb)
red_mask = r > g
green_mask = g > r
# 创建红色和绿色掩码
result['front_mask'] = np.array(red_mask, dtype=np.uint8) * 255
result['back_mask'] = np.array(green_mask, dtype=np.uint8) * 255
result['mask'] = result['front_mask'] + result['back_mask']
else:
# 本地查询seg 缓存是否存在
_, seg_result = self.load_seg_result(result["image_id"])
result['seg_result'] = seg_result
if not _:
# 推理获得seg 结果
seg_result = get_seg_result(result["image_id"], result['image'])[0]
self.save_seg_result(seg_result, result['image_id'])
# 处理前片后片
temp_front = seg_result == 1.0
result['front_mask'] = (255 * (temp_front + 0).astype(np.uint8))
temp_back = seg_result == 2.0
result['back_mask'] = (255 * (temp_back + 0).astype(np.uint8))
result['mask'] = result['front_mask'] + result['back_mask']
return result
@staticmethod
def save_seg_result(seg_result, image_id):
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
try:
np.save(file_path, seg_result)
logger.debug(f"保存成功 {os.path.abspath(file_path)}")
except Exception as e:
logger.error(f"保存失败: {e}")
@staticmethod
def load_seg_result(image_id):
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
try:
seg_result = np.load(file_path)
return True, seg_result
except FileNotFoundError:
# logger.warning("文件不存在")
return False, None
except Exception as e:
logger.error(f"加载失败: {e}")
return False, None

View File

@@ -1,79 +0,0 @@
import io
import logging
import cv2
import numpy as np
from PIL import Image
from cv2 import cvtColor, COLOR_BGR2RGBA
from app.core.config import AIDA_CLOTHING
from app.service.utils.generate_uuid import generate_uuid
from app.service.utils.oss_client import oss_upload_image
from ..builder import PIPELINES
from ...utils.conversion_image import rgb_to_rgba
from ...utils.upload_image import upload_png_mask
@PIPELINES.register_module()
class Split(object):
"""
Split image into front and back layer according to the segmentation result
"""
# @ClassCallRunTime
# KNet
def __call__(self, result):
try:
if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms'):
front_mask = result['front_mask']
back_mask = result['back_mask']
rgba_image = rgb_to_rgba(result['final_image'], front_mask + back_mask)
new_size = (int(rgba_image.shape[1] * result["scale"] * result["resize_scale"][0]), int(rgba_image.shape[0] * result["scale"] * result["resize_scale"][1]))
rgba_image = cv2.resize(rgba_image, new_size)
result_front_image = np.zeros_like(rgba_image)
front_mask = cv2.resize(front_mask, new_size)
result_front_image[front_mask != 0] = rgba_image[front_mask != 0]
result_front_image_pil = Image.fromarray(cvtColor(result_front_image, COLOR_BGR2RGBA))
result['front_image'], result["front_image_url"], _ = upload_png_mask(result_front_image_pil, f'{generate_uuid()}', mask=None)
height, width = front_mask.shape
mask_image = np.zeros((height, width, 3))
mask_image[front_mask != 0] = [0, 0, 255]
if result["name"] in ('blouse', 'dress', 'outwear', 'tops'):
result_back_image = np.zeros_like(rgba_image)
back_mask = cv2.resize(back_mask, new_size)
result_back_image[back_mask != 0] = rgba_image[back_mask != 0]
result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA))
result['back_image'], result["back_image_url"], _ = upload_png_mask(result_back_image_pil, f'{generate_uuid()}', mask=None)
mask_image[back_mask != 0] = [0, 255, 0]
rbga_mask = rgb_to_rgba(mask_image, front_mask + back_mask)
mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
image_data = io.BytesIO()
mask_pil.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
req = oss_upload_image(bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
result['mask_url'] = req.bucket_name + "/" + req.object_name
else:
rbga_mask = rgb_to_rgba(mask_image, front_mask)
mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
image_data = io.BytesIO()
mask_pil.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
req = oss_upload_image(bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
result['mask_url'] = req.bucket_name + "/" + req.object_name
result['back_image'] = None
result["back_image_url"] = None
# result["back_mask_url"] = None
# result['back_mask_image'] = None
# 创建中间图层
result_pattern_image_rgba = rgb_to_rgba(result['pattern_image'], result['mask'])
result_pattern_image_pil = Image.fromarray(cvtColor(result_pattern_image_rgba, COLOR_BGR2RGBA))
result['pattern_image'], result['pattern_image_url'], _ = upload_png_mask(result_pattern_image_pil, f'{generate_uuid()}')
return result
except Exception as e:
logging.warning(f"split runtime exception : {e} image_id : {result['image_id']}")

View File

@@ -1,121 +0,0 @@
import cv2
import numpy as np
from PIL import Image
from .builder import ITEMS
from .clothing import Clothing
from ..utils.conversion_image import rgb_to_rgba
from ..utils.upload_image import upload_png_mask
from ...utils.generate_uuid import generate_uuid
@ITEMS.register_module()
class Shoes(Clothing):
# TODO location of shoes has little mismatch
def __init__(self, **kwargs):
pipeline = [
dict(type='LoadImageFromFile', path=kwargs['path'], color=kwargs['color']),
dict(type='KeypointDetection'),
dict(type='ContourDetection'),
dict(type='Painting'),
dict(type='Scaling'),
dict(type='Split'),
# dict(type='ImageShow', key=['image', 'mask', 'pattern_image']),
]
kwargs.update(pipeline=pipeline)
super(Shoes, self).__init__(**kwargs)
def organize(self, layer):
left_shoe_mask, right_shoe_mask = self.cut()
left_layer = dict(name=f'{type(self).__name__.lower()}_left',
image=self.result['shoes_left'],
image_url=self.result['left_image_url'],
mask_url=self.result['left_mask_url'],
sacle=self.result['scale'],
clothes_keypoint=self.result['clothes_keypoint'],
position=self.calculate_start_point(self.result['keypoint'],
self.result['scale'],
self.result['clothes_keypoint'],
self.result['body_point'],
'left'))
layer.insert(left_layer)
right_layer = dict(name=f'{type(self).__name__.lower()}_right',
image=self.result['shoes_right'],
image_url=self.result['right_image_url'],
mask_url=self.result['right_mask_url'],
sacle=self.result['scale'],
clothes_keypoint=self.result['clothes_keypoint'],
position=self.calculate_start_point(self.result['keypoint'],
self.result['scale'],
self.result['clothes_keypoint'],
self.result['body_point'],
'right'))
layer.insert(right_layer)
def cut(self):
"""
Cut shoes mask into two pieces
Returns:
"""
contour, _ = cv2.findContours(self.result['mask'], cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
contours = sorted(contour, key=cv2.contourArea, reverse=True)
bounding_boxes = [cv2.boundingRect(c) for c in contours[:2]]
(contours, bounding_boxes) = zip(*sorted(zip(contours[:2], bounding_boxes), key=lambda x: x[1][0], reverse=False))
epsilon_left = 0.001 * cv2.arcLength(contours[0], True)
approx_left = cv2.approxPolyDP(contours[0], epsilon_left, True)
mask_left = np.zeros(self.result['final_image'].shape[:2], np.uint8)
cv2.drawContours(mask_left, [approx_left], -1, 255, -1)
item_mask_left = cv2.GaussianBlur(mask_left, (5, 5), 0)
rgba_image = rgb_to_rgba((self.result['final_image'].shape[0], self.result['final_image'].shape[1]), self.result['final_image'], item_mask_left)
result_image = np.zeros_like(rgba_image)
result_image[self.result['front_mask'] != 0] = rgba_image[self.result['front_mask'] != 0]
result_left_image_pil = Image.fromarray(result_image, 'RGBA')
result_left_image_pil = result_left_image_pil.resize((int(result_left_image_pil.width * self.result["scale"]), int(result_left_image_pil.height * self.result["scale"])), Image.LANCZOS)
self.result['shoes_left'], self.result["left_image_url"], self.result["left_mask_url"] = upload_png_mask(result_left_image_pil, f"{generate_uuid()}")
epsilon_right = 0.001 * cv2.arcLength(contours[1], True)
approx_right = cv2.approxPolyDP(contours[1], epsilon_right, True)
mask_right = np.zeros(self.result['final_image'].shape[:2], np.uint8)
cv2.drawContours(mask_right, [approx_right], -1, 255, -1)
item_mask_right = cv2.GaussianBlur(mask_right, (5, 5), 0)
rgba_image = rgb_to_rgba((self.result['final_image'].shape[0], self.result['final_image'].shape[1]), self.result['final_image'], item_mask_right)
result_image = np.zeros_like(rgba_image)
result_image[self.result['front_mask'] != 0] = rgba_image[self.result['front_mask'] != 0]
result_right_image_pil = Image.fromarray(result_image, 'RGBA')
result_right_image_pil = result_right_image_pil.resize((int(result_right_image_pil.width * self.result["scale"]), int(result_right_image_pil.height * self.result["scale"])), Image.LANCZOS)
self.result['shoes_right'], self.result["right_image_url"], self.result["right_mask_url"] = upload_png_mask(result_right_image_pil, f"{generate_uuid()}")
return item_mask_left, item_mask_right
@staticmethod
def calculate_start_point(keypoint_type, scale, clothes_point, body_point, location):
"""
left shoes align left
right shoes align right
Args:
keypoint_type: string, "toe"
scale: float
clothes_point: dict{'left': [x1, y1, z1], 'right': [x2, y2, z2]}
body_point: dict, containing keypoint data of body figure
location: string, indicates whether the start point belongs to right or left shoe
Returns:
start_point: tuple (x', y')
x' = y_body - y1 * scale
y' = x_body - x1 * scale
"""
if location not in ['left', 'right']:
raise KeyError(f'location value must be left or right but got {location}')
side_indicator = f'{keypoint_type}_{location}'
# clothes_point = {k: tuple(map(lambda x: int(scale * x), v[0: 2])) for k, v in clothes_point.items()}
start_point = (body_point[side_indicator][1] - int(int(clothes_point[side_indicator].split("_")[1]) * scale),
body_point[side_indicator][0] - int(int(clothes_point[side_indicator].split("_")[0]) * scale))
return start_point

View File

@@ -1,46 +0,0 @@
from .builder import ITEMS
from .clothing import Clothing
@ITEMS.register_module()
class Top(Clothing):
def __init__(self, pipeline, **kwargs):
if pipeline is None:
pipeline = [
dict(type='LoadImageFromFile', path=kwargs['path'], color=kwargs['color'], print_dict=kwargs['print']),
dict(type='KeypointDetection'),
# dict(type='ContourDetection'),
dict(type='Segmentation'),
dict(type='Painting', painting_flag=True),
dict(type='PrintPainting', print_flag=True),
# dict(type='ImageShow', key=['image', 'mask', 'seg_visualize', 'pattern_image']),
dict(type='Scaling'),
dict(type='Split'),
]
kwargs.update(pipeline=pipeline)
super(Top, self).__init__(**kwargs)
@ITEMS.register_module()
class Blouse(Top):
def __init__(self, pipeline=None, **kwargs):
super(Blouse, self).__init__(pipeline, **kwargs)
@ITEMS.register_module()
class Outwear(Top):
def __init__(self, pipeline=None, **kwargs):
super(Outwear, self).__init__(pipeline, **kwargs)
@ITEMS.register_module()
class Dress(Top):
def __init__(self, pipeline=None, **kwargs):
super(Dress, self).__init__(pipeline, **kwargs)
# Men's clothing
@ITEMS.register_module()
class Tops(Top):
def __init__(self, pipeline=None, **kwargs):
super(Tops, self).__init__(pipeline, **kwargs)

View File

@@ -1,197 +0,0 @@
import concurrent.futures
import io
import cv2
from app.core.config import PRIORITY_DICT
from app.service.design.core.layer import Layer
from app.service.design.items import build_item
from app.service.design.utils.redis_utils import Redis
from app.service.design.utils.synthesis_item import synthesis, synthesis_single
from app.service.utils.decorator import RunTime
from app.service.utils.oss_client import oss_upload_image
def process_item(item, layers):
# logging.info("process running.........")
item.process()
item.organize(layers)
if item.result['name'] == "mannequin":
return item.result['body_image'].size
def update_progress(process_id, total):
r = Redis()
progress = r.read(key=process_id)
if progress and total != 1:
if int(progress) <= 100:
r.write(key=process_id, value=int(progress) + int(100 / total))
else:
r.write(key=process_id, value=99)
return progress
elif total == 1:
r.write(key=process_id, value=100)
return progress
else:
r.write(key=process_id, value=int(100 / total))
return progress
def final_progress(process_id):
r = Redis()
progress = r.read(key=process_id)
r.write(key=process_id, value=100)
return progress
@RunTime
def generate(request_data):
return_response = {}
return_png_mask = []
request_data = request_data.dict()
assert "process_id" in request_data.keys(), "Need process_id parameters"
objects = request_data['objects']
# insert_keypoint_cache(objects)
process_id = request_data['process_id']
with concurrent.futures.ThreadPoolExecutor() as executor:
# 提交每个对象的处理任务
futures = {executor.submit(process_object, cfg, process_id, len(objects)): obj for obj, cfg in enumerate(objects)}
# 获取处理结果
for future in concurrent.futures.as_completed(futures):
obj = futures[future]
return_response[obj] = future.result()[0]
return_png_mask.extend(future.result()[1])
# upload_results = process_images(return_png_mask)
final_progress(process_id)
return return_response
def process_object(cfg, process_id, total):
uploaded_images = []
basic_info = cfg.get('basic')
items_response = {
'layers': []
}
if cfg.get('basic')['single_overall'] == 'overall':
basic_info['debug'] = False
items = [build_item(x, default_args=basic_info) for x in cfg.get('items')]
layers = Layer()
body_size = None
futures = []
for item in items:
futures = [process_item(item, layers)]
for future in futures:
if future is not None:
body_size = future
# 是否自定义排序
if basic_info.get('layer_order', False):
layers = sorted(layers.layer, key=lambda s: s.get("priority", float('inf')))
else:
layers = sorted(layers.layer, key=lambda x: PRIORITY_DICT.get(x['name'], float('inf')))
# 上传所有图片
# for layer in layers:
# if 'image' in layer.keys() and layer['image'] is not None:
# uploaded_images.append({'image_obj': layer['image'], 'image_url': layer['image_url'], 'image_type': 'image'})
# if 'pattern_image' in layer.keys() and layer['pattern_image'] is not None:
# uploaded_images.append({'image_obj': layer['pattern_image'], 'image_url': layer['pattern_image_url'], 'image_type': 'pattern_image'})
# if 'mask' in layer.keys() and layer['mask'] is not None and layer['mask_url'] is not None:
# uploaded_images.append({'image_obj': layer['mask'], 'image_url': layer['mask_url'], 'image_type': 'mask'})
layers, new_size = update_base_size_priority(layers, body_size)
# 合成
items_response['synthesis_url'] = synthesis(layers, new_size, basic_info)
for lay in layers:
items_response['layers'].append({
'image_category': lay['name'],
'position': lay['position'],
'priority': lay.get("priority", None),
'resize_scale': lay['resize_scale'] if "resize_scale" in lay.keys() else None,
'image_size': lay['image'] if lay['image'] is None else lay['image'].size,
'gradient_string': lay['gradient_string'] if 'gradient_string' in lay.keys() else "",
'mask_url': lay['mask_url'],
'image_url': lay['image_url'] if 'image_url' in lay.keys() else None,
'pattern_image_url': lay['pattern_image_url'] if 'pattern_image_url' in lay.keys() else None,
# 'image': lay['image'],
# 'mask_image': lay['mask_image'],
})
elif cfg.get('basic')['single_overall'] == 'single':
assert cfg.get('basic')['switch_category'] in [x['type'] for x in cfg.get('items')], "Lack of switch_category parameters "
basic_info['debug'] = False
for item in cfg.get('items'):
if item['type'] == cfg.get('basic')['switch_category']:
item = build_item(item, default_args=cfg.get('basic'))
item.process()
items_response['layers'].append({
'image_category': f"{item.result['name']}_front",
'image_size': item.result['back_image'].size if item.result['back_image'] else None,
'position': None,
'priority': 0,
'image_url': item.result['front_image_url'],
'mask_url': item.result['mask_url'],
"gradient_string": item.result['gradient_string'] if 'gradient_string' in item.result.keys() else "",
'pattern_image_url': item.result['pattern_image_url'] if 'pattern_image_url' in item.result.keys() else None,
})
items_response['layers'].append({
'image_category': f"{item.result['name']}_back",
'image_size': item.result['front_image'].size if item.result['front_image'] else None,
'position': None,
'priority': 0,
'image_url': item.result['back_image_url'],
'mask_url': item.result['mask_url'],
"gradient_string": item.result['gradient_string'] if 'gradient_string' in item.result.keys() else "",
'pattern_image_url': item.result['pattern_image_url'] if 'pattern_image_url' in item.result.keys() else None,
})
items_response['synthesis_url'] = synthesis_single(item.result['front_image'], item.result['back_image'])
break
update_progress(process_id, total)
return items_response, uploaded_images
@RunTime
def process_images(images):
with concurrent.futures.ThreadPoolExecutor() as executor:
results = list(executor.map(upload_images, images))
# results = []
# for image in images:
# results.append(upload_images(image))
return results
# @RunTime
def upload_images(image_obj):
bucket_name = image_obj['image_url'].split("/", 1)[0]
object_name = image_obj['image_url'].split("/", 1)[1]
if image_obj['image_type'] == 'image' or image_obj['image_type'] == 'pattern_image':
image_data = io.BytesIO()
image_obj['image_obj'].save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
return image_obj['image_url']
else:
mask_inverted = cv2.bitwise_not(image_obj['image_obj'])
# 将掩模的3通道转换为4通道白色部分不透明黑色部分透明
rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA)
rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0]
req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=cv2.imencode('.png', rgba_image)[1])
return image_obj['image_url']
def update_base_size_priority(layers, size):
# 计算透明背景图片的宽度
min_x = min(info['position'][1] for info in layers)
x_list = []
for info in layers:
if info['image'] is not None:
x_list.append(info['position'][1] + info['image'].width)
max_x = max(x_list)
new_width = max_x - min_x
new_height = 700
# 更新坐标
for info in layers:
info['adaptive_position'] = (info['position'][0], info['position'][1] - min_x)
return layers, (new_width, new_height)

View File

@@ -1,31 +0,0 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File conversion_image.py
@Author :周成融
@Date 2023/8/21 10:40:29
@detail
"""
import numpy as np
# def rgb_to_rgba(rgb_size, rgb_image, mask):
# alpha_channel = np.full(rgb_size, 255, dtype=np.uint8)
# # 创建四通道的结果图像
# rgba_image = np.dstack((rgb_image, alpha_channel))
# alpha_channel = np.where(mask > 0, 255, 0)
# # 更新RGBA图像的透明度通道
# rgba_image[:, :, 3] = alpha_channel
# return rgba_image
def rgb_to_rgba(rgb_image, mask):
# 创建全透明的alpha通道
alpha_channel = np.where(mask > 0, 255, 0).astype(np.uint8)
# 合并RGB图像和alpha通道
rgba_image = np.dstack((rgb_image, alpha_channel))
return rgba_image
if __name__ == '__main__':
image = open("")

View File

@@ -1,143 +0,0 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File design_ensemble.py
@Author :周成融
@Date 2023/8/16 19:36:21
@detail :发起请求 获取推理结果
"""
import logging
import cv2
import mmcv
import numpy as np
import torch
import torch.nn.functional as F
import tritonclient.http as httpclient
from app.core.config import *
"""
keypoint
预处理 推理 后处理
"""
def keypoint_preprocess(img_path):
img = mmcv.imread(img_path)
img_scale = (256, 256)
h, w = img.shape[:2]
img = cv2.resize(img, img_scale)
w_scale = img_scale[0] / w
h_scale = img_scale[1] / h
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img, (w_scale, h_scale)
# @ RunTime
# 推理
def get_keypoint_result(image, site):
keypoint_result = None
try:
image, scale_factor = keypoint_preprocess(image)
client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
transformed_img = image.astype(np.float32)
inputs = [httpclient.InferInput(f"input", transformed_img.shape, datatype="FP32")]
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
outputs = [httpclient.InferRequestedOutput(f"output", binary_data=True)]
results = client.infer(model_name=f"keypoint_{site}_ocrnet_hr18", inputs=inputs, outputs=outputs)
inference_output = torch.from_numpy(results.as_numpy(f'output'))
keypoint_result = keypoint_postprocess(inference_output, scale_factor)
except Exception as e:
logging.warning(f"get_keypoint_result : {e}")
return keypoint_result
def keypoint_postprocess(output, scale_factor):
max_indices = torch.argmax(output.view(output.size(0), output.size(1), -1), dim=2).unsqueeze(dim=2)
max_coords = torch.cat((max_indices / output.size(3), max_indices % output.size(3)), dim=2)
segment_result = max_coords.numpy()
scale_factor = [1 / x for x in scale_factor[::-1]]
scale_matrix = np.diag(scale_factor)
nan = np.isinf(scale_matrix)
scale_matrix[nan] = 0
return np.ceil(np.dot(segment_result, scale_matrix) * 4)
"""
seg
预处理 推理 后处理
"""
# KNet
def seg_preprocess(img_path):
img = mmcv.imread(img_path)
ori_shape = img.shape[:2]
img_scale_w, img_scale_h = ori_shape
if ori_shape[0] > 1024:
img_scale_w = 1024
if ori_shape[1] > 1024:
img_scale_h = 1024
# 如果图片size任意一边 大于 1024 则会resize 成1024
if ori_shape != (img_scale_w, img_scale_h):
# mmcv.imresize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了
img = cv2.resize(img, (img_scale_h, img_scale_w))
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
return preprocessed_img, ori_shape
# @ RunTime
def get_seg_result(image_id, image):
image, ori_shape = seg_preprocess(image)
client = httpclient.InferenceServerClient(url=f"{DESIGN_MODEL_URL}")
transformed_img = image.astype(np.float32)
# 输入集
inputs = [
httpclient.InferInput(SEGMENTATION['input'], transformed_img.shape, datatype="FP32")
]
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
# 输出集
outputs = [
httpclient.InferRequestedOutput(SEGMENTATION['output'], binary_data=True),
]
results = client.infer(model_name=SEGMENTATION['new_model_name'], inputs=inputs, outputs=outputs)
# 推理
# 取结果
inference_output1 = results.as_numpy(SEGMENTATION['output'])
seg_result = seg_postprocess(int(image_id), inference_output1, ori_shape)
return seg_result
# no cache
def seg_postprocess(image_id, output, ori_shape):
seg_logit = F.interpolate(torch.tensor(output).float(), size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False)
seg_pred = seg_logit.cpu().numpy()
return seg_pred[0]
def key_point_show(image_path, key_point_result=None):
img = cv2.imread(image_path)
points_list = key_point_result
point_size = 1
point_color = (0, 0, 255) # BGR
thickness = 4 # 可以为 0 、4、8
for point in points_list:
cv2.circle(img, point[::-1], point_size, point_color, thickness)
cv2.imshow("0", img)
cv2.waitKey(0)
if __name__ == '__main__':
image = cv2.imread("9070101c-e5be-49b5-9602-4113a968969b.png")
a = get_keypoint_result(image, "up")
new_list = []
print(list)
for i in a[0]:
new_list.append((int(i[0]), int(i[1])))
key_point_show("9070101c-e5be-49b5-9602-4113a968969b.png", new_list)
# a = get_seg_result(1, image)
print(a)

View File

@@ -1,99 +0,0 @@
import redis
from app.core.config import REDIS_HOST, REDIS_PORT
class Redis(object):
"""
redis数据库操作
"""
@staticmethod
def _get_r():
host = REDIS_HOST
port = REDIS_PORT
db = 0
r = redis.StrictRedis(host, port, db)
return r
@classmethod
def write(cls, key, value, expire=None):
"""
写入键值对
"""
# 判断是否有过期时间,没有就设置默认值
if expire:
expire_in_seconds = expire
else:
expire_in_seconds = 100
r = cls._get_r()
r.set(key, value, ex=expire_in_seconds)
@classmethod
def read(cls, key):
"""
读取键值对内容
"""
r = cls._get_r()
value = r.get(key)
return value.decode('utf-8') if value else value
@classmethod
def hset(cls, name, key, value):
"""
写入hash表
"""
r = cls._get_r()
r.hset(name, key, value)
@classmethod
def hget(cls, name, key):
"""
读取指定hash表的键值
"""
r = cls._get_r()
value = r.hget(name, key)
return value.decode('utf-8') if value else value
@classmethod
def hgetall(cls, name):
"""
获取指定hash表所有的值
"""
r = cls._get_r()
return r.hgetall(name)
@classmethod
def delete(cls, *names):
"""
删除一个或者多个
"""
r = cls._get_r()
r.delete(*names)
@classmethod
def hdel(cls, name, key):
"""
删除指定hash表的键值
"""
r = cls._get_r()
r.hdel(name, key)
@classmethod
def expire(cls, name, expire=None):
"""
设置过期时间
"""
if expire:
expire_in_seconds = expire
else:
expire_in_seconds = 100
r = cls._get_r()
r.expire(name, expire_in_seconds)
if __name__ == '__main__':
redis_client = Redis()
# print(redis_client.write(key="1230", value=0))
redis_client.write(key="1230", value=10)
# print(redis_client.read(key="1230"))

View File

@@ -1,181 +0,0 @@
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
"""
@Project trinity_client
@File synthesis_item.py
@Author :周成融
@Date 2023/8/26 14:13:04
@detail
"""
import io
import logging
import cv2
import numpy as np
from PIL import Image
from app.service.utils.generate_uuid import generate_uuid
from app.service.utils.oss_client import oss_upload_image
def positioning(all_mask_shape, mask_shape, offset):
all_start = 0
all_end = 0
mask_start = 0
mask_end = 0
if offset == 0:
all_start = 0
all_end = min(all_mask_shape, mask_shape)
mask_start = 0
mask_end = min(all_mask_shape, mask_shape)
elif offset > 0:
all_start = min(offset, all_mask_shape)
all_end = min(offset + mask_shape, all_mask_shape)
mask_start = 0
mask_end = 0 if offset > all_mask_shape else min(all_mask_shape - offset, mask_shape)
elif offset < 0:
if abs(offset) > mask_shape:
all_start = 0
all_end = 0
else:
all_start = 0
if mask_shape - abs(offset) > all_mask_shape:
all_end = min(mask_shape - abs(offset), all_mask_shape)
else:
all_end = mask_shape - abs(offset)
if abs(offset) > mask_shape:
mask_start = mask_shape
mask_end = mask_shape
else:
mask_start = abs(offset)
if mask_shape - abs(offset) >= all_mask_shape:
mask_end = all_mask_shape + abs(offset)
else:
mask_end = mask_shape
return all_start, all_end, mask_start, mask_end
# @RunTime
def synthesis(data, size, basic_info):
# 创建底图
base_image = Image.new('RGBA', size, (0, 0, 0, 0))
try:
all_mask_shape = (size[1], size[0])
body_mask = None
for d in data:
if d['name'] == 'body':
# 创建一个新的宽高透明图像, 把模特贴上去获取mask
transparent_image = Image.new("RGBA", size, (0, 0, 0, 0))
transparent_image.paste(d['image'], (d['adaptive_position'][1], d['adaptive_position'][0]), d['image']) # 此处可变数组会被paste篡改值所以使用下标获取position
body_mask = np.array(transparent_image.split()[3])
# 根据新的坐标获取新的肩点
left_shoulder = [x + y for x, y in zip(basic_info['body_point_test']['shoulder_left'], [d['adaptive_position'][1], d['adaptive_position'][0]])]
right_shoulder = [x + y for x, y in zip(basic_info['body_point_test']['shoulder_right'], [d['adaptive_position'][1], d['adaptive_position'][0]])]
body_mask[:min(left_shoulder[1], right_shoulder[1]), left_shoulder[0]:right_shoulder[0]] = 255
_, binary_body_mask = cv2.threshold(body_mask, 127, 255, cv2.THRESH_BINARY)
top_outer_mask = np.array(binary_body_mask)
bottom_outer_mask = np.array(binary_body_mask)
top = True
bottom = True
i = len(data)
while i:
i -= 1
if top and data[i]['name'] in ["blouse_front", "outwear_front", "dress_front", "tops_front"]:
top = False
mask_shape = data[i]['mask'].shape
y_offset, x_offset = data[i]['adaptive_position']
# 初始化叠加区域的起始和结束位置
all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset)
all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset)
# 将叠加区域赋值为相应的像素值
_, sketch_mask = cv2.threshold(data[i]['mask'], 127, 255, cv2.THRESH_BINARY)
background = np.zeros_like(top_outer_mask)
background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end]
top_outer_mask = background + top_outer_mask
elif bottom and data[i]['name'] in ["trousers_front", "skirt_front", "bottoms_front", "dress_front"]:
bottom = False
mask_shape = data[i]['mask'].shape
y_offset, x_offset = data[i]['adaptive_position']
# 初始化叠加区域的起始和结束位置
all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset)
all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset)
# 将叠加区域赋值为相应的像素值
_, sketch_mask = cv2.threshold(data[i]['mask'], 127, 255, cv2.THRESH_BINARY)
background = np.zeros_like(top_outer_mask)
background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end]
bottom_outer_mask = background + bottom_outer_mask
elif bottom is False and top is False:
break
all_mask = cv2.bitwise_or(top_outer_mask, bottom_outer_mask)
for layer in data:
if layer['image'] is not None:
if layer['name'] != "body":
test_image = Image.new('RGBA', size, (0, 0, 0, 0))
test_image.paste(layer['image'], (layer['adaptive_position'][1], layer['adaptive_position'][0]), layer['image'])
mask_data = np.where(all_mask > 0, 255, 0).astype(np.uint8)
mask_alpha = Image.fromarray(mask_data)
cropped_image = Image.composite(test_image, Image.new("RGBA", test_image.size, (255, 255, 255, 0)), mask_alpha)
base_image.paste(test_image, (0, 0), cropped_image) # test_image 已经按照坐标贴到最大宽值的图片上 坐着这里坐标为00
else:
base_image.paste(layer['image'], (layer['adaptive_position'][1], layer['adaptive_position'][0]), layer['image'])
result_image = base_image
image_data = io.BytesIO()
result_image.save(image_data, format='PNG')
image_data.seek(0)
# oss upload
image_bytes = image_data.read()
bucket_name = "aida-results"
object_name = f'result_{generate_uuid()}.png'
req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
return f"{bucket_name}/{object_name}"
# return f"aida-results/{minio_client.put_object('aida-results', f'result_{generate_uuid()}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
# object_name = f'result_{generate_uuid()}.png'
# response = s3.put_object(Bucket="aida-results", Key=object_name, Body=data, ContentType='image/png')
# object_url = f"aida-results/{object_name}"
# if response['ResponseMetadata']['HTTPStatusCode'] == 200:
# return object_url
# else:
# return ""
except Exception as e:
logging.warning(f"synthesis runtime exception : {e}")
def synthesis_single(front_image, back_image):
result_image = None
if front_image:
result_image = front_image
if back_image:
result_image.paste(back_image, (0, 0), back_image)
# with io.BytesIO() as output:
# result_image.save(output, format='PNG')
# data = output.getvalue()
# object_name = f'result_{generate_uuid()}.png'
# response = s3.put_object(Bucket="aida-results", Key=object_name, Body=data, ContentType='image/png')
# object_url = f"aida-results/{object_name}"
# if response['ResponseMetadata']['HTTPStatusCode'] == 200:
# return object_url
# else:
# return ""
image_data = io.BytesIO()
result_image.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
# return f"aida-results/{minio_client.put_object('aida-results', f'result_{generate_uuid()}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
# oss upload
bucket_name = 'aida-results'
object_name = f'result_{generate_uuid()}.png'
req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
return f"{bucket_name}/{object_name}"

View File

@@ -4,20 +4,20 @@ import threading
from celery import Celery
from minio import Minio
from app.core.config import *
from app.service.design_batch.item import BodyItem, TopItem, BottomItem, AccessoriesItem
from app.core.config import settings
from app.service.design_batch.item import BodyItem, TopItem, BottomItem, OthersItem
from app.service.design_batch.utils.MQ import publish_status
from app.service.design_batch.utils.organize import organize_body, organize_clothing, organize_accessories
from app.service.design_batch.utils.organize import organize_body, organize_clothing, organize_others
from app.service.design_batch.utils.save_json import oss_upload_json
from app.service.design_batch.utils.synthesis_item import update_base_size_priority, synthesis, synthesis_single
id_lock = threading.Lock()
celery_app = Celery('tasks', broker=f'amqp://rabbit:123456@18.167.251.121:5672//', backend='rpc://', BROKER_CONNECTION_RETRY_ON_STARTUP=True)
celery_app = Celery('tasks', broker=f'amqp://{settings.MQ_USERNAME}:{settings.MQ_PASSWORD}@{settings.MQ_HOST}:{settings.MQ_PORT}//', backend='rpc://')
celery_app.conf.worker_log_format = '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s'
celery_app.conf.worker_hijack_root_logger = False
logging.getLogger('pika').setLevel(logging.WARNING)
logger = logging.getLogger()
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
print("start")
@@ -33,8 +33,8 @@ def process_item(item, basic):
elif item['type'].lower() in ['skirt', 'trousers', 'bottoms']:
bottom_server = BottomItem(data=item, basic=basic, minio_client=minio_client)
item_data = bottom_server.process()
elif item['type'].lower() in ['accessories']:
bottom_server = AccessoriesItem(data=item, basic=basic, minio_client=minio_client)
elif item['type'].lower() in ['others']:
bottom_server = OthersItem(data=item, basic=basic, minio_client=minio_client)
item_data = bottom_server.process()
else:
raise NotImplementedError(f"Item type {item['type']} not implemented")
@@ -47,14 +47,16 @@ def process_layer(item, layers):
body_layer = organize_body(item)
layers.append(body_layer)
return item['body_image'].size
elif item['name'] == 'accessories':
front_layer, back_layer = organize_accessories(item)
elif item['name'] == 'others':
front_layer, back_layer = organize_others(item)
layers.append(front_layer)
layers.append(back_layer)
return None
else:
front_layer, back_layer = organize_clothing(item)
layers.append(front_layer)
layers.append(back_layer)
return None
@celery_app.task
@@ -76,12 +78,11 @@ def batch_design(objects_data, tasks_id, json_name):
for item in object['items']:
item_results.append(process_item(item, basic))
layers = []
body_size = None
for item in item_results:
body_size = process_layer(item, layers)
process_layer(item, layers)
layers = sorted(layers, key=lambda s: s.get("priority", float('inf')))
layers, new_size = update_base_size_priority(layers, body_size)
layers, new_size = update_base_size_priority(layers)
for lay in layers:
items_response['layers'].append({

View File

@@ -9,10 +9,10 @@ class BaseItem:
self.result.update(basic)
class AccessoriesItem(BaseItem):
class OthersItem(BaseItem):
def __init__(self, data, basic, minio_client):
super().__init__(data, basic)
self.Accessories_pipeline = [
self.Others_pipeline = [
LoadImage(minio_client),
# KeyPoint(),
ContourDetection(),
@@ -25,7 +25,7 @@ class AccessoriesItem(BaseItem):
]
def process(self):
for item in self.Accessories_pipeline:
for item in self.Others_pipeline:
self.result = item(self.result)
return self.result

View File

@@ -18,11 +18,11 @@ class BackPerspective:
result['back_perspective_url'] = file_path
return result
else:
seg_result = get_seg_result("1", result['image'])[0]
seg_result = get_seg_result(result['image'])[0]
elif result['name'] in ['blouse', 'outwear', 'dress', 'tops']:
seg_result = result['seg_result']
else:
seg_result = get_seg_result("1", result['image'])[0]
seg_result = get_seg_result(result['image'])[0]
m = self.thicken_contours_and_display(seg_result, thickness=10, color=(0, 0, 0))
back_sketch = result['image'].copy()
@@ -34,7 +34,8 @@ class BackPerspective:
result['back_perspective_url'] = f"{resp.bucket_name}/{resp.object_name}"
return result
def thicken_contours_and_display(self, mask, thickness=10, color=(0, 0, 0)):
@staticmethod
def thicken_contours_and_display(mask, thickness=10, color=(0, 0, 0)):
mask = mask.astype(np.uint8) * 255
# 查找轮廓
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
@@ -48,9 +49,9 @@ class BackPerspective:
# 在空白图像上绘制白色的轮廓
cv2.drawContours(blank, [contour], -1, 255, thickness=thick)
# 找到轮廓的中心(可以用重心等方法近似)
M = cv2.moments(contour)
cx = int(M['m10'] / M['m00'])
cy = int(M['m01'] / M['m00'])
m = cv2.moments(contour)
cx = int(m['m10'] / m['m00'])
cy = int(m['m01'] / m['m00'])
# 进行距离变换,离中心越近的值越小
dist_transform = cv2.distanceTransform(255 - blank, cv2.DIST_L2, 5)
# 根据距离变换的值来决定是否保留像素,离中心近的像素更容易被保留

View File

@@ -79,9 +79,9 @@ class Color:
def get_pattern(single_color):
if single_color is None:
raise False
R, G, B = single_color.split(' ')
r, g, b = single_color.split(' ')
pattern = np.zeros([1, 1, 3], np.uint8)
pattern[0, 0, 0] = int(B)
pattern[0, 0, 1] = int(G)
pattern[0, 0, 2] = int(R)
pattern[0, 0, 0] = int(b)
pattern[0, 0, 1] = int(g)
pattern[0, 0, 2] = int(r)
return pattern

View File

@@ -3,7 +3,7 @@ import logging
import numpy as np
from pymilvus import MilvusClient
from app.core.config import *
from app.core.config import KEYPOINT_RESULT_TABLE_FIELD_SET, MILVUS_TABLE_KEYPOINT, settings
from app.service.design_fast.utils.design_ensemble import get_keypoint_result
from app.service.utils.decorator import ClassCallRunTime, RunTime
@@ -21,12 +21,12 @@ class KeyPoint:
def __call__(self, result):
if result['name'] in ['blouse', 'skirt', 'dress', 'outwear', 'trousers', 'tops', 'bottoms']: # 查询是否有数据 且类别相同 相同则直接读 不同则推理后更新
# result['clothes_keypoint'] = self.infer_keypoint_result(result)
site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
# 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
# keypoint_cache = search_keypoint_cache(result["image_id"], site)
# keypoint_cache = self.keypoint_cache(result, site)
keypoint_cache = False
# 取消向量查询 直接过模型推理
if keypoint_cache is False:
if not keypoint_cache:
keypoint_infer_result, site = self.infer_keypoint_result(result)
result['clothes_keypoint'] = self.save_keypoint_cache(result["image_id"], keypoint_infer_result, site)
else:
@@ -55,8 +55,8 @@ class KeyPoint:
}
]
try:
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
res = client.upsert(collection_name=MILVUS_TABLE_KEYPOINT, data=data)
client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS)
client.upsert(collection_name=MILVUS_TABLE_KEYPOINT, data=data)
client.close()
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
except Exception as e:
@@ -79,7 +79,7 @@ class KeyPoint:
]
try:
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS)
client.upsert(
collection_name=MILVUS_TABLE_KEYPOINT,
data=data
@@ -92,7 +92,7 @@ class KeyPoint:
@RunTime
def keypoint_cache(self, result, site):
try:
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS)
keypoint_id = result['image_id']
res = client.query(
collection_name=MILVUS_TABLE_KEYPOINT,

View File

@@ -1,9 +1,6 @@
import io
import logging
import cv2
import numpy as np
from PIL import Image
from app.service.utils.new_oss_client import oss_get_image
@@ -74,8 +71,8 @@ class LoadImage:
keypoint = 'head_point'
elif name == 'earring':
keypoint = 'ear_point'
elif name == 'accessories':
keypoint = "accessories"
elif name == 'others':
keypoint = "others"
else:
raise KeyError(f"{name} does not belong to item category list: blouse, outwear, dress, trousers, skirt, "
f"bag, shoes, hairstyle, earring.")

View File

@@ -9,6 +9,7 @@ from app.service.utils.new_oss_client import oss_get_image
class PrintPainting:
def __init__(self, minio_client):
self.random_seed = None
self.minio_client = minio_client
def __call__(self, result):
@@ -408,7 +409,7 @@ class PrintPainting:
change_mask = print_mask[start_h: length_h, start_w: length_w]
# get real part into change mask
_, change_mask = cv2.threshold(change_mask, 220, 255, cv2.THRESH_BINARY)
mask = cv2.bitwise_not(painting_dict['mask_inv_print'])
cv2.bitwise_not(painting_dict['mask_inv_print'])
img_fg[start_h:start_h + painting_dict['dim_print_h'], start_w:start_w + painting_dict['dim_print_w'], :] = change_region
clothes_mask_print = cv2.bitwise_not(print_mask)

View File

@@ -46,7 +46,7 @@ class Scaling:
result['scale'] = result['scale_bag']
elif result['keypoint'] == 'ear_point':
result['scale'] = result['scale_earrings']
elif result['keypoint'] == 'accessories':
elif result['keypoint'] == 'others':
# 由于没有识别配饰keypoint的模型 所以统一将配饰的两个关键点设定为 (0,0) (0,img.width)
# 模特的关键点设定为(0,0) (0,320/2) 距离比例简写为 160 / img.width
distance_clo = result['img_shape'][1]

View File

@@ -4,7 +4,7 @@ import os
import cv2
import numpy as np
from app.core.config import SEG_CACHE_PATH
from app.core.config import settings
from app.service.design_fast.utils.design_ensemble import get_seg_result
from app.service.utils.decorator import ClassCallRunTime
from app.service.utils.new_oss_client import oss_get_image
@@ -36,11 +36,11 @@ class Segmentation:
# preview 过模型 不缓存
if "preview_submit" in result.keys() and result['preview_submit'] == "preview":
# 推理获得seg 结果
seg_result = get_seg_result(result["image_id"], result['image'])
seg_result = get_seg_result(result['image'])
# submit 过模型 缓存
elif "preview_submit" in result.keys() and result['preview_submit'] == "submit":
# 推理获得seg 结果
seg_result = get_seg_result(result["image_id"], result['image'])
seg_result = get_seg_result(result['image'])
self.save_seg_result(seg_result, result['image_id'])
# null 正常流程 加载本地缓存 无缓存则过模型
else:
@@ -49,7 +49,7 @@ class Segmentation:
# 判断缓存和实际图片size是否相同
if not _ or result["image"].shape[:2] != seg_result.shape:
# 推理获得seg 结果
seg_result = get_seg_result(result["image_id"], result['image'])
seg_result = get_seg_result(result['image'])
self.save_seg_result(seg_result, result['image_id'])
result['seg_result'] = seg_result
@@ -63,7 +63,7 @@ class Segmentation:
@staticmethod
def save_seg_result(seg_result, image_id):
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
file_path = f"{settings.SEG_CACHE_PATH}{image_id}.npy"
try:
np.save(file_path, seg_result)
logger.debug(f"保存成功 {os.path.abspath(file_path)}")
@@ -72,7 +72,7 @@ class Segmentation:
@staticmethod
def load_seg_result(image_id):
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
file_path = f"{settings.SEG_CACHE_PATH}{image_id}.npy"
# logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy")
try:
seg_result = np.load(file_path)

View File

@@ -4,9 +4,7 @@ import logging
import cv2
import numpy as np
from PIL import Image
from cv2 import cvtColor, COLOR_BGR2RGBA
from app.core.config import AIDA_CLOTHING
from app.service.design_fast.utils.conversion_image import rgb_to_rgba
from app.service.design_fast.utils.transparent import sketch_to_transparent
from app.service.design_fast.utils.upload_image import upload_png_mask
@@ -21,7 +19,7 @@ class Split(object):
def __call__(self, result):
try:
if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms', 'accessories'):
if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms', 'others'):
if result['resize_scale'][0] == 1.0 and result['resize_scale'][1] == 1.0:
front_mask = result['front_mask']
@@ -40,7 +38,7 @@ class Split(object):
result_front_image = np.zeros_like(rgba_image)
front_mask = cv2.resize(front_mask, new_size)
result_front_image[front_mask != 0] = rgba_image[front_mask != 0]
result_front_image_pil = Image.fromarray(cvtColor(result_front_image, COLOR_BGR2RGBA))
result_front_image_pil = Image.fromarray(cv2.cvtColor(result_front_image, cv2.COLOR_BGR2RGBA))
if 'transparent' in result.keys():
# 用户自选区域transparent
transparent = result['transparent']
@@ -98,21 +96,21 @@ class Split(object):
result_back_image = np.zeros_like(rgba_image)
back_mask = cv2.resize(back_mask, new_size)
result_back_image[back_mask != 0] = rgba_image[back_mask != 0]
result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA))
result_back_image_pil = Image.fromarray(cv2.cvtColor(result_back_image, cv2.COLOR_BGR2RGBA))
result['back_image'], result["back_image_url"], _ = upload_png_mask(self.minio_client, result_back_image_pil, f'{generate_uuid()}', mask=None)
mask_image[back_mask != 0] = [0, 255, 0]
rbga_mask = rgb_to_rgba(mask_image, front_mask + back_mask)
mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
mask_pil = Image.fromarray(cv2.cvtColor(rbga_mask.astype(np.uint8), cv2.COLOR_BGR2RGBA))
image_data = io.BytesIO()
mask_pil.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
req = oss_upload_image(oss_client=self.minio_client, bucket="aida-clothing", object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
result['mask_url'] = req.bucket_name + "/" + req.object_name
# 创建中间图层
result_pattern_image_rgba = rgb_to_rgba(result['pattern_image'], result['mask'])
result_pattern_image_pil = Image.fromarray(cvtColor(result_pattern_image_rgba, COLOR_BGR2RGBA))
result_pattern_image_pil = Image.fromarray(cv2.cvtColor(result_pattern_image_rgba, cv2.COLOR_BGR2RGBA))
result['pattern_image'], result['pattern_image_url'], _ = upload_png_mask(self.minio_client, result_pattern_image_pil, f'{generate_uuid()}')
return result
except Exception as e:

View File

@@ -2,16 +2,17 @@ import json
import pika
from app.core.config import RABBITMQ_PARAMS, BATCH_DESIGN_RABBITMQ_QUEUES
from app.core.config import settings
from app.core.rabbit_mq_config import RABBITMQ_PARAMS
def publish_status(task_id, progress, result):
connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
channel = connection.channel()
channel.queue_declare(queue=BATCH_DESIGN_RABBITMQ_QUEUES, durable=True)
channel.queue_declare(queue=settings.BATCH_DESIGN_RABBITMQ_QUEUES, durable=True)
message = {'task_id': task_id, 'progress': progress, "result": result}
channel.basic_publish(exchange='',
routing_key=BATCH_DESIGN_RABBITMQ_QUEUES,
routing_key=settings.BATCH_DESIGN_RABBITMQ_QUEUES,
body=json.dumps(message),
properties=pika.BasicProperties(
delivery_mode=2,

View File

@@ -16,7 +16,7 @@ import torch
import torch.nn.functional as F
import tritonclient.http as httpclient
from app.core.config import *
from app.core.config import DESIGN_MODEL_URL, DESIGN_MODEL_NAME
"""
keypoint
@@ -91,29 +91,29 @@ def seg_preprocess(img_path):
# @ RunTime
def get_seg_result(image_id, image):
def get_seg_result(image):
image, ori_shape = seg_preprocess(image)
client = httpclient.InferenceServerClient(url=f"{DESIGN_MODEL_URL}")
transformed_img = image.astype(np.float32)
# 输入集
inputs = [
httpclient.InferInput(SEGMENTATION['input'], transformed_img.shape, datatype="FP32")
httpclient.InferInput(DESIGN_MODEL_NAME, transformed_img.shape, datatype="FP32")
]
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
# 输出集
outputs = [
httpclient.InferRequestedOutput(SEGMENTATION['output'], binary_data=True),
httpclient.InferRequestedOutput("seg_input__0", binary_data=True),
]
results = client.infer(model_name=SEGMENTATION['new_model_name'], inputs=inputs, outputs=outputs)
results = client.infer(model_name=DESIGN_MODEL_NAME, inputs=inputs, outputs=outputs)
# 推理
# 取结果
inference_output1 = results.as_numpy(SEGMENTATION['output'])
seg_result = seg_postprocess(int(image_id), inference_output1, ori_shape)
inference_output1 = results.as_numpy("seg_input__0")
seg_result = seg_postprocess(inference_output1, ori_shape)
return seg_result
# no cache
def seg_postprocess(image_id, output, ori_shape):
def seg_postprocess(output, ori_shape):
seg_logit = F.interpolate(torch.tensor(output).float(), size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False)
seg_pred = seg_logit.cpu().numpy()
return seg_pred[0]

View File

@@ -55,7 +55,7 @@ def organize_clothing(layer):
return front_layer, back_layer
def organize_accessories(layer):
def organize_others(layer):
# 起始坐标
start_point = (0, 0)
# 前片数据
@@ -98,6 +98,8 @@ def calculate_start_point(keypoint_type, scale, clothes_point, body_point, offse
"""
Align left
Args:
offset:
resize_scale:
keypoint_type: string, "waistband" | "shoulder" | "ear_point"
scale: float
clothes_point: dict{'left': [x1, y1, z1], 'right': [x2, y2, z2]}

View File

@@ -1,6 +1,6 @@
import logging
from app.service.design_fast.utils.redis_utils import Redis
from app.service.utils.redis_utils import Redis
logger = logging.getLogger(__name__)

View File

@@ -1,99 +0,0 @@
import redis
from app.core.config import REDIS_HOST, REDIS_PORT
class Redis(object):
"""
redis数据库操作
"""
@staticmethod
def _get_r():
host = REDIS_HOST
port = REDIS_PORT
db = 0
r = redis.StrictRedis(host, port, db)
return r
@classmethod
def write(cls, key, value, expire=None):
"""
写入键值对
"""
# 判断是否有过期时间,没有就设置默认值
if expire:
expire_in_seconds = expire
else:
expire_in_seconds = 100
r = cls._get_r()
r.set(key, value, ex=expire_in_seconds)
@classmethod
def read(cls, key):
"""
读取键值对内容
"""
r = cls._get_r()
value = r.get(key)
return value.decode('utf-8') if value else value
@classmethod
def hset(cls, name, key, value):
"""
写入hash表
"""
r = cls._get_r()
r.hset(name, key, value)
@classmethod
def hget(cls, name, key):
"""
读取指定hash表的键值
"""
r = cls._get_r()
value = r.hget(name, key)
return value.decode('utf-8') if value else value
@classmethod
def hgetall(cls, name):
"""
获取指定hash表所有的值
"""
r = cls._get_r()
return r.hgetall(name)
@classmethod
def delete(cls, *names):
"""
删除一个或者多个
"""
r = cls._get_r()
r.delete(*names)
@classmethod
def hdel(cls, name, key):
"""
删除指定hash表的键值
"""
r = cls._get_r()
r.hdel(name, key)
@classmethod
def expire(cls, name, expire=None):
"""
设置过期时间
"""
if expire:
expire_in_seconds = expire
else:
expire_in_seconds = 100
r = cls._get_r()
r.expire(name, expire_in_seconds)
if __name__ == '__main__':
redis_client = Redis()
# print(redis_client.write(key="1230", value=0))
redis_client.write(key="1230", value=10)
# print(redis_client.read(key="1230"))

View File

@@ -13,9 +13,12 @@ import logging
import cv2
import numpy as np
from PIL import Image
from minio import Minio
from app.core.config import settings
from app.service.utils.generate_uuid import generate_uuid
from app.service.utils.oss_client import oss_upload_image
from app.service.utils.new_oss_client import oss_upload_image
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
def positioning(all_mask_shape, mask_shape, offset):
@@ -136,7 +139,7 @@ def synthesis(data, size, basic_info):
image_bytes = image_data.read()
bucket_name = "aida-results"
object_name = f'result_{generate_uuid()}.png'
req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
oss_upload_image(oss_client=minio_client, bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
return f"{bucket_name}/{object_name}"
# return f"aida-results/{minio_client.put_object('aida-results', f'result_{generate_uuid()}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
@@ -177,11 +180,11 @@ def synthesis_single(front_image, back_image):
# oss upload
bucket_name = 'aida-results'
object_name = f'result_{generate_uuid()}.png'
req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
oss_upload_image(oss_client=minio_client, bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
return f"{bucket_name}/{object_name}"
def update_base_size_priority(layers, size):
def update_base_size_priority(layers):
# 计算透明背景图片的宽度
min_x = min(info['position'][1] for info in layers)
x_list = []

View File

@@ -12,7 +12,6 @@ import logging
import cv2
from app.core.config import *
from app.service.utils.new_oss_client import oss_upload_image
@@ -25,15 +24,15 @@ def upload_png_mask(minio_client, front_image, object_name, mask=None):
# 将掩模的3通道转换为4通道白色部分不透明黑色部分透明
rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA)
rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0]
req = oss_upload_image(oss_client=minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{object_name}.png", image_bytes=cv2.imencode('.png', rgba_image)[1])
mask_url = f"{AIDA_CLOTHING}/mask/mask_{object_name}.png"
req = oss_upload_image(oss_client=minio_client, bucket="aida-clothing", object_name=f"mask/mask_{object_name}.png", image_bytes=cv2.imencode('.png', rgba_image)[1])
mask_url = f"aida-clothing/mask/mask_{object_name}.png"
image_data = io.BytesIO()
front_image.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
req = oss_upload_image(oss_client=minio_client, bucket=AIDA_CLOTHING, object_name=f"image/image_{object_name}.png", image_bytes=image_bytes)
image_url = f"{AIDA_CLOTHING}/image/image_{object_name}.png"
req = oss_upload_image(oss_client=minio_client, bucket="aida-clothing", object_name=f"image/image_{object_name}.png", image_bytes=image_bytes)
image_url = f"aida-clothing/image/image_{object_name}.png"
return front_image, image_url, mask_url
except Exception as e:
logging.warning(f"upload_png_mask runtime exception : {e}")

View File

@@ -5,36 +5,60 @@ import time
import requests
from minio import Minio
from app.core.config import *
from app.service.design_fast.item import BodyItem, TopItem, BottomItem, AccessoriesItem
from app.service.design_fast.utils.organize import organize_body, organize_clothing, organize_accessories
from app.core.config import settings
from app.service.design_fast.item import BodyItem, TopItem, BottomItem, OthersItem, TopMergeItem, BottomMergeItem, OthersMergeItem
from app.service.design_fast.utils.organize import organize_body, organize_clothing, organize_others
from app.service.design_fast.utils.progress import final_progress, update_progress
from app.service.design_fast.utils.synthesis_item import synthesis, synthesis_single, update_base_size_priority
from app.service.design_fast.utils.synthesis_item import synthesis, synthesis_single, update_base_size_priority, merge
from app.service.utils.decorator import RunTime
id_lock = threading.Lock()
logger = logging.getLogger()
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
def process_item(item, basic):
# 处理project中单个item
if item['type'] == "Body":
body_server = BodyItem(data=item, basic=basic, minio_client=minio_client)
item_data = body_server.process()
elif item['type'].lower() in ['blouse', 'outwear', 'dress', 'tops']:
top_server = TopItem(data=item, basic=basic, minio_client=minio_client)
item_data = top_server.process()
elif item['type'].lower() in ['skirt', 'trousers', 'bottoms']:
bottom_server = BottomItem(data=item, basic=basic, minio_client=minio_client)
item_data = bottom_server.process()
elif item['type'].lower() in ['accessories']:
bottom_server = AccessoriesItem(data=item, basic=basic, minio_client=minio_client)
item_data = bottom_server.process()
def process_item(item, basic, design_type):
# 1. 定义映射配置
# key 为 item_type 的小写value 为对应的处理类
DESIGN_MAP = {
'body': BodyItem,
'blouse': TopItem, 'outwear': TopItem,
'dress': TopItem, 'tops': TopItem,
'skirt': BottomItem, 'trousers': BottomItem,
'bottoms': BottomItem,
'others': OthersItem
}
MERGE_MAP = {
'body_merge': BodyItem,
'blouse_merge': TopMergeItem, 'outwear_merge': TopMergeItem,
'dress_merge': TopMergeItem, 'tops_merge': TopMergeItem,
'skirt_merge': BottomMergeItem, 'trousers_merge': BottomMergeItem,
'bottoms_merge': BottomMergeItem,
'others_merge': OthersMergeItem
}
# 2. 根据 design_type 选择映射表
mapping = MERGE_MAP if design_type == 'merge' else DESIGN_MAP
if design_type == 'merge':
item_type_key = f"{item['type'].lower()}_merge"
elif design_type == 'default':
item_type_key = item['type'].lower()
else:
raise NotImplementedError(f"Item type {item['type']} not implemented")
item_type_key = item['type'].lower()
handler_class = mapping.get(item_type_key)
if not handler_class:
raise NotImplementedError(f"Item type {item['type']} not implemented for design_type={design_type}")
# 4. 统一实例化并执行
# 注意:这里假设所有 Item 类构造函数签名一致
server = handler_class(data=item, basic=basic, minio_client=minio_client)
item_data = server.process()
return item_data
@@ -44,14 +68,16 @@ def process_layer(item, layers):
body_layer = organize_body(item)
layers.append(body_layer)
return item['body_image'].size
elif item['name'] == 'accessories':
front_layer, back_layer = organize_accessories(item)
elif item['name'] in ['others', 'others_merge']:
front_layer, back_layer = organize_others(item)
layers.append(front_layer)
layers.append(back_layer)
return None
else:
front_layer, back_layer = organize_clothing(item)
layers.append(front_layer)
layers.append(back_layer)
return None
@RunTime
@@ -68,18 +94,18 @@ def design_generate(request_data):
nonlocal active_threads
basic = object['basic']
items_response = {'layers': [], 'objectSign': object['objectSign'] if 'objectSign' in object.keys() else ""}
design_type = basic.get('design_type', "default")
if basic['single_overall'] == "overall":
item_results = []
for item in object['items']:
item_results.append(process_item(item, basic))
item_results.append(process_item(item, basic, design_type))
layers = []
body_size = None
for item in item_results:
body_size = process_layer(item, layers)
process_layer(item, layers)
layers = sorted(layers, key=lambda s: s.get("priority", float('inf')))
layers, new_size = update_base_size_priority(layers, body_size)
layers, new_size = update_base_size_priority(layers)
# pattern_overall_image_url 、 pattern_print_image_url
for lay in layers:
items_response['layers'].append({
'image_category': "body" if lay['name'] == 'mannequin' else lay['name'],
@@ -90,10 +116,19 @@ def design_generate(request_data):
'gradient_string': lay['gradient_string'] if 'gradient_string' in lay.keys() else "",
'mask_url': lay['mask_url'],
'image_url': lay['image_url'] if 'image_url' in lay.keys() else None,
'pattern_image_url': lay['pattern_image_url'] if 'pattern_image_url' in lay.keys() else None,
'pattern_overall_image_url': lay['pattern_overall_image_url'] if 'pattern_overall_image_url' in lay.keys() else None,
'pattern_print_image_url': lay['pattern_print_image_url'] if 'pattern_print_image_url' in lay.keys() else None,
'transpose': lay.get('transpose', None),
'rotate': lay.get('rotate', None),
# 'back_perspective_url': lay['back_perspective_url'] if 'back_perspective_url' in lay.keys() else None,
})
if basic.get('design_type') == 'default':
items_response['synthesis_url'] = synthesis(layers, new_size, basic)
elif basic.get('design_type') == 'merge':
items_response['synthesis_url'] = merge(layers, new_size, basic)
else:
items_response['synthesis_url'] = synthesis(layers, new_size, basic)
else:
item_result = process_item(object['items'][0], basic)
items_response['layers'].append({
@@ -104,7 +139,9 @@ def design_generate(request_data):
'image_url': item_result['front_image_url'],
'mask_url': item_result['mask_url'],
"gradient_string": item_result['gradient_string'] if 'gradient_string' in item_result.keys() else "",
'pattern_image_url': item_result['pattern_image_url'] if 'pattern_image_url' in item_result.keys() else None,
'pattern_overall_image_url': item_result['pattern_overall_image_url'] if 'pattern_overall_image_url' in item_result.keys() else None,
'pattern_print_image_url': item_result['pattern_print_image_url'] if 'pattern_print_image_url' in item_result.keys() else None,
})
items_response['layers'].append({
'image_category': f"{item_result['name']}_back",
@@ -114,7 +151,9 @@ def design_generate(request_data):
'image_url': item_result['back_image_url'],
'mask_url': item_result['mask_url'],
"gradient_string": item_result['gradient_string'] if 'gradient_string' in item_result.keys() else "",
'pattern_image_url': item_result['pattern_image_url'] if 'pattern_image_url' in item_result.keys() else None,
'pattern_overall_image_url': item_result['pattern_overall_image_url'] if 'pattern_overall_image_url' in item_result.keys() else None,
'pattern_print_image_url': item_result['pattern_print_image_url'] if 'pattern_print_image_url' in item_result.keys() else None,
})
items_response['synthesis_url'] = synthesis_single(item_result['front_image'], item_result['back_image'])
update_progress(process_id, total)
@@ -139,10 +178,11 @@ def design_generate(request_data):
@RunTime
def design_generate_v2(request_data):
objects_data = request_data.dict()['objects']
callback_url = request_data.callback_url
request_id = request_data.requestId
threads = []
def process_object(step, object):
def process_object(object, callback_url):
basic = object['basic']
items_response = {
'layers': [],
@@ -154,12 +194,11 @@ def design_generate_v2(request_data):
for item in object['items']:
item_results.append(process_item(item, basic))
layers = []
body_size = None
for item in item_results:
body_size = process_layer(item, layers)
process_layer(item, layers)
layers = sorted(layers, key=lambda s: s.get("priority", float('inf')))
layers, new_size = update_base_size_priority(layers, body_size)
layers, new_size = update_base_size_priority(layers)
for lay in layers:
items_response['layers'].append({
@@ -171,7 +210,9 @@ def design_generate_v2(request_data):
'gradient_string': lay['gradient_string'] if 'gradient_string' in lay.keys() else "",
'mask_url': lay['mask_url'],
'image_url': lay['image_url'] if 'image_url' in lay.keys() else None,
'pattern_image_url': lay['pattern_image_url'] if 'pattern_image_url' in lay.keys() else None,
'pattern_overall_image_url': lay['pattern_overall_image_url'] if 'pattern_overall_image_url' in lay.keys() else None,
'pattern_print_image_url': lay['pattern_print_image_url'] if 'pattern_print_image_url' in lay.keys() else None,
# 'back_perspective_url': lay['back_perspective_url'] if 'back_perspective_url' in lay.keys() else None,
})
items_response['synthesis_url'] = synthesis(layers, new_size, basic)
@@ -185,7 +226,9 @@ def design_generate_v2(request_data):
'image_url': item_result['front_image_url'],
'mask_url': item_result['mask_url'],
"gradient_string": item_result['gradient_string'] if 'gradient_string' in item_result.keys() else "",
'pattern_image_url': item_result['pattern_image_url'] if 'pattern_image_url' in item_result.keys() else None,
'pattern_overall_image_url': item_result['pattern_overall_image_url'] if 'pattern_overall_image_url' in item_result.keys() else None,
'pattern_print_image_url': item_result['pattern_print_image_url'] if 'pattern_print_image_url' in item_result.keys() else None,
})
items_response['layers'].append({
'image_category': f"{item_result['name']}_back",
@@ -195,16 +238,14 @@ def design_generate_v2(request_data):
'image_url': item_result['back_image_url'],
'mask_url': item_result['mask_url'],
"gradient_string": item_result['gradient_string'] if 'gradient_string' in item_result.keys() else "",
'pattern_image_url': item_result['pattern_image_url'] if 'pattern_image_url' in item_result.keys() else None,
'pattern_overall_image_url': item_result['pattern_overall_image_url'] if 'pattern_overall_image_url' in item_result.keys() else None,
'pattern_print_image_url': item_result['pattern_print_image_url'] if 'pattern_print_image_url' in item_result.keys() else None,
})
items_response['synthesis_url'] = synthesis_single(item_result['front_image'], item_result['back_image'])
# 发送结果给java端
url = JAVA_STREAM_API_URL
# xu_pei_test_url = "https://cd21b9110505.ngrok-free.app/api/third/party/receiveDesignResults"
tianxaing_test_url = "https://c2ae520723c9.ngrok-free.app/api/third/party/receiveDesignResults"
url = callback_url
logger.info(f"java 回调 -> {url}")
# logger.info(f"xupei java 回调 -> {xu_pei_test_url}")
logger.info(f"tianxiang java 回调 -> {tianxaing_test_url}")
headers = {
'Accept': "*/*",
@@ -219,16 +260,8 @@ def design_generate_v2(request_data):
# 打印结果
logger.info(response.text)
# test_response = post_request(xu_pei_test_url, json_data=items_response, headers=headers)
test_response = post_request(tianxaing_test_url, json_data=items_response, headers=headers)
if test_response:
# 打印结果
# logger.info(f"xupei test response : {test_response.text}")
logger.info(f"tianxiang test response : {test_response.text}")
for step, object in enumerate(objects_data):
t = threading.Thread(target=process_object, args=(step, object))
t = threading.Thread(target=process_object, args=(object, callback_url))
threads.append(t)
t.start()

View File

@@ -1,4 +1,4 @@
from app.service.design_fast.pipeline import LoadImage, KeyPoint, Segmentation, Color, PrintPainting, Scaling, Split, LoadBodyImage, ContourDetection
from app.service.design_fast.pipeline import LoadImage, KeyPoint, Segmentation, Color, PrintPainting, Scaling, Split, LoadBodyImage, ContourDetection, NoSegPrintPainting
class BaseItem:
@@ -7,25 +7,21 @@ class BaseItem:
self.result['name'] = data['type'].lower()
self.result.pop("type")
self.result.update(basic)
self.result['design_type'] = basic.get('design_type', None)
class AccessoriesItem(BaseItem):
class OthersItem(BaseItem):
def __init__(self, data, basic, minio_client):
super().__init__(data, basic)
self.Accessories_pipeline = [
self.Others_pipeline = [
LoadImage(minio_client),
# KeyPoint(),
# ContourDetection(),
Segmentation(minio_client),
# BackPerspective(minio_client),
Color(minio_client),
PrintPainting(minio_client),
Scaling(),
Split(minio_client)
]
def process(self):
for item in self.Accessories_pipeline:
for item in self.Others_pipeline:
self.result = item(self.result)
return self.result
@@ -39,6 +35,7 @@ class TopItem(BaseItem):
Segmentation(minio_client),
# BackPerspective(minio_client),
Color(minio_client),
NoSegPrintPainting(minio_client),
PrintPainting(minio_client),
Scaling(),
Split(minio_client)
@@ -60,6 +57,7 @@ class BottomItem(BaseItem):
Segmentation(minio_client),
# BackPerspective(minio_client),
Color(minio_client),
NoSegPrintPainting(minio_client),
PrintPainting(minio_client),
Scaling(),
Split(minio_client)
@@ -71,6 +69,65 @@ class BottomItem(BaseItem):
return self.result
"""merge"""
class OthersMergeItem(BaseItem):
def __init__(self, data, basic, minio_client):
super().__init__(data, basic)
self.Others_pipeline = [
LoadImage(minio_client),
# KeyPoint(),
# ContourDetection(),
Segmentation(minio_client),
# BackPerspective(minio_client),
Color(minio_client),
NoSegPrintPainting(minio_client),
PrintPainting(minio_client),
Scaling(),
Split(minio_client)
]
def process(self):
for item in self.Others_pipeline:
self.result = item(self.result)
return self.result
class TopMergeItem(BaseItem):
def __init__(self, data, basic, minio_client):
super().__init__(data, basic)
self.top_pipeline = [
LoadImage(minio_client),
KeyPoint(),
Segmentation(minio_client),
Scaling(),
Split(minio_client)
]
def process(self):
for item in self.top_pipeline:
self.result = item(self.result)
return self.result
class BottomMergeItem(BaseItem):
def __init__(self, data, basic, minio_client):
super().__init__(data, basic)
self.bottom_pipeline = [
LoadImage(minio_client),
KeyPoint(),
Segmentation(minio_client),
Scaling(),
Split(minio_client)
]
def process(self):
for item in self.bottom_pipeline:
self.result = item(self.result)
return self.result
class BodyItem(BaseItem):
def __init__(self, data, basic, minio_client):
super().__init__(data, basic)

View File

@@ -1,13 +1,18 @@
import io
from app.service.utils.oss_client import oss_get_image, oss_upload_image
from minio import Minio
from app.core.config import settings
from app.service.utils.new_oss_client import oss_get_image, oss_upload_image
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
def model_transpose(image_path):
bucket = image_path.split("/", 1)[0]
object_name = image_path.split("/", 1)[1]
new_object_name = f'{object_name[:object_name.rfind(".")]}.png'
image = oss_get_image(bucket=bucket, object_name=object_name, data_type="PIL")
image = oss_get_image(oss_client=minio_client, bucket=bucket, object_name=object_name, data_type="PIL")
image = image.convert("RGBA")
data = image.getdata()
#
@@ -23,6 +28,6 @@ def model_transpose(image_path):
image.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
oss_upload_image(bucket=bucket, object_name=new_object_name, image_bytes=image_bytes)
oss_upload_image(oss_client=minio_client, bucket=bucket, object_name=new_object_name, image_bytes=image_bytes)
image_path = f"{bucket}/{new_object_name}"
return image_path

View File

@@ -5,6 +5,7 @@ from .keypoint import KeyPoint
from .keypoint import KeyPoint
from .loading import LoadImage, LoadBodyImage
from .print_painting import PrintPainting
from .no_seg_print_painting import NoSegPrintPainting
from .scale import Scaling
from .segmentation import Segmentation
from .split import Split
@@ -16,6 +17,7 @@ __all__ = [
'Segmentation',
'BackPerspective',
'Color',
'NoSegPrintPainting',
'PrintPainting',
'Scaling',
'Split'

View File

@@ -18,11 +18,11 @@ class BackPerspective:
result['back_perspective_url'] = file_path
return result
else:
seg_result = get_seg_result("1", result['image'])[0]
seg_result = get_seg_result(result['image'])[0]
elif result['name'] in ['blouse', 'outwear', 'dress', 'tops']:
seg_result = result['seg_result']
else:
seg_result = get_seg_result("1", result['image'])[0]
seg_result = get_seg_result(result['image'])[0]
m = self.thicken_contours_and_display(seg_result, thickness=10, color=(0, 0, 0))
back_sketch = result['image'].copy()
@@ -34,7 +34,8 @@ class BackPerspective:
result['back_perspective_url'] = f"{resp.bucket_name}/{resp.object_name}"
return result
def thicken_contours_and_display(self, mask, thickness=10, color=(0, 0, 0)):
@staticmethod
def thicken_contours_and_display(mask, thickness=10, color=(0, 0, 0)):
mask = mask.astype(np.uint8) * 255
# 查找轮廓
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
@@ -48,9 +49,9 @@ class BackPerspective:
# 在空白图像上绘制白色的轮廓
cv2.drawContours(blank, [contour], -1, 255, thickness=thick)
# 找到轮廓的中心(可以用重心等方法近似)
M = cv2.moments(contour)
cx = int(M['m10'] / M['m00'])
cy = int(M['m01'] / M['m00'])
m = cv2.moments(contour)
# cx = int(m['m10'] / m['m00'])
# cy = int(m['m01'] / m['m00'])
# 进行距离变换,离中心越近的值越小
dist_transform = cv2.distanceTransform(255 - blank, cv2.DIST_L2, 5)
# 根据距离变换的值来决定是否保留像素,离中心近的像素更容易被保留

View File

@@ -22,7 +22,7 @@ class Color:
resize_pattern = cv2.resize(pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA)
# 无色
elif "color" not in result.keys() or result['color'] == "":
result['final_image'] = result['pattern_image'] = result['single_image'] = result['image']
result['no_seg_sketch_overall'] = result['no_seg_sketch_print'] = result['final_image'] = result['pattern_image'] = result['single_image'] = result['image']
result['alpha'] = 100 / 255.0
return result
# 正常颜色
@@ -48,7 +48,7 @@ class Color:
resize_pattern[mask_3ch] = png_rgb[mask_3ch]
resize_pattern = cv2.resize(resize_pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA)
closed_mo = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
gray_mo = np.expand_dims(cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY), axis=2).repeat(3, axis=2)
get_image_fir = resize_pattern * (closed_mo / 255) * (gray_mo / 255)
result['pattern_image'] = get_image_fir.astype(np.uint8)
result['final_image'] = result['pattern_image']
@@ -59,6 +59,8 @@ class Color:
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
result['single_image'] = cv2.add(tmp1, tmp2)
result['alpha'] = 100 / 255.0
result['no_seg_sketch_overall'] = result['no_seg_sketch_print'] = result['final_image'].copy()
return result
def get_gradient(self, bucket_name, object_name):
@@ -79,9 +81,9 @@ class Color:
def get_pattern(single_color):
if single_color is None:
raise False
R, G, B = single_color.split(' ')
r, g, b = single_color.split(' ')
pattern = np.zeros([1, 1, 3], np.uint8)
pattern[0, 0, 0] = int(B)
pattern[0, 0, 1] = int(G)
pattern[0, 0, 2] = int(R)
pattern[0, 0, 0] = int(b)
pattern[0, 0, 1] = int(g)
pattern[0, 0, 2] = int(r)
return pattern

Some files were not shown because too many files have changed in this diff Show More