Compare commits
85 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fb46a9521d | ||
|
|
b90688f835 | ||
| 7e30779aec | |||
| f7294f5966 | |||
| 0ac5a4e0a8 | |||
| 40b57b749c | |||
|
|
b8a538a8a1 | ||
|
|
29b4f43a27 | ||
|
|
69dc20207d | ||
|
|
18979af604 | ||
|
|
74406f9be4 | ||
|
|
df99e3ac76 | ||
|
|
19346c2eb7 | ||
|
|
2af9cbfe78 | ||
| fe12b5697d | |||
| c04d4877b0 | |||
| 91016e6cae | |||
| 0f4bb260ad | |||
| c792106f02 | |||
| deac5a4cab | |||
| 15682036b3 | |||
| 9ba3a0ca49 | |||
| f6963070fb | |||
| 12f5ca3ca3 | |||
| 19110f51bf | |||
| e04636ce21 | |||
| 2a50e7040e | |||
| a6f3bda9f7 | |||
| c18f45e549 | |||
| 4951fab71a | |||
| aa57478852 | |||
| 2a6c48d937 | |||
|
|
fed3fcdf85 | ||
| 417528f8cd | |||
| 18024a2d70 | |||
|
|
1be716e414 | ||
|
|
826bdcf9c1 | ||
|
|
f351184630 | ||
| fac1eab1bc | |||
| 832ca6fd05 | |||
| 673423131a | |||
| 6e15430a83 | |||
| 51068d2215 | |||
| d493d9eff6 | |||
| 7d970a7bba | |||
| 3fc6720bf7 | |||
| efa2e3a4a9 | |||
| c6af01bc51 | |||
|
|
448af4ab6b | ||
|
|
8a9f160cfa | ||
|
|
6e06c8b516 | ||
|
|
322fb9c46b | ||
|
|
30bfd22e3e | ||
|
|
e8d8b715ae | ||
|
|
7d2149dcaf | ||
|
|
fee9334b1f | ||
|
|
85c486c3dc | ||
|
|
0e7ef80eed | ||
|
|
8ccbbe41b1 | ||
|
|
98468ea7aa | ||
|
|
a9d9bdcb71 | ||
|
|
7459583377 | ||
|
|
385ff2d4aa | ||
|
|
02ad5db269 | ||
|
|
1d90963ded | ||
|
|
d1fefceebf | ||
|
|
242ebfc1df | ||
|
|
b8cf3d25b4 | ||
|
|
95647be610 | ||
|
|
e966ed5aa5 | ||
|
|
0d4d464e3f | ||
|
|
4bc79e62ca | ||
|
|
bf1fb8e514 | ||
|
|
d720bf2209 | ||
|
|
8f486867d5 | ||
|
|
1f45fe48a3 | ||
|
|
79865d9a96 | ||
|
|
a9a5964127 | ||
|
|
47e991cd76 | ||
|
|
8bc1ea576e | ||
|
|
31e848e8bb | ||
|
|
6da3712a76 | ||
|
|
e6da512a31 | ||
|
|
16d4844cca | ||
|
|
978e0d998d |
44
.gitea/workflows/develop_build_commit.yaml
Normal file
44
.gitea/workflows/develop_build_commit.yaml
Normal file
@@ -0,0 +1,44 @@
|
||||
name: git commit AiDA python develop 分支构建部署
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
branches:
|
||||
- develop
|
||||
|
||||
jobs:
|
||||
scheduled_deploy:
|
||||
runs-on: ubuntu-latest
|
||||
if: "contains(github.event.head_commit.message, '[run build]')"
|
||||
|
||||
env:
|
||||
REMOTE_DEPLOY_PATH: /workspace/Trinity/Fastapi_AiDA_Trinity_Dev
|
||||
|
||||
steps:
|
||||
- name: 1.检出代码
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: 'develop'
|
||||
|
||||
- name: 2.复制文件到服务器
|
||||
uses: appleboy/scp-action@v0.1.7
|
||||
with:
|
||||
host: ${{ secrets.SERVER_HOST }}
|
||||
username: ${{ secrets.SERVER_USER }}
|
||||
password: ${{ secrets.SERVER_PASSWORD }}
|
||||
source: "."
|
||||
target: ${{ env.REMOTE_DEPLOY_PATH }}
|
||||
|
||||
- name: Restart Docker containers
|
||||
uses: appleboy/ssh-action@v0.1.10
|
||||
with:
|
||||
host: ${{ secrets.SERVER_HOST }}
|
||||
username: ${{ secrets.SERVER_USER }}
|
||||
password: ${{ secrets.SERVER_PASSWORD }}
|
||||
script: |
|
||||
# 进入项目目录
|
||||
cd ${{ env.REMOTE_DEPLOY_PATH }}
|
||||
|
||||
docker-compose down 2>&1
|
||||
docker-compose up -d --build --remove-orphans 2>&1
|
||||
|
||||
docker image prune -f 2>&1
|
||||
40
.gitea/workflows/develop_build_manual.yaml
Normal file
40
.gitea/workflows/develop_build_manual.yaml
Normal file
@@ -0,0 +1,40 @@
|
||||
name: 手动 AiDA python develop 分支构建部署
|
||||
on:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
scheduled_deploy:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
env:
|
||||
REMOTE_DEPLOY_PATH: /workspace/Trinity/Fastapi_AiDA_Trinity_Dev
|
||||
|
||||
steps:
|
||||
- name: 1.检出代码
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: 'develop'
|
||||
|
||||
- name: 2.复制文件到服务器
|
||||
uses: appleboy/scp-action@v0.1.7
|
||||
with:
|
||||
host: ${{ secrets.SERVER_HOST }}
|
||||
username: ${{ secrets.SERVER_USER }}
|
||||
password: ${{ secrets.SERVER_PASSWORD }}
|
||||
source: "."
|
||||
target: ${{ env.REMOTE_DEPLOY_PATH }}
|
||||
|
||||
- name: 3.重启docker-compose
|
||||
uses: appleboy/ssh-action@v0.1.10
|
||||
with:
|
||||
host: ${{ secrets.SERVER_HOST }}
|
||||
username: ${{ secrets.SERVER_USER }}
|
||||
password: ${{ secrets.SERVER_PASSWORD }}
|
||||
script: |
|
||||
# 进入项目目录
|
||||
cd ${{ env.REMOTE_DEPLOY_PATH }}
|
||||
|
||||
docker-compose down 2>&1
|
||||
docker-compose up -d --build --remove-orphans 2>&1
|
||||
|
||||
docker image prune -f 2>&1
|
||||
42
.gitea/workflows/develop_build_scheduled.yaml
Normal file
42
.gitea/workflows/develop_build_scheduled.yaml
Normal file
@@ -0,0 +1,42 @@
|
||||
name: 定时 AiDA python develop 分支构建部署
|
||||
on:
|
||||
# 使用 schedule 触发器,遵循标准的 Cron 格式 (分钟 小时-8 日期 月份 星期)
|
||||
schedule:
|
||||
- cron: '30 9 * * *'
|
||||
|
||||
jobs:
|
||||
scheduled_deploy:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
env:
|
||||
REMOTE_DEPLOY_PATH: /workspace/Trinity/Fastapi_AiDA_Trinity_Dev
|
||||
|
||||
steps:
|
||||
- name: 1.检出代码
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: 'develop'
|
||||
|
||||
- name: 2.复制文件到服务器
|
||||
uses: appleboy/scp-action@v0.1.7
|
||||
with:
|
||||
host: ${{ secrets.SERVER_HOST }}
|
||||
username: ${{ secrets.SERVER_USER }}
|
||||
password: ${{ secrets.SERVER_PASSWORD }}
|
||||
source: "."
|
||||
target: ${{ env.REMOTE_DEPLOY_PATH }}
|
||||
|
||||
- name: Restart Docker containers
|
||||
uses: appleboy/ssh-action@v0.1.10
|
||||
with:
|
||||
host: ${{ secrets.SERVER_HOST }}
|
||||
username: ${{ secrets.SERVER_USER }}
|
||||
password: ${{ secrets.SERVER_PASSWORD }}
|
||||
script: |
|
||||
# 进入项目目录
|
||||
cd ${{ env.REMOTE_DEPLOY_PATH }}
|
||||
|
||||
docker-compose down 2>&1
|
||||
docker-compose up -d --build --remove-orphans 2>&1
|
||||
|
||||
docker image prune -f 2>&1
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -149,3 +149,5 @@ app/logs/*
|
||||
*.csv
|
||||
*.avi
|
||||
*.json
|
||||
*.env*
|
||||
config.backup.py
|
||||
22
Dockerfile
Normal file
22
Dockerfile
Normal file
@@ -0,0 +1,22 @@
|
||||
FROM python:3.12-slim
|
||||
|
||||
# Install uv.
|
||||
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
|
||||
|
||||
# Copy the application into the container.
|
||||
COPY . /app
|
||||
|
||||
# Install the application dependencies.
|
||||
WORKDIR /app
|
||||
RUN mkdir /seg_cache
|
||||
# 更新索引并安装替代包
|
||||
RUN apt-get update && apt-get install -y \
|
||||
vim \
|
||||
libgl1 \
|
||||
libglib2.0-0 \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN uv sync --frozen --no-cache
|
||||
|
||||
# Run the application.
|
||||
CMD ["/app/.venv/bin/fastapi", "run", "app/main.py", "--port", "80", "--host", "0.0.0.0"]
|
||||
@@ -23,11 +23,11 @@
|
||||
$ pip install mmcv==1.4.2 -f https://download.openmmlab.com/mmcv/dist/cu117/torch1.13/index.html
|
||||
|
||||
|
||||
2. 启动服务器
|
||||
1. 启动服务器
|
||||
|
||||
$ uvicorn app.main:app --host 0.0.0.0 --port 8000
|
||||
|
||||
3. 打开 http://127.0.0.1:8000/docs
|
||||
2. 打开 http://127.0.0.1:8000/docs
|
||||
|
||||
Docker 部署
|
||||
---------------
|
||||
|
||||
@@ -2,8 +2,7 @@ import json
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from app.core.config import DEBUG
|
||||
from app.core.config import settings
|
||||
from app.schemas.attribute_retrieve import *
|
||||
from app.schemas.response_template import ResponseModel
|
||||
from app.service.attribute.config import const, local_debug_const
|
||||
@@ -35,13 +34,13 @@ def attribute_recognition(request_item: list[AttributeRecognitionModel]):
|
||||
"""
|
||||
try:
|
||||
for item in request_item:
|
||||
logger.debug(f"attribute_recognition request item is : @@@@@@:{json.dumps(item.dict())}")
|
||||
if DEBUG:
|
||||
logger.info(f"attribute_recognition request item is : @@@@@@:{json.dumps(item.dict(), indent=4)}")
|
||||
if settings.DEBUG:
|
||||
service = AttributeRecognition(const=local_debug_const, request_data=request_item)
|
||||
else:
|
||||
service = AttributeRecognition(const=const, request_data=request_item)
|
||||
data = service.get_result()
|
||||
logger.debug(f"attribute_recognition response @@@@@@:{json.dumps(data)}")
|
||||
logger.info(f"attribute_recognition response @@@@@@:{json.dumps(data, indent=4)}")
|
||||
except Exception as e:
|
||||
logger.warning(f"attribute_recognition Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
@@ -67,10 +66,10 @@ def category_recognition(request_item: list[CategoryRecognitionModel]):
|
||||
"""
|
||||
try:
|
||||
for item in request_item:
|
||||
logger.info(f"category_recognition request item is : @@@@@@:{json.dumps(item.dict())}")
|
||||
logger.info(f"category_recognition request item is : @@@@@@:{json.dumps(item.dict(), indent=4)}")
|
||||
service = CategoryRecognition(request_data=request_item)
|
||||
data = service.get_result()
|
||||
logger.info(f"category_recognition response @@@@@@:{json.dumps(data)}")
|
||||
logger.info(f"category_recognition response @@@@@@:{json.dumps(data, indent=4)}")
|
||||
except Exception as e:
|
||||
logger.warning(f"category_recognition Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
@@ -26,7 +26,7 @@ def seg_product(request_item: BrandDnaModel):
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"brand dna request item is : @@@@@@:{json.dumps(request_item.dict())}")
|
||||
logger.info(f"brand dna request item is : @@@@@@:{json.dumps(request_item.dict(),indent=4)}")
|
||||
service = BrandDna(request_item)
|
||||
result_url = service.get_result()
|
||||
except Exception as e:
|
||||
@@ -36,7 +36,7 @@ def seg_product(request_item: BrandDnaModel):
|
||||
|
||||
|
||||
@router.post("/GenerateBrand")
|
||||
def GenerateBrand(request_data: GenerateBrandModel):
|
||||
def generate_brand(request_data: GenerateBrandModel):
|
||||
"""
|
||||
通过prompt 生成 brand name ,brand slogan , brand logo。
|
||||
创建一个具有以下参数的请求体:
|
||||
|
||||
@@ -9,7 +9,6 @@ from apscheduler.schedulers.background import BackgroundScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from fastapi import HTTPException, APIRouter
|
||||
|
||||
from app.service.recommend.service import load_resources, matrix_data
|
||||
import pymysql
|
||||
from app.core.config import DB_CONFIG, TABLE_CATEGORIES, RECOMMEND_PATH_PREFIX
|
||||
from minio import Minio
|
||||
|
||||
@@ -5,10 +5,11 @@ import time
|
||||
|
||||
from PIL import ImageEnhance
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from minio import Minio
|
||||
from app.core.config import settings
|
||||
from app.schemas.brighten import BrightenModel
|
||||
from app.schemas.response_template import ResponseModel
|
||||
from app.service.utils.oss_client import oss_get_image, oss_upload_image
|
||||
from app.service.utils.new_oss_client import oss_get_image, oss_upload_image
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger()
|
||||
@@ -20,6 +21,9 @@ def increase_brightness(img, factor):
|
||||
return bright_img
|
||||
|
||||
|
||||
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
|
||||
|
||||
|
||||
@router.post("/brighten")
|
||||
async def brighten(request_item: BrightenModel):
|
||||
"""
|
||||
@@ -35,14 +39,14 @@ async def brighten(request_item: BrightenModel):
|
||||
"""
|
||||
try:
|
||||
start_time = time.time()
|
||||
logger.info(f"brighten request item is : @@@@@@:{json.dumps(request_item.dict())}")
|
||||
image = oss_get_image(bucket=request_item.image_url.split('/')[0], object_name=request_item.image_url[request_item.image_url.find('/') + 1:], data_type="PIL")
|
||||
logger.info(f"brighten request item is : @@@@@@:{json.dumps(request_item.dict(),indent=4)}")
|
||||
image = oss_get_image(oss_client=minio_client, bucket=request_item.image_url.split('/')[0], object_name=request_item.image_url[request_item.image_url.find('/') + 1:], data_type="PIL")
|
||||
new_image = increase_brightness(image, request_item.brighten_value)
|
||||
image_data = io.BytesIO()
|
||||
new_image.save(image_data, format='PNG')
|
||||
image_data.seek(0)
|
||||
image_bytes = image_data.read()
|
||||
req = oss_upload_image(bucket=request_item.image_url.split('/')[0], object_name=request_item.image_url[request_item.image_url.find('/') + 1:], image_bytes=image_bytes)
|
||||
req = oss_upload_image(oss_client=minio_client, bucket=request_item.image_url.split('/')[0], object_name=request_item.image_url[request_item.image_url.find('/') + 1:], image_bytes=image_bytes)
|
||||
brighten_url = f"{req.bucket_name}/{req.object_name}"
|
||||
logger.info(f"run time is : {time.time() - start_time}")
|
||||
except Exception as e:
|
||||
|
||||
@@ -30,9 +30,9 @@ def chat_robot(request_data: ChatRobotModel):
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"chat_robot request item is : @@@@@@:{json.dumps(request_data.dict())}")
|
||||
logger.info(f"chat_robot request item is : @@@@@@:{json.dumps(request_data.dict(),indent=4)}")
|
||||
data = chat(post_data=request_data)
|
||||
logger.info(f"chat_robot response @@@@@@:{json.dumps(data)}")
|
||||
logger.info(f"chat_robot response @@@@@@:{json.dumps(data, indent=4)}")
|
||||
except Exception as e:
|
||||
logger.warning(f"chat_robot Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
@@ -42,7 +42,7 @@ def clothing_seg(request_item: ClothingSegModel):
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"clothing_seg request item is : @@@@@@:{json.dumps(request_item.dict())}")
|
||||
logger.info(f"clothing_seg request item is : @@@@@@:{json.dumps(request_item.dict(),indent=4)}")
|
||||
server = ClothingSeg(request_item)
|
||||
result_url = server.get_result()
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,64 +1,76 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
from fastapi import APIRouter, HTTPException, UploadFile, File, Form, BackgroundTasks
|
||||
import requests
|
||||
from fastapi import APIRouter, HTTPException, BackgroundTasks
|
||||
|
||||
from app.schemas.design import DesignModel, DesignProgressModel, ModelProgressModel, DBGConfigModel, DesignStreamModel
|
||||
from app.schemas.design import DesignModel, ModelProgressModel, DesignStreamModel, SAMRequestModel
|
||||
from app.schemas.response_template import ResponseModel
|
||||
from app.service.design.model_process_service import model_transpose
|
||||
from app.service.design_batch.service import start_design_batch_generate
|
||||
from app.service.design_fast.design_generate import design_generate, design_generate_v2
|
||||
from app.service.design_fast.utils.redis_utils import Redis
|
||||
from app.service.design_fast.model_process_service import model_transpose
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
@router.post("/design")
|
||||
def design(request_data: DesignModel, background_tasks: BackgroundTasks):
|
||||
def design(request_data: DesignModel):
|
||||
"""
|
||||
objects.items.transparent:
|
||||
"transparent":{
|
||||
"mask_url":"test/transparent_test/transparent_mask.png",
|
||||
"scale":0.1
|
||||
},
|
||||
mask_url 为空"" -> 单件衣服透明
|
||||
mask_url 非空"mask_url" -> 区域透明
|
||||
- **objects.items.transparent**:
|
||||
```json
|
||||
"transparent":{
|
||||
"mask_url":"test/transparent_test/transparent_mask.png",
|
||||
"scale":0.1
|
||||
},
|
||||
```
|
||||
- **mask_url** 为空"" -> 单件衣服透明
|
||||
- **mask_url** 非空"mask_url" -> 区域透明
|
||||
- **transpose** 镜像模式 ,:"top_bottom"或"left_right"
|
||||
- **rotate** 45,
|
||||
|
||||
创建一个具有以下参数的请求体:
|
||||
示例参数:
|
||||
- ** design 参数变更:
|
||||
design detail 请求参数中 basic -> preview_submit 替换为design_type 可选参数 default ,merge (移除preview和submit)
|
||||
design_type 参数说明:
|
||||
defuault模式下 请求参数不变
|
||||
merge模式下 items -> 每个item需要新增 merge_image_path , merge_image_path为前端处理 print color等操作后的单件结果图
|
||||
|
||||
**
|
||||
|
||||
- 创建一个具有以下参数的请求体:
|
||||
示例参数:
|
||||
```json
|
||||
{
|
||||
"objects": [
|
||||
{
|
||||
"basic": {
|
||||
"body_point_test": {
|
||||
"waistband_right": [
|
||||
200,
|
||||
241
|
||||
203,
|
||||
249
|
||||
],
|
||||
"hand_point_right": [
|
||||
223,
|
||||
297
|
||||
229,
|
||||
343
|
||||
],
|
||||
"waistband_left": [
|
||||
112,
|
||||
241
|
||||
119,
|
||||
248
|
||||
],
|
||||
"hand_point_left": [
|
||||
92,
|
||||
305
|
||||
97,
|
||||
343
|
||||
],
|
||||
"shoulder_left": [
|
||||
99,
|
||||
116
|
||||
108,
|
||||
107
|
||||
],
|
||||
"shoulder_right": [
|
||||
215,
|
||||
116
|
||||
212,
|
||||
107
|
||||
]
|
||||
},
|
||||
"layer_order": true,
|
||||
"design_type": "preview",
|
||||
"scale_bag": 0.7,
|
||||
"scale_earrings": 0.16,
|
||||
"self_template": true,
|
||||
@@ -67,14 +79,19 @@ def design(request_data: DesignModel, background_tasks: BackgroundTasks):
|
||||
},
|
||||
"items": [
|
||||
{
|
||||
"businessId": 270372,
|
||||
"color": "30 28 28",
|
||||
"image_id": 69780,
|
||||
"businessId": 2115382,
|
||||
"color": "",
|
||||
"image_id": 61686,
|
||||
"offset": [
|
||||
0,
|
||||
0
|
||||
],
|
||||
"path": "aida-sys-image/images/female/trousers/0825000630.jpg",
|
||||
"path": "aida-sys-image/images/female/dress/0628000564.jpg",
|
||||
"transpose": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"rotate": 45,
|
||||
"print": {
|
||||
"element": {
|
||||
"element_angle_list": [],
|
||||
@@ -83,10 +100,30 @@ def design(request_data: DesignModel, background_tasks: BackgroundTasks):
|
||||
"location": []
|
||||
},
|
||||
"overall": {
|
||||
"location": [],
|
||||
"print_angle_list": [],
|
||||
"print_path_list": [],
|
||||
"print_scale_list": []
|
||||
"location": [
|
||||
[
|
||||
53.0,
|
||||
118.5
|
||||
]
|
||||
],
|
||||
"print_angle_list": [
|
||||
0.0
|
||||
],
|
||||
"print_path_list": [
|
||||
"aida-users/89/print/02d57aa8-f342-4e1d-b02c-b278f94dcfe6-3-89.png"
|
||||
],
|
||||
"print_scale_list": [
|
||||
[
|
||||
0.5,
|
||||
0.5
|
||||
]
|
||||
],
|
||||
"gap": [
|
||||
[
|
||||
10,
|
||||
10
|
||||
]
|
||||
]
|
||||
},
|
||||
"single": {
|
||||
"location": [],
|
||||
@@ -100,104 +137,30 @@ def design(request_data: DesignModel, background_tasks: BackgroundTasks):
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Trousers"
|
||||
"seg_mask_url": "aida-clothing/mask/mask_9698b428-eb93-11f0-9327-0242c0a80003.png",
|
||||
"type": "Dress"
|
||||
},
|
||||
{
|
||||
"businessId": 270373,
|
||||
"color": "30 28 28",
|
||||
"image_id": 98243,
|
||||
"offset": [
|
||||
0,
|
||||
0
|
||||
],
|
||||
"path": "aida-sys-image/images/female/blouse/0902003811.jpg",
|
||||
"print": {
|
||||
"element": {
|
||||
"element_angle_list": [],
|
||||
"element_path_list": [],
|
||||
"element_scale_list": [],
|
||||
"location": []
|
||||
},
|
||||
"overall": {
|
||||
"location": [],
|
||||
"print_angle_list": [],
|
||||
"print_path_list": [],
|
||||
"print_scale_list": []
|
||||
},
|
||||
"single": {
|
||||
"location": [],
|
||||
"print_angle_list": [],
|
||||
"print_path_list": [],
|
||||
"print_scale_list": []
|
||||
}
|
||||
},
|
||||
"priority": 11,
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Blouse"
|
||||
},
|
||||
{
|
||||
"businessId": 270374,
|
||||
"color": "172 68 68",
|
||||
"image_id": 98244,
|
||||
"offset": [
|
||||
0,
|
||||
0
|
||||
],
|
||||
"path": "aida-sys-image/images/female/outwear/0825000410.jpg",
|
||||
"print": {
|
||||
"element": {
|
||||
"element_angle_list": [],
|
||||
"element_path_list": [],
|
||||
"element_scale_list": [],
|
||||
"location": []
|
||||
},
|
||||
"overall": {
|
||||
"location": [],
|
||||
"print_angle_list": [],
|
||||
"print_path_list": [],
|
||||
"print_scale_list": []
|
||||
},
|
||||
"single": {
|
||||
"location": [],
|
||||
"print_angle_list": [],
|
||||
"print_path_list": [],
|
||||
"print_scale_list": []
|
||||
}
|
||||
},
|
||||
"priority": 12,
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"transparent":{
|
||||
"mask_url":"test/transparent_test/transparent_mask.png",
|
||||
"scale":0.1
|
||||
},
|
||||
"type": "Outwear"
|
||||
},
|
||||
{
|
||||
"body_path": "aida-sys-image/models/female/5bdfe7ca-64eb-44e4-b03d-8e517520c795.png",
|
||||
"image_id": 96090,
|
||||
"body_path": "aida-sys-image/models/female/2e4815b9-1191-419d-94ed-5771239ca4a5.png",
|
||||
"image_id": 67277,
|
||||
"type": "Body"
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
"process_id": "83"
|
||||
"process_id": "89"
|
||||
}
|
||||
"""
|
||||
# logger.info(f"design request item is : @@@@@@:{json.dumps(request_data.dict())}")
|
||||
```
|
||||
"""
|
||||
# logger.info(f"design request item is : @@@@@@:{json.dumps(request_data.dict(),indent=4)}")
|
||||
# data = generate(request_data=request_data)
|
||||
# logger.info(f"design response @@@@@@:{json.dumps(data)}")
|
||||
# logger.info(f"design response @@@@@@:{json.dumps(data, indent=4)}")
|
||||
#
|
||||
|
||||
try:
|
||||
logger.info(f"design request item is : @@@@@@:{json.dumps(request_data.dict())}")
|
||||
logger.info(f"design request item is : @@@@@@:{json.dumps(request_data.dict(), indent=4)}")
|
||||
data = design_generate(request_data=request_data)
|
||||
logger.info(f"design response @@@@@@:{json.dumps(data)}")
|
||||
logger.info(f"design response @@@@@@:{json.dumps(data, indent=4)}")
|
||||
except Exception as e:
|
||||
logger.warning(f"design Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
@@ -215,47 +178,48 @@ async def design_v2(request_data: DesignStreamModel, background_tasks: Backgroun
|
||||
"basic": {
|
||||
"body_point_test": {
|
||||
"waistband_right": [
|
||||
200,
|
||||
241
|
||||
203,
|
||||
249
|
||||
],
|
||||
"hand_point_right": [
|
||||
223,
|
||||
297
|
||||
229,
|
||||
343
|
||||
],
|
||||
"waistband_left": [
|
||||
112,
|
||||
241
|
||||
119,
|
||||
248
|
||||
],
|
||||
"hand_point_left": [
|
||||
92,
|
||||
305
|
||||
97,
|
||||
343
|
||||
],
|
||||
"shoulder_left": [
|
||||
99,
|
||||
116
|
||||
108,
|
||||
107
|
||||
],
|
||||
"relation_type": "System",
|
||||
"shoulder_right": [
|
||||
215,
|
||||
116
|
||||
]
|
||||
212,
|
||||
107
|
||||
],
|
||||
"relation_id": 1020356
|
||||
},
|
||||
"layer_order": true,
|
||||
"layer_order": false,
|
||||
"scale_bag": 0.7,
|
||||
"scale_earrings": 0.16,
|
||||
"self_template": true,
|
||||
"self_template": false,
|
||||
"single_overall": "overall",
|
||||
"switch_category": ""
|
||||
},
|
||||
"items": [
|
||||
{
|
||||
"businessId": 270372,
|
||||
"color": "30 28 28",
|
||||
"image_id": 69780,
|
||||
"color": "209 196 171",
|
||||
"image_id": 84093,
|
||||
"offset": [
|
||||
0,
|
||||
0
|
||||
1,
|
||||
1
|
||||
],
|
||||
"path": "aida-sys-image/images/female/trousers/0825000630.jpg",
|
||||
"path": "aida-users/89/sketchboard/female/Outwear/0943d209-7ce0-408c-bc61-83f15da94138.png",
|
||||
"print": {
|
||||
"element": {
|
||||
"element_angle_list": [],
|
||||
@@ -264,10 +228,23 @@ async def design_v2(request_data: DesignStreamModel, background_tasks: Backgroun
|
||||
"location": []
|
||||
},
|
||||
"overall": {
|
||||
"location": [],
|
||||
"print_angle_list": [],
|
||||
"location": [
|
||||
[
|
||||
0.0,
|
||||
0.0
|
||||
]
|
||||
],
|
||||
"print_angle_list": [
|
||||
0.0,
|
||||
0.0
|
||||
],
|
||||
"print_path_list": [],
|
||||
"print_scale_list": []
|
||||
"print_scale_list": [
|
||||
[
|
||||
0.0,
|
||||
0.0
|
||||
]
|
||||
]
|
||||
},
|
||||
"single": {
|
||||
"location": [],
|
||||
@@ -276,22 +253,20 @@ async def design_v2(request_data: DesignStreamModel, background_tasks: Backgroun
|
||||
"print_scale_list": []
|
||||
}
|
||||
},
|
||||
"priority": 10,
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Trousers"
|
||||
"type": "Outwear"
|
||||
},
|
||||
{
|
||||
"businessId": 270373,
|
||||
"color": "30 28 28",
|
||||
"image_id": 98243,
|
||||
"color": "63 71 73",
|
||||
"image_id": 100496,
|
||||
"offset": [
|
||||
0,
|
||||
0
|
||||
1,
|
||||
1
|
||||
],
|
||||
"path": "aida-sys-image/images/female/blouse/0902003811.jpg",
|
||||
"path": "aida-sys-image/images/female/blouse/0628001684.jpg",
|
||||
"print": {
|
||||
"element": {
|
||||
"element_angle_list": [],
|
||||
@@ -300,10 +275,23 @@ async def design_v2(request_data: DesignStreamModel, background_tasks: Backgroun
|
||||
"location": []
|
||||
},
|
||||
"overall": {
|
||||
"location": [],
|
||||
"print_angle_list": [],
|
||||
"location": [
|
||||
[
|
||||
0.0,
|
||||
0.0
|
||||
]
|
||||
],
|
||||
"print_angle_list": [
|
||||
0.0,
|
||||
0.0
|
||||
],
|
||||
"print_path_list": [],
|
||||
"print_scale_list": []
|
||||
"print_scale_list": [
|
||||
[
|
||||
0.0,
|
||||
0.0
|
||||
]
|
||||
]
|
||||
},
|
||||
"single": {
|
||||
"location": [],
|
||||
@@ -312,7 +300,6 @@ async def design_v2(request_data: DesignStreamModel, background_tasks: Backgroun
|
||||
"print_scale_list": []
|
||||
}
|
||||
},
|
||||
"priority": 11,
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
@@ -320,14 +307,14 @@ async def design_v2(request_data: DesignStreamModel, background_tasks: Backgroun
|
||||
"type": "Blouse"
|
||||
},
|
||||
{
|
||||
"businessId": 270374,
|
||||
"color": "172 68 68",
|
||||
"image_id": 98244,
|
||||
"color": "111 78 63",
|
||||
"gradient": "aida-gradient/f69b98e8-4248-4f7a-98a2-21bac41bf3e0.png",
|
||||
"image_id": 92193,
|
||||
"offset": [
|
||||
0,
|
||||
0
|
||||
1,
|
||||
1
|
||||
],
|
||||
"path": "aida-sys-image/images/female/outwear/0825000410.jpg",
|
||||
"path": "aida-sys-image/images/female/trousers/0825001160.jpg",
|
||||
"print": {
|
||||
"element": {
|
||||
"element_angle_list": [],
|
||||
@@ -336,10 +323,23 @@ async def design_v2(request_data: DesignStreamModel, background_tasks: Backgroun
|
||||
"location": []
|
||||
},
|
||||
"overall": {
|
||||
"location": [],
|
||||
"print_angle_list": [],
|
||||
"location": [
|
||||
[
|
||||
0.0,
|
||||
0.0
|
||||
]
|
||||
],
|
||||
"print_angle_list": [
|
||||
0.0,
|
||||
0.0
|
||||
],
|
||||
"print_path_list": [],
|
||||
"print_scale_list": []
|
||||
"print_scale_list": [
|
||||
[
|
||||
0.0,
|
||||
0.0
|
||||
]
|
||||
]
|
||||
},
|
||||
"single": {
|
||||
"location": [],
|
||||
@@ -348,31 +348,37 @@ async def design_v2(request_data: DesignStreamModel, background_tasks: Backgroun
|
||||
"print_scale_list": []
|
||||
}
|
||||
},
|
||||
"priority": 12,
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"transparent":{
|
||||
"mask_url":"test/transparent_test/transparent_mask.png",
|
||||
"scale":0.1
|
||||
},
|
||||
"type": "Outwear"
|
||||
"type": "Trousers"
|
||||
},
|
||||
{
|
||||
"body_path": "aida-sys-image/models/female/5bdfe7ca-64eb-44e4-b03d-8e517520c795.png",
|
||||
"image_id": 96090,
|
||||
"body_path": "aida-sys-image/models/female/2e4815b9-1191-419d-94ed-5771239ca4a5.png",
|
||||
"image_id": 67277,
|
||||
"offset": [
|
||||
1,
|
||||
1
|
||||
],
|
||||
"resize_scale": [
|
||||
1.0,
|
||||
1.0
|
||||
],
|
||||
"type": "Body"
|
||||
}
|
||||
]
|
||||
],
|
||||
"objectSign": "65830966"
|
||||
}
|
||||
],
|
||||
"process_id": "83"
|
||||
"process_id": "4802946666428422",
|
||||
"requestId": "1d1e7641-0d62-4da2-adc0-b4404910723c",
|
||||
"callback_url": "https://api.aida.com.hk/api/third/party/receiveDesignResults"
|
||||
}
|
||||
"""
|
||||
try:
|
||||
# 异步
|
||||
logger.info(f"generate_image request item is : @@@@@@:{json.dumps(request_data.dict())}")
|
||||
logger.info(f"generate_image request item is : @@@@@@:{json.dumps(request_data.dict(), indent=4)}")
|
||||
background_tasks.add_task(design_generate_v2, request_data)
|
||||
except Exception as e:
|
||||
logger.warning(f"design Run Exception @@@@@@:{e}")
|
||||
@@ -380,30 +386,76 @@ async def design_v2(request_data: DesignStreamModel, background_tasks: Backgroun
|
||||
return ResponseModel()
|
||||
|
||||
|
||||
@router.post('/get_progress')
|
||||
def get_progress(request_data: DesignProgressModel):
|
||||
@router.post("/seg_anything")
|
||||
async def seg_anything(request_data: SAMRequestModel):
|
||||
"""
|
||||
获取design 进度
|
||||
创建一个具有以下参数的请求体:
|
||||
- **process_id**: 进度id
|
||||
**Segment Anything 交互式分割接口**
|
||||
|
||||
示例参数:
|
||||
{
|
||||
"process_id": "6878547032381675"
|
||||
}
|
||||
通过传入图片路径和点击的点坐标,返回分割后的掩码数据。
|
||||
|
||||
### 参数说明:
|
||||
- **user_id**:用户id 用于存储分割图
|
||||
- **image_path**: 图片在服务器或云端的相对路径。
|
||||
- **type**: 推理类型
|
||||
- **box**: 框选矩形点位信息
|
||||
- **points**: 交互点的坐标列表。每个点为 [x, y] 像素格式。
|
||||
- **labels**: 坐标点的属性标签,必须与 points 长度一致:
|
||||
- 1: **前景点** (代表想要分割出的区域)
|
||||
- 0: **背景点** (代表想要排除的区域)
|
||||
|
||||
### 请求体示例:
|
||||
```json
|
||||
point
|
||||
{
|
||||
"user_id": 1,
|
||||
"image_path": "aida-users/89/sketch/4e8fe37d-7068-400a-ac94-c01647fa5f6f.png",
|
||||
"type":"point",
|
||||
"points": [[310, 403], [493, 375], [261, 266], [404, 484]],
|
||||
"labels": [1, 1, 0, 1]
|
||||
}
|
||||
|
||||
box
|
||||
{
|
||||
"user_id": 1,
|
||||
"image_path": "aida-users/89/sketch/4e8fe37d-7068-400a-ac94-c01647fa5f6f.png",
|
||||
"type":"box",
|
||||
"box": [350, 286, 544, 520]
|
||||
}
|
||||
```
|
||||
"""
|
||||
try:
|
||||
logger.info(f"get_progress request item is : @@@@@@:{json.dumps(request_data.dict())}")
|
||||
process_id = request_data.process_id
|
||||
r = Redis()
|
||||
data = r.read(key=process_id)
|
||||
if data is None:
|
||||
raise ValueError(f"No progress ID: {process_id}")
|
||||
logging.info(f"get_progress process_id @@@@@@ : {process_id} , progress : {json.dumps(data)}")
|
||||
logger.info(f"seg_anything request item is : @@@@@@:{json.dumps(request_data.dict(), indent=4)}")
|
||||
data = requests.post("http://10.1.1.240:10075/predict", json=request_data.dict())
|
||||
logger.info(f"seg_anything response @@@@@@:{json.dumps(json.loads(data.content), indent=4)}")
|
||||
return ResponseModel(data=json.loads(data.content))
|
||||
except Exception as e:
|
||||
logger.warning(f"get_progress Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel(data=data)
|
||||
logger.warning(f"seg_anything Run Exception @@@@@@:{e}")
|
||||
|
||||
|
||||
# @router.post('/get_progress')
|
||||
# def get_progress(request_data: DesignProgressModel):
|
||||
# """
|
||||
# 获取design 进度
|
||||
# 创建一个具有以下参数的请求体:
|
||||
# - **process_id**: 进度id
|
||||
#
|
||||
# 示例参数:
|
||||
# {
|
||||
# "process_id": "6878547032381675"
|
||||
# }
|
||||
# """
|
||||
# try:
|
||||
# logger.info(f"get_progress request item is : @@@@@@:{json.dumps(request_data.dict(), indent=4)}")
|
||||
# process_id = request_data.process_id
|
||||
# r = Redis()
|
||||
# data = r.read(key=process_id)
|
||||
# if data is None:
|
||||
# raise ValueError(f"No progress ID: {process_id}")
|
||||
# logging.info(f"get_progress process_id @@@@@@ : {process_id} , progress : {json.dumps(data, indent=4)}")
|
||||
# except Exception as e:
|
||||
# logger.warning(f"get_progress Run Exception @@@@@@:{e}")
|
||||
# raise HTTPException(status_code=404, detail=str(e))
|
||||
# return ResponseModel(data=data)
|
||||
|
||||
|
||||
@router.post('/model_process')
|
||||
@@ -419,44 +471,42 @@ def model_process(request_data: ModelProgressModel):
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"model_process request item is : @@@@@@:{json.dumps(request_data.dict())}")
|
||||
logger.info(f"model_process request item is : @@@@@@:{json.dumps(request_data.dict(), indent=4)}")
|
||||
|
||||
data = model_transpose(image_path=request_data.model_path)
|
||||
logger.info(f"model_process response @@@@@@:{json.dumps(data)}")
|
||||
logger.info(f"model_process response @@@@@@:{json.dumps(data, indent=4)}")
|
||||
except Exception as e:
|
||||
logger.warning(f"model_process Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel(data=data)
|
||||
|
||||
|
||||
# ##############################################################
|
||||
|
||||
|
||||
@router.post("/design_batch_generate")
|
||||
async def design_batch(file: UploadFile = File(...),
|
||||
tasks_id: str = Form(...),
|
||||
user_id: str = Form(...),
|
||||
file_name: str = Form(...),
|
||||
total: int = Form(...)
|
||||
):
|
||||
dbg_config = DBGConfigModel(
|
||||
tasks_id=tasks_id,
|
||||
user_id=user_id,
|
||||
file_name=file_name,
|
||||
total=total
|
||||
)
|
||||
contents = await file.read()
|
||||
file_name = file.filename
|
||||
await save_request_file(contents, file_name)
|
||||
return await start_design_batch_generate(dbg_config, contents)
|
||||
|
||||
|
||||
async def save_request_file(contents, file_name):
|
||||
# 创建保存文件的目录(如果不存在)
|
||||
save_dir = os.path.join(os.getcwd(), "service/design_batch", "request_data")
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
# 处理文件
|
||||
file_path = os.path.join(save_dir, file_name)
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(contents)
|
||||
"""design 批量处理 停用"""
|
||||
# @router.post("/design_batch_generate")
|
||||
# async def design_batch(file: UploadFile = File(...),
|
||||
# tasks_id: str = Form(...),
|
||||
# user_id: str = Form(...),
|
||||
# file_name: str = Form(...),
|
||||
# total: int = Form(...)
|
||||
# ):
|
||||
# dbg_config = DBGConfigModel(
|
||||
# tasks_id=tasks_id,
|
||||
# user_id=user_id,
|
||||
# file_name=file_name,
|
||||
# total=total
|
||||
# )
|
||||
# contents = await file.read()
|
||||
# file_name = file.filename
|
||||
# await save_request_file(contents, file_name)
|
||||
# return await start_design_batch_generate(dbg_config, contents)
|
||||
#
|
||||
#
|
||||
# async def save_request_file(contents, file_name):
|
||||
# # 创建保存文件的目录(如果不存在)
|
||||
# save_dir = os.path.join(os.getcwd(), "service/design_batch", "request_data")
|
||||
# if not os.path.exists(save_dir):
|
||||
# os.makedirs(save_dir)
|
||||
# # 处理文件
|
||||
# file_path = os.path.join(save_dir, file_name)
|
||||
# with open(file_path, "wb") as f:
|
||||
# f.write(contents)
|
||||
|
||||
@@ -30,10 +30,10 @@ def design_pre_processing(request_data: DesignPreProcessingModel):
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"design_pre_processing request item is : @@@@@@:{json.dumps(request_data.dict())}")
|
||||
logger.info(f"design_pre_processing request item is : @@@@@@:{json.dumps(request_data.dict(), indent=4)}")
|
||||
server = DesignPreprocessing()
|
||||
data = server.pipeline(image_list=request_data.sketches)
|
||||
logger.info(f"design response @@@@@@:{json.dumps(data)}")
|
||||
logger.info(f"design response @@@@@@:{json.dumps(data, indent=4)}")
|
||||
except Exception as e:
|
||||
logger.warning(f"design Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
@@ -33,18 +33,30 @@ def generate_image(request_item: GenerateImageModel, background_tasks: Backgroun
|
||||
- **version**: 使用模型版本 fast 或者 high
|
||||
|
||||
示例参数:
|
||||
1. txt 2 img
|
||||
{
|
||||
"tasks_id": "123-89",
|
||||
"prompt": "skeleton sitting by the side of a river looking soulful, concert poster, 4k, artistic",
|
||||
"image_url": "aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg",
|
||||
"mode": "img2img",
|
||||
"category": "sketch",
|
||||
"gender": "male",
|
||||
"version": "fast"
|
||||
"tasks_id": "bd2cf809-24bc-49a6-91c9-193c6272a52e-2-89",
|
||||
"prompt": "a single item of sketch of dress, 4k, white background",
|
||||
"image_url": "",
|
||||
"mode": "txt2img",
|
||||
"category": "sketch",
|
||||
"gender": "Female",
|
||||
"version": "fast"
|
||||
}
|
||||
2. img 2 img
|
||||
{
|
||||
"tasks_id": "b861d4fa-5ae3-4a30-9c7a-7ba6bb9aa37b-1-89",
|
||||
"prompt": "a single item of sketch of dress, 4k, white background",
|
||||
"image_url": "aida-collection-element/89/Sketchboard/548da3a2-834f-49a7-b52c-e729c5ab5062.png",
|
||||
"mode": "img2img",
|
||||
"category": "sketch",
|
||||
"gender": "Female",
|
||||
"version": "fast"
|
||||
}
|
||||
|
||||
"""
|
||||
try:
|
||||
logger.info(f"generate_image request item is : @@@@@@:{json.dumps(request_item.dict())}")
|
||||
logger.info(f"generate_image request item is : @@@@@@:{json.dumps(request_item.dict(), indent=4)}")
|
||||
service = GenerateImage(request_item)
|
||||
background_tasks.add_task(service.get_result)
|
||||
except Exception as e:
|
||||
@@ -65,42 +77,41 @@ def generate_image(tasks_id: str):
|
||||
return ResponseModel(data=data['data'])
|
||||
|
||||
|
||||
'''multi view'''
|
||||
'''multi view 停用'''
|
||||
|
||||
# @router.post("/generate_multi_view")
|
||||
# def generate_multi_view(request_item: GenerateMultiViewModel, background_tasks: BackgroundTasks):
|
||||
# """
|
||||
# 创建一个具有以下参数的请求体:
|
||||
# - **tasks_id**: 任务id 用于取消生成任务和获取生成结果
|
||||
# - **image_url**: 前视角图的输入,minio或S3 url 地址
|
||||
#
|
||||
# 示例参数:
|
||||
# {
|
||||
# "tasks_id": "123-89",
|
||||
# "image_url": "aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg"
|
||||
# }
|
||||
# """
|
||||
# try:
|
||||
# logger.info(f"generate_multi_view request item is : @@@@@@:{json.dumps(request_item.dict(),indent=4)}")
|
||||
# service = GenerateMultiView(request_item)
|
||||
# background_tasks.add_task(service.get_result)
|
||||
# except Exception as e:
|
||||
# logger.warning(f"generate_multi_view Run Exception @@@@@@:{e}")
|
||||
# raise HTTPException(status_code=404, detail=str(e))
|
||||
# return ResponseModel()
|
||||
|
||||
|
||||
@router.post("/generate_multi_view")
|
||||
def generate_multi_view(request_item: GenerateMultiViewModel, background_tasks: BackgroundTasks):
|
||||
"""
|
||||
创建一个具有以下参数的请求体:
|
||||
- **tasks_id**: 任务id 用于取消生成任务和获取生成结果
|
||||
- **image_url**: 前视角图的输入,minio或S3 url 地址
|
||||
|
||||
示例参数:
|
||||
{
|
||||
"tasks_id": "123-89",
|
||||
"image_url": "aida-collection-element/87/Printboard/842c09cf-7297-42d9-9e6e-9c17d4a13cb5.jpg"
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"generate_multi_view request item is : @@@@@@:{json.dumps(request_item.dict())}")
|
||||
service = GenerateMultiView(request_item)
|
||||
background_tasks.add_task(service.get_result)
|
||||
except Exception as e:
|
||||
logger.warning(f"generate_multi_view Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel()
|
||||
|
||||
|
||||
@router.get("/generate_multi_view_cancel/{tasks_id}")
|
||||
def generate_multi_view(tasks_id: str):
|
||||
try:
|
||||
logger.info(f"generate_cancel request item is : @@@@@@:{tasks_id}")
|
||||
data = generate_multi_view_cancel(tasks_id)
|
||||
logger.info(f"generate_cancel response @@@@@@:{data}")
|
||||
except Exception as e:
|
||||
logger.warning(f"generate_cancel Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel(data=data['data'])
|
||||
# @router.get("/generate_multi_view_cancel/{tasks_id}")
|
||||
# def generate_multi_view(tasks_id: str):
|
||||
# try:
|
||||
# logger.info(f"generate_cancel request item is : @@@@@@:{tasks_id}")
|
||||
# data = generate_multi_view_cancel(tasks_id)
|
||||
# logger.info(f"generate_cancel response @@@@@@:{data}")
|
||||
# except Exception as e:
|
||||
# logger.warning(f"generate_cancel Run Exception @@@@@@:{e}")
|
||||
# raise HTTPException(status_code=404, detail=str(e))
|
||||
# return ResponseModel(data=data['data'])
|
||||
|
||||
|
||||
'''single logo'''
|
||||
@@ -122,7 +133,7 @@ def generate_single_logo(request_item: GenerateSingleLogoImageModel, background_
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"generate_single_logo request item is : @@@@@@:{json.dumps(request_item.dict())}")
|
||||
logger.info(f"generate_single_logo request item is : @@@@@@:{json.dumps(request_item.dict(), indent=4)}")
|
||||
service = GenerateSingleLogoImage(request_item)
|
||||
background_tasks.add_task(service.get_result)
|
||||
except Exception as e:
|
||||
@@ -167,7 +178,7 @@ def generate_product_image(request_item: GenerateProductImageModel, background_t
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"generate_product_image request item is : @@@@@@:{json.dumps(request_item.dict())}")
|
||||
logger.info(f"generate_product_image request item is : @@@@@@:{json.dumps(request_item.dict(),indent=4)}")
|
||||
service = GenerateProductImage(request_item)
|
||||
background_tasks.add_task(service.get_result)
|
||||
except Exception as e:
|
||||
@@ -188,166 +199,164 @@ def generate_product_image(tasks_id: str):
|
||||
return ResponseModel(data=data['data'])
|
||||
|
||||
|
||||
'''relight image'''
|
||||
'''relight image 停用'''
|
||||
|
||||
# @router.post("/generate_relight_image")
|
||||
# def generate_relight_image(request_item: GenerateRelightImageModel, background_tasks: BackgroundTasks):
|
||||
# """
|
||||
# 创建一个具有以下参数的请求体:
|
||||
# - **tasks_id**: 任务id 用于取消生成任务和获取生成结果
|
||||
# - **prompt**: 想要生成图片的描述词
|
||||
# - **image_url**: 被生成图片的S3或minio url地址
|
||||
# - **direction**: 光源方向 Right Light Left Light Top Light Bottom Light
|
||||
# - **product_type**: 输入single item 还是 overall item
|
||||
#
|
||||
#
|
||||
# 示例参数:
|
||||
# {
|
||||
# "tasks_id": "123-89",
|
||||
# "prompt": "beautiful woman, detailed face, sunshine, outdoor, warm atmosphere",
|
||||
# "image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png",
|
||||
# "direction": "Right Light",
|
||||
# "product_type": "overall"
|
||||
# }
|
||||
# """
|
||||
# try:
|
||||
# logger.info(f"generate_relight_image request item is : @@@@@@:{json.dumps(request_item.dict(),indent=4)}")
|
||||
# service = GenerateRelightImage(request_item)
|
||||
# background_tasks.add_task(service.get_result)
|
||||
# except Exception as e:
|
||||
# logger.warning(f"generate_relight_image Run Exception @@@@@@:{e}")
|
||||
# raise HTTPException(status_code=404, detail=str(e))
|
||||
# return ResponseModel()
|
||||
#
|
||||
#
|
||||
# @router.get("/generate_relight_image_cancel_cancel/{tasks_id}")
|
||||
# def generate_relight_image(tasks_id: str):
|
||||
# try:
|
||||
# logger.info(f"generate_relight_image_cancel_cancel request item is : @@@@@@:{tasks_id}")
|
||||
# data = generate_relight_image_cancel(tasks_id)
|
||||
# logger.info(f"generate_relight_image_cancel_cancel response @@@@@@:{data}")
|
||||
# except Exception as e:
|
||||
# logger.warning(f"generate_relight_image_cancel_cancel Run Exception @@@@@@:{e}")
|
||||
# raise HTTPException(status_code=404, detail=str(e))
|
||||
# return ResponseModel(data=data['data'])
|
||||
|
||||
|
||||
@router.post("/generate_relight_image")
|
||||
def generate_relight_image(request_item: GenerateRelightImageModel, background_tasks: BackgroundTasks):
|
||||
"""
|
||||
创建一个具有以下参数的请求体:
|
||||
- **tasks_id**: 任务id 用于取消生成任务和获取生成结果
|
||||
- **prompt**: 想要生成图片的描述词
|
||||
- **image_url**: 被生成图片的S3或minio url地址
|
||||
- **direction**: 光源方向 Right Light Left Light Top Light Bottom Light
|
||||
- **product_type**: 输入single item 还是 overall item
|
||||
"""batch generate img 停用"""
|
||||
|
||||
# @router.post("/batch_generate_product_image")
|
||||
# async def batch_generate_product(request_batch_item: BatchGenerateProductImageModel):
|
||||
# """
|
||||
# 创建一个具有以下参数的请求体:
|
||||
# - **tasks_id**: 任务id 用于获取生成结果
|
||||
# - **prompt**: 想要生成图片的描述词
|
||||
# - **image_url**: 被生成图片的S3或minio url地址
|
||||
# - **image_strength**: 生成强度,越低越接近原图
|
||||
# - **product_type**: 输入single item 还是 overall item
|
||||
# - **batch_size**: 批生成数量
|
||||
#
|
||||
#
|
||||
# 示例参数:
|
||||
# {
|
||||
# "tasks_id": "123-89",
|
||||
# "prompt": "the best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting",
|
||||
# "image_url": "aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png",
|
||||
# "image_strength": 0.8,
|
||||
# "product_type": "overall",
|
||||
# "batch_size": 1
|
||||
# }
|
||||
# """
|
||||
# return await start_product_batch_generate(request_batch_item)
|
||||
#
|
||||
#
|
||||
# @router.post("/batch_generate_relight_image")
|
||||
# async def batch_generate_relight(request_batch_item: BatchGenerateRelightImageModel):
|
||||
# """
|
||||
# 创建一个具有以下参数的请求体:
|
||||
# - **tasks_id**: 任务id 用于获取生成结果
|
||||
# - **prompt**: 想要生成图片的描述词
|
||||
# - **image_url**: 被生成图片的S3或minio url地址
|
||||
# - **direction**: 光源方向 Right Light Left Light Top Light Bottom Light
|
||||
# - **product_type**: 输入single item 还是 overall item
|
||||
# - **batch_size**: 批生成数量
|
||||
#
|
||||
#
|
||||
# 示例参数:
|
||||
# {
|
||||
# "tasks_id": "123-89",
|
||||
# "prompt": "beautiful woman, detailed face, sunshine, outdoor, warm atmosphere",
|
||||
# "image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png",
|
||||
# "direction": "Right Light",
|
||||
# "product_type": "overall",
|
||||
# "batch_size": 1
|
||||
# }
|
||||
# """
|
||||
# return await start_relight_batch_generate(request_batch_item)
|
||||
|
||||
|
||||
示例参数:
|
||||
{
|
||||
"tasks_id": "123-89",
|
||||
"prompt": "beautiful woman, detailed face, sunshine, outdoor, warm atmosphere",
|
||||
"image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png",
|
||||
"direction": "Right Light",
|
||||
"product_type": "overall"
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"generate_relight_image request item is : @@@@@@:{json.dumps(request_item.dict())}")
|
||||
service = GenerateRelightImage(request_item)
|
||||
background_tasks.add_task(service.get_result)
|
||||
except Exception as e:
|
||||
logger.warning(f"generate_relight_image Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel()
|
||||
|
||||
|
||||
@router.get("/generate_relight_image_cancel_cancel/{tasks_id}")
|
||||
def generate_relight_image(tasks_id: str):
|
||||
try:
|
||||
logger.info(f"generate_relight_image_cancel_cancel request item is : @@@@@@:{tasks_id}")
|
||||
data = generate_relight_image_cancel(tasks_id)
|
||||
logger.info(f"generate_relight_image_cancel_cancel response @@@@@@:{data}")
|
||||
except Exception as e:
|
||||
logger.warning(f"generate_relight_image_cancel_cancel Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel(data=data['data'])
|
||||
|
||||
|
||||
"""batch generate img"""
|
||||
|
||||
|
||||
@router.post("/batch_generate_product_image")
|
||||
async def batch_generate_product(request_batch_item: BatchGenerateProductImageModel):
|
||||
"""
|
||||
创建一个具有以下参数的请求体:
|
||||
- **tasks_id**: 任务id 用于获取生成结果
|
||||
- **prompt**: 想要生成图片的描述词
|
||||
- **image_url**: 被生成图片的S3或minio url地址
|
||||
- **image_strength**: 生成强度,越低越接近原图
|
||||
- **product_type**: 输入single item 还是 overall item
|
||||
- **batch_size**: 批生成数量
|
||||
|
||||
|
||||
示例参数:
|
||||
{
|
||||
"tasks_id": "123-89",
|
||||
"prompt": "the best quality, masterpiece. detailed, high-res, simple background, studio photography, extremely detailed, updo, detailed face, face, close-up, HDR, UHD, 8K realistic, Highly detailed, simple background, Studio lighting",
|
||||
"image_url": "aida-results/result_00097282-ebb2-11ee-a822-b48351119060.png",
|
||||
"image_strength": 0.8,
|
||||
"product_type": "overall",
|
||||
"batch_size": 1
|
||||
}
|
||||
"""
|
||||
return await start_product_batch_generate(request_batch_item)
|
||||
|
||||
|
||||
@router.post("/batch_generate_relight_image")
|
||||
async def batch_generate_relight(request_batch_item: BatchGenerateRelightImageModel):
|
||||
"""
|
||||
创建一个具有以下参数的请求体:
|
||||
- **tasks_id**: 任务id 用于获取生成结果
|
||||
- **prompt**: 想要生成图片的描述词
|
||||
- **image_url**: 被生成图片的S3或minio url地址
|
||||
- **direction**: 光源方向 Right Light Left Light Top Light Bottom Light
|
||||
- **product_type**: 输入single item 还是 overall item
|
||||
- **batch_size**: 批生成数量
|
||||
|
||||
|
||||
示例参数:
|
||||
{
|
||||
"tasks_id": "123-89",
|
||||
"prompt": "beautiful woman, detailed face, sunshine, outdoor, warm atmosphere",
|
||||
"image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png",
|
||||
"direction": "Right Light",
|
||||
"product_type": "overall",
|
||||
"batch_size": 1
|
||||
}
|
||||
"""
|
||||
return await start_relight_batch_generate(request_batch_item)
|
||||
|
||||
|
||||
@router.post("/batch_generate_pose_transform_image")
|
||||
async def batch_generate_pose_transform(request_batch_item: BatchPoseTransformModel):
|
||||
"""
|
||||
创建一个具有以下参数的请求体:
|
||||
- **tasks_id**: 任务id 用于取消生成任务和获取生成结果
|
||||
- **image_url**: 被生成图片的S3或minio url地址
|
||||
- **pose_id**: 1
|
||||
- **batch_size**: 批生成数量
|
||||
|
||||
|
||||
示例参数:
|
||||
{
|
||||
"tasks_id": "123-89",
|
||||
"image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png",
|
||||
"pose_id": "1",
|
||||
"batch_size": 1
|
||||
}
|
||||
"""
|
||||
return await start_pose_transform_batch_generate(request_batch_item)
|
||||
|
||||
|
||||
"""agent tool"""
|
||||
|
||||
|
||||
@router.post("/agent_tool_generate_image")
|
||||
def agent_tool_generate_image(request_item: AgentTollGenerateImageModel, background_tasks: BackgroundTasks):
|
||||
"""
|
||||
创建一个具有以下参数的请求体:
|
||||
- **prompt**: 想要生成图片的描述词
|
||||
- **category**: 生成图片的类别,sketch print 等等
|
||||
- **gender**: 生成sketch专用,服装类别
|
||||
- **version**: 使用模型版本 fast 或者 high
|
||||
- **size**: 生成数量
|
||||
- **version**: 使用模型版本 fast 或者 high
|
||||
|
||||
|
||||
示例参数:
|
||||
{
|
||||
"prompt": "a single item of sketch of Wabi-sabi, skirt, tiered, 4k, white background",
|
||||
"category": "sketch",
|
||||
"gender": "male",
|
||||
"size":2,
|
||||
"version":"high"
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"agent_tool_generate_image request item is : @@@@@@:{request_item.dict()}")
|
||||
request_data = request_item.dict()
|
||||
service = AgentToolGenerateImage(request_data['version'])
|
||||
image_url_list, clothing_category_list = service.get_result(
|
||||
prompt=request_data['prompt'],
|
||||
size=request_data['size'],
|
||||
version=request_data['version'],
|
||||
category=request_data['category'],
|
||||
gender=request_data['gender']
|
||||
)
|
||||
data = {
|
||||
"image_url_list": image_url_list,
|
||||
"clothing_category_list": clothing_category_list
|
||||
}
|
||||
logger.info(f"agent_tool_generate_image response item is : @@@@@@:{data}")
|
||||
except Exception as e:
|
||||
logger.warning(f"agent_tool_generate_image Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel(data=data)
|
||||
# @router.post("/batch_generate_pose_transform_image")
|
||||
# async def batch_generate_pose_transform(request_batch_item: BatchPoseTransformModel):
|
||||
# """
|
||||
# 创建一个具有以下参数的请求体:
|
||||
# - **tasks_id**: 任务id 用于取消生成任务和获取生成结果
|
||||
# - **image_url**: 被生成图片的S3或minio url地址
|
||||
# - **pose_id**: 1
|
||||
# - **batch_size**: 批生成数量
|
||||
#
|
||||
#
|
||||
# 示例参数:
|
||||
# {
|
||||
# "tasks_id": "123-89",
|
||||
# "image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png",
|
||||
# "pose_id": "1",
|
||||
# "batch_size": 1
|
||||
# }
|
||||
# """
|
||||
# return await start_pose_transform_batch_generate(request_batch_item)
|
||||
#
|
||||
#
|
||||
# """agent tool"""
|
||||
#
|
||||
#
|
||||
# @router.post("/agent_tool_generate_image")
|
||||
# def agent_tool_generate_image(request_item: AgentTollGenerateImageModel):
|
||||
# """
|
||||
# 创建一个具有以下参数的请求体:
|
||||
# - **prompt**: 想要生成图片的描述词
|
||||
# - **category**: 生成图片的类别,sketch print 等等
|
||||
# - **gender**: 生成sketch专用,服装类别
|
||||
# - **version**: 使用模型版本 fast 或者 high
|
||||
# - **size**: 生成数量
|
||||
# - **version**: 使用模型版本 fast 或者 high
|
||||
#
|
||||
#
|
||||
# 示例参数:
|
||||
# {
|
||||
# "prompt": "a single item of sketch of Wabi-sabi, skirt, tiered, 4k, white background",
|
||||
# "category": "sketch",
|
||||
# "gender": "male",
|
||||
# "size":2,
|
||||
# "version":"high"
|
||||
# }
|
||||
# """
|
||||
# try:
|
||||
# logger.info(f"agent_tool_generate_image request item is : @@@@@@:{request_item.dict()}")
|
||||
# request_data = request_item.dict()
|
||||
# service = AgentToolGenerateImage(request_data['version'])
|
||||
# image_url_list, clothing_category_list = service.get_result(
|
||||
# prompt=request_data['prompt'],
|
||||
# size=request_data['size'],
|
||||
# version=request_data['version'],
|
||||
# category=request_data['category'],
|
||||
# gender=request_data['gender']
|
||||
# )
|
||||
# data = {
|
||||
# "image_url_list": image_url_list,
|
||||
# "clothing_category_list": clothing_category_list
|
||||
# }
|
||||
# logger.info(f"agent_tool_generate_image response item is : @@@@@@:{data}")
|
||||
# except Exception as e:
|
||||
# logger.warning(f"agent_tool_generate_image Run Exception @@@@@@:{e}")
|
||||
# raise HTTPException(status_code=404, detail=str(e))
|
||||
# return ResponseModel(data=data)
|
||||
|
||||
@@ -14,22 +14,22 @@ logger = logging.getLogger()
|
||||
@router.post("/image2sketch")
|
||||
def image2sketch(request_item: Image2SketchModel):
|
||||
"""
|
||||
创建一个具有以下参数的请求体:
|
||||
- **image_url**: 提取图片url
|
||||
- **default_style**: 原始、 1、2、3、4、5
|
||||
- **sketch_bucket**: sketch保存的bucket
|
||||
- **sketch_name**: sketch保存的object name
|
||||
创建一个具有以下参数的请求体:
|
||||
- **image_url**: 提取图片url
|
||||
- **default_style**: 原始、 1、2、3、4、5
|
||||
- **sketch_bucket**: sketch保存的bucket
|
||||
- **sketch_name**: sketch保存的object name
|
||||
|
||||
示例参数:
|
||||
{
|
||||
"image_url": "test/image2sketch/real_Dress_3200fecdc83d0c556c2bd96aedbd7fbf.jpg_Img.jpg",
|
||||
"default_style": 0,
|
||||
"sketch_bucket": "test",
|
||||
"sketch_name": "image2sketch/area_fill_img.png"
|
||||
}
|
||||
"""
|
||||
示例参数:
|
||||
{
|
||||
"image_url": "test/image2sketch/real_Dress_3200fecdc83d0c556c2bd96aedbd7fbf.jpg_Img.jpg",
|
||||
"default_style": 0,
|
||||
"sketch_bucket": "test",
|
||||
"sketch_name": "image2sketch/area_fill_img.png"
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"image2sketch request item is : @@@@@@:{json.dumps(request_item.dict())}")
|
||||
logger.info(f"image2sketch request item is : @@@@@@:{json.dumps(request_item.dict(),indent=4)}")
|
||||
service = LineArtService(request_item)
|
||||
result_url = service.get_result()
|
||||
except Exception as e:
|
||||
|
||||
116
app/api/api_import_sys_sketch.py
Normal file
116
app/api/api_import_sys_sketch.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import logging
|
||||
import sys
|
||||
from typing import Optional
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import threading
|
||||
|
||||
from app.schemas.response_template import ResponseModel
|
||||
from app.service.recommendation_system.import_sys_sketch_to_milvus import main as import_main
|
||||
|
||||
logger = logging.getLogger()
|
||||
router = APIRouter()
|
||||
|
||||
# 使用线程池执行器来运行长时间任务
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
# 用于跟踪任务状态
|
||||
task_status = {"running": False}
|
||||
|
||||
|
||||
def run_import_task(batch_size: int, retry_times: int, limit: Optional[int], offset: int, skip_create_collection: bool):
|
||||
"""在后台线程中运行导入任务"""
|
||||
original_argv = None
|
||||
try:
|
||||
task_status["running"] = True
|
||||
# 保存原始 sys.argv
|
||||
original_argv = sys.argv.copy()
|
||||
|
||||
# 模拟命令行参数
|
||||
sys.argv = [
|
||||
"import_sys_sketch_to_milvus.py",
|
||||
"--batch-size", str(batch_size),
|
||||
"--retry-times", str(retry_times),
|
||||
]
|
||||
if limit is not None:
|
||||
sys.argv.extend(["--limit", str(limit)])
|
||||
if offset > 0:
|
||||
sys.argv.extend(["--offset", str(offset)])
|
||||
if skip_create_collection:
|
||||
sys.argv.append("--skip-create-collection")
|
||||
|
||||
import_main()
|
||||
task_status["running"] = False
|
||||
logger.info("导入任务完成")
|
||||
except Exception as e:
|
||||
task_status["running"] = False
|
||||
logger.error(f"导入任务失败: {e}", exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
# 恢复原始 sys.argv
|
||||
if original_argv is not None:
|
||||
sys.argv = original_argv
|
||||
|
||||
|
||||
@router.post("/import-sys-sketch", response_model=ResponseModel)
|
||||
async def import_sys_sketch(
|
||||
batch_size: int = Query(1000, description="批量处理大小(默认:1000)"),
|
||||
retry_times: int = Query(3, description="失败重试次数(默认:3)"),
|
||||
limit: Optional[int] = Query(None, description="限制处理数量(用于测试,默认:不限制)"),
|
||||
offset: int = Query(0, description="起始偏移量(默认:0)"),
|
||||
skip_create_collection: bool = Query(False, description="跳过创建集合(如果集合已存在)"),
|
||||
):
|
||||
"""
|
||||
从 t_sys_file 导入系统图向量到 Milvus
|
||||
|
||||
该接口会异步执行导入任务,任务在后台运行。
|
||||
"""
|
||||
try:
|
||||
# 检查是否有任务正在运行
|
||||
if task_status["running"]:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="已有导入任务正在运行,请等待完成后再试"
|
||||
)
|
||||
|
||||
# 在后台线程中执行任务
|
||||
executor.submit(
|
||||
run_import_task,
|
||||
batch_size,
|
||||
retry_times,
|
||||
limit,
|
||||
offset,
|
||||
skip_create_collection
|
||||
)
|
||||
|
||||
return ResponseModel(
|
||||
code=200,
|
||||
msg="导入任务已启动,正在后台执行",
|
||||
data={
|
||||
"status": "started",
|
||||
"batch_size": batch_size,
|
||||
"retry_times": retry_times,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
"skip_create_collection": skip_create_collection
|
||||
}
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"启动导入任务失败: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"启动导入任务失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/import-sys-sketch/status", response_model=ResponseModel)
|
||||
async def get_import_status():
|
||||
"""
|
||||
获取导入任务状态
|
||||
"""
|
||||
return ResponseModel(
|
||||
code=200,
|
||||
msg="OK",
|
||||
data={
|
||||
"running": task_status["running"]
|
||||
}
|
||||
)
|
||||
|
||||
@@ -35,10 +35,10 @@ def mannequins_edit(request_data: MannequinModel):
|
||||
}**
|
||||
"""
|
||||
try:
|
||||
logger.info(f"mannequins_edit request item is : @@@@@@:{json.dumps(request_data.dict())}")
|
||||
logger.info(f"mannequins_edit request item is : @@@@@@:{json.dumps(request_data.dict(),indent=4)}")
|
||||
service = MannequinEditService(request_data)
|
||||
data = service()
|
||||
logger.info(f"mannequins_edit response @@@@@@:{json.dumps(data)}")
|
||||
logger.info(f"mannequins_edit response @@@@@@:{json.dumps(data, indent=4)}")
|
||||
except Exception as e:
|
||||
logger.warning(f"mannequins_edit Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
@@ -1,18 +1,67 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
import requests
|
||||
from fastapi import APIRouter, BackgroundTasks, HTTPException
|
||||
|
||||
from app.core.config import settings
|
||||
from app.schemas.comfyui_i2v import ComfyuiI2VModel, ComfyuiFLF2VModel
|
||||
from app.schemas.pose_transform import PoseTransformModel
|
||||
from app.schemas.response_template import ResponseModel
|
||||
from app.service.generate_image.service_pose_transform import PoseTransformService, infer_cancel as pose_transform_infer_cancel
|
||||
from app.service.comfyui_I2V.flf2v_server import ComfyUIServerFLF2V
|
||||
from app.service.comfyui_I2V.i2v_server import ComfyUIServerI2V
|
||||
from app.service.comfyui_I2V.pose2v_server import ComfyUIServerPose2V
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger()
|
||||
|
||||
"""停用"""
|
||||
|
||||
@router.post("/pose_transform")
|
||||
def pose_transform(request_item: PoseTransformModel, background_tasks: BackgroundTasks):
|
||||
# @router.post("/pose_transform")
|
||||
# def pose_transform(request_item: PoseTransformModel, background_tasks: BackgroundTasks):
|
||||
# """
|
||||
# 创建一个具有以下参数的请求体:
|
||||
# - **tasks_id**: 任务id 用于取消生成任务和获取生成结果
|
||||
# - **image_url**: 被生成图片的S3或minio url地址
|
||||
# - **pose_id**: 1
|
||||
#
|
||||
#
|
||||
# 示例参数:
|
||||
# {
|
||||
# "tasks_id": "123-89",
|
||||
# "image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png",
|
||||
# "pose_id": "1"
|
||||
# }
|
||||
# """
|
||||
# try:
|
||||
# logger.info(f"pose_transform request item is : @@@@@@:{json.dumps(request_item.dict(),indent=4)}")
|
||||
# service = PoseTransformService(request_item)
|
||||
# background_tasks.add_task(service.get_result)
|
||||
# except Exception as e:
|
||||
# logger.warning(f"pose_transform Run Exception @@@@@@:{e}")
|
||||
# raise HTTPException(status_code=404, detail=str(e))
|
||||
# return ResponseModel()
|
||||
|
||||
|
||||
# @router.get("/pose_transform_cancel/{tasks_id}")
|
||||
# def pose_transform_cancel(tasks_id: str):
|
||||
# try:
|
||||
# logger.info(f"pose_transform_cancel request item is : @@@@@@:{tasks_id}")
|
||||
# data = pose_transform_infer_cancel(tasks_id)
|
||||
# logger.info(f"pose_transform_cancel response @@@@@@:{data}")
|
||||
# except Exception as e:
|
||||
# logger.warning(f"pose_transform_cancel Run Exception @@@@@@:{e}")
|
||||
# raise HTTPException(status_code=404, detail=str(e))
|
||||
# return ResponseModel(data=data['data'])
|
||||
|
||||
|
||||
"""
|
||||
骨架 + 产品图 => 视频
|
||||
"""
|
||||
|
||||
|
||||
@router.post("/comfyui_image_pose_2_video")
|
||||
def comfyui_image_pose_2_video(request_item: PoseTransformModel, background_tasks: BackgroundTasks):
|
||||
"""
|
||||
创建一个具有以下参数的请求体:
|
||||
- **tasks_id**: 任务id 用于取消生成任务和获取生成结果
|
||||
@@ -28,22 +77,92 @@ def pose_transform(request_item: PoseTransformModel, background_tasks: Backgroun
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"pose_transform request item is : @@@@@@:{json.dumps(request_item.dict())}")
|
||||
service = PoseTransformService(request_item)
|
||||
logger.info(f"image_pose_2_video request item is : @@@@@@:{json.dumps(request_item.dict(),indent=4)}")
|
||||
service = ComfyUIServerPose2V(request_item)
|
||||
background_tasks.add_task(service.get_result)
|
||||
except Exception as e:
|
||||
logger.warning(f"pose_transform Run Exception @@@@@@:{e}")
|
||||
logger.warning(f"image_pose_2_video Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel()
|
||||
|
||||
|
||||
@router.get("/pose_transform_cancel/{tasks_id}")
|
||||
def pose_transform_cancel(tasks_id: str):
|
||||
"""
|
||||
产品图 + 文 => 视频
|
||||
"""
|
||||
|
||||
|
||||
@router.post("/comfyui_image_2_video")
|
||||
def comfyui_image_2_video(request_item: ComfyuiI2VModel, background_tasks: BackgroundTasks):
|
||||
"""
|
||||
创建一个具有以下参数的请求体:
|
||||
- **tasks_id**: 任务id 用于取消生成任务和获取生成结果
|
||||
- **image_url**: 被生成图片的S3或minio url地址
|
||||
- **prompt**: 动作表述
|
||||
|
||||
示例参数:
|
||||
{
|
||||
"tasks_id": "12222515151123-89111",
|
||||
"image_url": "aida-users/89/product_image/a6949500-2393-42ac-8723-440b5d5da2b2-0-89.png",
|
||||
"prompt": "Model executing a series of poses, dynamic camera movement alternating between detailed close-ups and full shots."
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"pose_transform_cancel request item is : @@@@@@:{tasks_id}")
|
||||
data = pose_transform_infer_cancel(tasks_id)
|
||||
logger.info(f"pose_transform_cancel response @@@@@@:{data}")
|
||||
logger.info(f"image_2_video request item is : @@@@@@:{json.dumps(request_item.dict(),indent=4)}")
|
||||
service = ComfyUIServerI2V(request_item)
|
||||
background_tasks.add_task(service.get_result)
|
||||
except Exception as e:
|
||||
logger.warning(f"pose_transform_cancel Run Exception @@@@@@:{e}")
|
||||
logger.warning(f"image_2_video Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel()
|
||||
|
||||
|
||||
"""
|
||||
首尾帧 + 文 => 视频
|
||||
"""
|
||||
|
||||
|
||||
@router.post("/comfyui_flf_2_video")
|
||||
def comfyui_flf_2_video(request_item: ComfyuiFLF2VModel, background_tasks: BackgroundTasks):
|
||||
"""
|
||||
创建一个具有以下参数的请求体:
|
||||
- **tasks_id**: 任务id 用于取消生成任务和获取生成结果
|
||||
- **start_image_url**: 首帧
|
||||
- **end_image_url**: 尾帧
|
||||
- **prompt**: 动作描述
|
||||
|
||||
示例参数:
|
||||
{
|
||||
"tasks_id": "202511051619-89111",
|
||||
"start_image_url": "test/start.png",
|
||||
"end_image_url": "test/end.png",
|
||||
"prompt": "Model executing a series of poses, dynamic camera movement alternating between detailed close-ups and full shots."
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"flf_2_video request item is : @@@@@@:{json.dumps(request_item.dict(),indent=4)}")
|
||||
service = ComfyUIServerFLF2V(request_item)
|
||||
background_tasks.add_task(service.get_result)
|
||||
except Exception as e:
|
||||
logger.warning(f"flf_2_video Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel()
|
||||
|
||||
|
||||
@router.get("/comfyui_i_2_video_cancel/{tasks_id}")
|
||||
def comfyui_i_2_video_cancel(tasks_id: str):
|
||||
try:
|
||||
logger.info(f"comfyui_i_2_video_cancel request item is : @@@@@@:{tasks_id}")
|
||||
response = requests.post(
|
||||
f"http://{settings.COMFYUI_SERVER_ADDRESS}/interrupt",
|
||||
json={"prompt_id": tasks_id}
|
||||
)
|
||||
data = {}
|
||||
if response.status_code == 200:
|
||||
data['data']['message'] = "任务已成功中断"
|
||||
else:
|
||||
data['data']['message'] = f"中断失败:{response.text}"
|
||||
logger.info(f"comfyui_i_2_video_cancel response @@@@@@:{data}")
|
||||
except Exception as e:
|
||||
logger.warning(f"comfyui_i_2_video_cancel Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel(data=data['data'])
|
||||
|
||||
85
app/api/api_precompute.py
Normal file
85
app/api/api_precompute.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import logging
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from app.schemas.response_template import ResponseModel
|
||||
from app.service.recommendation_system.precompute import run_precompute
|
||||
|
||||
logger = logging.getLogger()
|
||||
router = APIRouter()
|
||||
|
||||
# 使用线程池执行器来运行长时间任务
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
# 用于跟踪任务状态
|
||||
task_status = {"running": False}
|
||||
|
||||
|
||||
def run_precompute_task():
|
||||
"""在后台线程中运行预计算任务"""
|
||||
try:
|
||||
task_status["running"] = True
|
||||
logger.info("开始执行预计算任务...")
|
||||
run_precompute()
|
||||
task_status["running"] = False
|
||||
logger.info("预计算任务完成")
|
||||
except Exception as e:
|
||||
task_status["running"] = False
|
||||
logger.error(f"预计算任务失败: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
@router.post("/precompute", response_model=ResponseModel)
|
||||
async def precompute():
|
||||
"""
|
||||
运行预计算任务
|
||||
|
||||
该接口会异步执行预计算任务,包括:
|
||||
1. 优化数据库表结构
|
||||
2. 历史数据迁移
|
||||
3. 初始用户偏好向量生成
|
||||
|
||||
任务在后台运行。
|
||||
"""
|
||||
try:
|
||||
# 检查是否有任务正在运行
|
||||
if task_status["running"]:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="已有预计算任务正在运行,请等待完成后再试"
|
||||
)
|
||||
|
||||
# 在后台线程中执行任务
|
||||
executor.submit(run_precompute_task)
|
||||
|
||||
return ResponseModel(
|
||||
code=200,
|
||||
msg="预计算任务已启动,正在后台执行",
|
||||
data={
|
||||
"status": "started",
|
||||
"tasks": [
|
||||
"优化数据库表结构",
|
||||
"历史数据迁移",
|
||||
"初始用户偏好向量生成"
|
||||
]
|
||||
}
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"启动预计算任务失败: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"启动预计算任务失败: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/precompute/status", response_model=ResponseModel)
|
||||
async def get_precompute_status():
|
||||
"""
|
||||
获取预计算任务状态
|
||||
"""
|
||||
return ResponseModel(
|
||||
code=200,
|
||||
msg="OK",
|
||||
data={
|
||||
"running": task_status["running"]
|
||||
}
|
||||
)
|
||||
|
||||
@@ -1,13 +1,10 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from app.schemas.prompt_generation import PromptGenerationImageModel, ImageRequest
|
||||
from app.schemas.response_template import ResponseModel
|
||||
from app.service.prompt_generation.chatgpt_for_translation import get_translation_from_llama3, \
|
||||
get_prompt_from_image
|
||||
from app.service.prompt_generation.chatgpt_for_translation import get_translation_from_llama3, get_prompt_from_image
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger()
|
||||
@@ -34,19 +31,19 @@ def prompt_generation(request_data: PromptGenerationImageModel):
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ResponseModel(data=data)
|
||||
|
||||
|
||||
@router.post("/img2prompt")
|
||||
def get_prompt_from_img(img: ImageRequest):
|
||||
"""
|
||||
自动识别图片并输出为prompt
|
||||
|
||||
:param img: 图片的minio地址
|
||||
:return: 图片的文字描述
|
||||
"""
|
||||
text = ("Please describe the clothing in the image and provide a line art description of the outfit. "
|
||||
"The description should allow for the reconstruction of the corresponding line art based on the details "
|
||||
"given.")
|
||||
logger.info(f"get_prompt_from_img request item is : @@@@@@:{img}")
|
||||
description = get_prompt_from_image(img, text)
|
||||
logger.info(f"生成的图片描述 response @@@@@@:{description}")
|
||||
return description
|
||||
# 停用
|
||||
# @router.post("/img2prompt")
|
||||
# def get_prompt_from_img(img: ImageRequest):
|
||||
# """
|
||||
# 自动识别图片并输出为prompt
|
||||
#
|
||||
# :param img: 图片的minio地址
|
||||
# :return: 图片的文字描述
|
||||
# """
|
||||
# text = ("Please describe the clothing in the image and provide a line art description of the outfit. "
|
||||
# "The description should allow for the reconstruction of the corresponding line art based on the details "
|
||||
# "given.")
|
||||
# logger.info(f"get_prompt_from_img request item is : @@@@@@:{img}")
|
||||
# description = get_prompt_from_image(img, text)
|
||||
# logger.info(f"生成的图片描述 response @@@@@@:{description}")
|
||||
# return description
|
||||
|
||||
@@ -26,9 +26,9 @@ def query_image(request_data: QueryImageModel):
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"query_image request item is : @@@@@@:{json.dumps(request_data.dict())}")
|
||||
logger.info(f"query_image request item is : @@@@@@:{json.dumps(request_data.dict(), indent=4)}")
|
||||
data = query(request_data.gender, request_data.content)
|
||||
logger.info(f"query_image response @@@@@@:{json.dumps(data)}")
|
||||
logger.info(f"query_image response @@@@@@:{json.dumps(data, indent=4)}")
|
||||
except Exception as e:
|
||||
logger.warning(f"query_image Run Exception @@@@@@:{e}")
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
@@ -1,204 +1,206 @@
|
||||
import io
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from typing import List
|
||||
import os
|
||||
import json
|
||||
import math
|
||||
import random
|
||||
import numpy as np
|
||||
from typing import List, Optional
|
||||
from fastapi import HTTPException, APIRouter, Query
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from fastapi import HTTPException, APIRouter
|
||||
|
||||
from app.service.recommend.service import load_resources, matrix_data
|
||||
from app.service.recommendation_system.recommendation_api import get_recommendations as get_new_recommendations
|
||||
from app.service.recommendation_system.incremental_listener import start_background_listener
|
||||
from app.service.recommendation_system.milvus_client import create_collection
|
||||
|
||||
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
|
||||
logger = logging.getLogger()
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ========== 旧版推荐接口(基于 npy 矩阵,已废弃)==========
|
||||
# @router.get("/recommend/{user_id}/{category}/{num_recommendations}/{brand_id}/{brand_scale}", response_model=List[str])
|
||||
# async def get_recommendations(user_id: int, category: str, brand_id: int, brand_scale: float, num_recommendations: int = 10):
|
||||
# """
|
||||
# :param user_id: 4
|
||||
# :param category: female_skirt
|
||||
# :param num_recommendations: 1
|
||||
# :return:
|
||||
# [
|
||||
# "aida-sys-image/images/female/skirt/903000017.jpg"
|
||||
# ]
|
||||
# """
|
||||
# try:
|
||||
# start_time = time.time()
|
||||
# cache_key = (user_id, category)
|
||||
# # === 新增:用户存在性检查 ===
|
||||
# user_exists_inter = user_id in matrix_data["user_index_interaction"]
|
||||
# user_exists_feat = user_id in matrix_data["user_index_feature"]
|
||||
#
|
||||
# # 任一矩阵不存在用户则返回随机推荐
|
||||
# if not (user_exists_inter and user_exists_feat):
|
||||
# logger.info(f"用户 {user_id} 数据不完整,触发随机推荐")
|
||||
# return get_random_recommendations(category, num_recommendations)
|
||||
#
|
||||
# # 检查缓存
|
||||
# if cache_key in matrix_data["cached_scores"]:
|
||||
# processed_inter, processed_feat = matrix_data["cached_scores"][cache_key]
|
||||
# valid_sketch_idxs_inter = matrix_data["cached_valid_idxs"][cache_key]
|
||||
# else:
|
||||
# # 实时计算逻辑(同原代码)
|
||||
# user_idx_inter = matrix_data["user_index_interaction"].get(user_id)
|
||||
# user_idx_feature = matrix_data["user_index_feature"].get(user_id)
|
||||
#
|
||||
# category_iids = matrix_data["category_to_iids"].get(category, [])
|
||||
# valid_sketch_idxs_inter = [
|
||||
# idx for iid, idx in matrix_data["sketch_index_interaction"].items()
|
||||
# if iid in category_iids
|
||||
# ]
|
||||
#
|
||||
# # 处理交互分数
|
||||
# raw_inter_scores = []
|
||||
# if user_idx_inter is not None and valid_sketch_idxs_inter:
|
||||
# raw_inter_scores = matrix_data["interaction_matrix"][user_idx_inter, valid_sketch_idxs_inter]
|
||||
# processed_inter = raw_inter_scores * 0.7
|
||||
#
|
||||
# # 处理特征分数
|
||||
# valid_sketch_idxs_feature = [
|
||||
# idx for iid, idx in matrix_data["sketch_index_feature"].items()
|
||||
# if iid in category_iids
|
||||
# ]
|
||||
# raw_feat_scores = []
|
||||
# if user_idx_feature is not None and valid_sketch_idxs_feature:
|
||||
# raw_feat_scores = matrix_data["feature_matrix"][user_idx_feature, valid_sketch_idxs_feature]
|
||||
# raw_feat_scores = (raw_feat_scores - np.min(raw_feat_scores)) / (
|
||||
# np.max(raw_feat_scores) - np.min(raw_feat_scores) + 1e-8)
|
||||
# processed_feat = raw_feat_scores
|
||||
# else:
|
||||
# processed_feat = np.array([])
|
||||
#
|
||||
# # 更新缓存
|
||||
# matrix_data["cached_scores"][cache_key] = (processed_inter, processed_feat)
|
||||
# matrix_data["cached_valid_idxs"][cache_key] = valid_sketch_idxs_inter
|
||||
#
|
||||
# # 合并分数
|
||||
# if brand_id is not None:
|
||||
# brand_idx_feature = matrix_data["brand_index_map"].get(brand_id)
|
||||
#
|
||||
# brand_feat_valid = (
|
||||
# matrix_data["brand_feature_matrix"].size > 0 and # 矩阵非空
|
||||
# brand_idx_feature is not None and
|
||||
# valid_sketch_idxs_feature # 有可用索引
|
||||
# )
|
||||
#
|
||||
# if brand_feat_valid:
|
||||
# raw_brand_feat_scores = matrix_data["brand_feature_matrix"][
|
||||
# brand_idx_feature, valid_sketch_idxs_feature
|
||||
# ]
|
||||
# raw_brand_feat_scores = (raw_brand_feat_scores - np.min(raw_brand_feat_scores)) / (
|
||||
# np.max(raw_brand_feat_scores) - np.min(raw_brand_feat_scores) + 1e-8
|
||||
# )
|
||||
# processed_brand_feat = raw_brand_feat_scores
|
||||
#
|
||||
# # 如果 processed_feat 是空的,替换为全 0,避免 shape 不一致
|
||||
# if processed_feat.size == 0:
|
||||
# processed_feat = np.zeros_like(processed_brand_feat)
|
||||
#
|
||||
# final_scores = processed_inter + 0.3 * (
|
||||
# (1 - brand_scale) * processed_feat + brand_scale * processed_brand_feat
|
||||
# )
|
||||
# else:
|
||||
# # brand 信息不可用
|
||||
# final_scores = processed_inter + 0.3 * processed_feat if processed_feat.size > 0 else processed_inter
|
||||
# else:
|
||||
# final_scores = processed_inter + 0.3 * processed_feat if processed_feat.size > 0 else processed_inter
|
||||
#
|
||||
# valid_sketch_idxs = matrix_data["cached_valid_idxs"][cache_key]
|
||||
#
|
||||
# # 概率采样
|
||||
# scores = np.array(final_scores)
|
||||
#
|
||||
# # 调整后的概率转换(带温度控制的softmax)
|
||||
# def calibrated_softmax(scores, temperature=1.0):
|
||||
# scores = scores / temperature
|
||||
# scale = scores - max(scores)
|
||||
# exps = np.exp(scale)
|
||||
# return exps / np.sum(exps)
|
||||
#
|
||||
# probs = calibrated_softmax(scores, 0.09)
|
||||
#
|
||||
# chosen_indices = np.random.choice(
|
||||
# len(valid_sketch_idxs),
|
||||
# size=min(num_recommendations, len(valid_sketch_idxs)),
|
||||
# p=probs,
|
||||
# replace=False
|
||||
# )
|
||||
# recommendations = [matrix_data["iid_to_sketch"][valid_sketch_idxs[idx]] for idx in chosen_indices]
|
||||
#
|
||||
# logger.info(f"推荐生成完成,耗时: {time.time() - start_time:.2f}秒")
|
||||
# return recommendations
|
||||
# except Exception as e:
|
||||
# logger.error(f"推荐失败: {str(e)}", exc_info=True)
|
||||
# raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@router.on_event("startup")
|
||||
async def startup_event():
|
||||
# 初始加载
|
||||
load_resources()
|
||||
"""启动时初始化增量监听任务"""
|
||||
try:
|
||||
# 屏蔽 apscheduler 的 INFO 日志
|
||||
logging.getLogger("apscheduler").setLevel(logging.WARNING)
|
||||
|
||||
# 配置定时任务
|
||||
scheduler = BackgroundScheduler()
|
||||
scheduler.add_job(
|
||||
load_resources,
|
||||
trigger=CronTrigger(hour=0, minute=30),
|
||||
name="每日资源刷新"
|
||||
)
|
||||
scheduler.start()
|
||||
logger.info("定时任务已启动")
|
||||
# 确保 Milvus 集合已创建(若已存在则直接返回)
|
||||
try:
|
||||
create_collection()
|
||||
except Exception as exc:
|
||||
logger.error("Milvus 集合创建/检查失败,不影响服务继续启动: %s", exc, exc_info=True)
|
||||
|
||||
def softmax(scores):
|
||||
max_score = max(scores)
|
||||
exp_scores = [math.exp(s - max_score) for s in scores]
|
||||
sum_exp = sum(exp_scores)
|
||||
return [s / sum_exp for s in exp_scores]
|
||||
|
||||
# def get_random_recommendations(category: str, num: int) -> List[str]:
|
||||
# """根据预加载热度向量推荐(冷启动)"""
|
||||
# try:
|
||||
# heat_data = matrix_data.get("heat_data", {})
|
||||
#
|
||||
# if category not in heat_data:
|
||||
# raise ValueError(f"热度数据缺少类别 {category},使用随机推荐")
|
||||
#
|
||||
# heat_dict = heat_data[category] # {url: score}
|
||||
# urls = list(heat_dict.keys())
|
||||
# scores = list(heat_dict.values())
|
||||
#
|
||||
# if not urls:
|
||||
# raise ValueError("该类别下无热度记录,使用随机推荐")
|
||||
#
|
||||
# probs = softmax(scores)
|
||||
# sample_size = min(num, len(urls))
|
||||
# sampled_urls = random.choices(urls, weights=probs, k=sample_size)
|
||||
#
|
||||
# return sampled_urls
|
||||
#
|
||||
# except Exception as e:
|
||||
# # 回退:完全随机推荐
|
||||
# all_iids = list(matrix_data["iid_to_sketch"].keys())
|
||||
# category_iids = matrix_data["category_to_iids"].get(category, all_iids)
|
||||
# sample_size = min(num, len(category_iids))
|
||||
# sampled = np.random.choice(category_iids, size=sample_size, replace=False)
|
||||
# return [matrix_data["iid_to_sketch"][iid] for iid in sampled]
|
||||
|
||||
def get_random_recommendations(category: str, num: int) -> List[str]:
|
||||
"""全品类随机推荐"""
|
||||
all_iids = list(matrix_data["iid_to_sketch"].keys())
|
||||
# 优先从当前品类选择
|
||||
category_iids = matrix_data["category_to_iids"].get(category, all_iids)
|
||||
# 确保不超出实际数量
|
||||
sample_size = min(num, len(category_iids))
|
||||
sampled = np.random.choice(category_iids, size=sample_size, replace=False)
|
||||
return [matrix_data["iid_to_sketch"][iid] for iid in sampled]
|
||||
# 配置定时任务
|
||||
scheduler = BackgroundScheduler()
|
||||
start_background_listener(scheduler)
|
||||
scheduler.start()
|
||||
logger.info("增量监听定时任务已启动")
|
||||
except Exception as e:
|
||||
logger.error(f"启动增量监听任务失败: {e}", exc_info=True)
|
||||
|
||||
|
||||
@router.get("/recommend/{user_id}/{category}/{num_recommendations}/{brand_id}/{brand_scale}", response_model=List[str])
|
||||
async def get_recommendations(user_id: int, category: str, brand_id: int, brand_scale: float, num_recommendations: int = 10):
|
||||
@router.get("/recommend/{user_id}/{category}", response_model=List[str])
|
||||
async def recommend(
|
||||
user_id: int,
|
||||
category: str,
|
||||
style: Optional[str] = Query(
|
||||
None,
|
||||
description="风格样式(可选):若传入,则在利用分支对同 style 的候选进行加分",
|
||||
),
|
||||
):
|
||||
"""新版推荐接口(Milvus + Redis 偏好向量)。"""
|
||||
try:
|
||||
results = get_new_recommendations(user_id, category, style)
|
||||
path = results[0] if results else ""
|
||||
return [path]
|
||||
except Exception as e:
|
||||
logger.error("新版推荐接口失败 [user=%s, category=%s]: %s", user_id, category, e, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/redis/user_pref")
|
||||
async def get_all_user_preferences():
|
||||
"""
|
||||
:param user_id: 4
|
||||
:param category: female_skirt
|
||||
:param num_recommendations: 1
|
||||
:return:
|
||||
[
|
||||
"aida-sys-image/images/female/skirt/903000017.jpg"
|
||||
]
|
||||
获取所有以 user_pref 为前缀的 Redis key 信息
|
||||
"""
|
||||
try:
|
||||
start_time = time.time()
|
||||
cache_key = (user_id, category)
|
||||
# === 新增:用户存在性检查 ===
|
||||
user_exists_inter = user_id in matrix_data["user_index_interaction"]
|
||||
user_exists_feat = user_id in matrix_data["user_index_feature"]
|
||||
from app.service.utils.redis_utils import Redis
|
||||
from app.service.recommendation_system.config import REDIS_KEY_USER_PREF_PREFIX
|
||||
|
||||
# 任一矩阵不存在用户则返回随机推荐
|
||||
if not (user_exists_inter and user_exists_feat):
|
||||
logger.info(f"用户 {user_id} 数据不完整,触发随机推荐")
|
||||
return get_random_recommendations(category, num_recommendations)
|
||||
# 扫描所有匹配 user_pref:* 的 key
|
||||
pattern = f"{REDIS_KEY_USER_PREF_PREFIX}:*"
|
||||
keys = Redis.scan_keys(pattern)
|
||||
|
||||
# 检查缓存
|
||||
if cache_key in matrix_data["cached_scores"]:
|
||||
processed_inter, processed_feat = matrix_data["cached_scores"][cache_key]
|
||||
valid_sketch_idxs_inter = matrix_data["cached_valid_idxs"][cache_key]
|
||||
else:
|
||||
# 实时计算逻辑(同原代码)
|
||||
user_idx_inter = matrix_data["user_index_interaction"].get(user_id)
|
||||
user_idx_feature = matrix_data["user_index_feature"].get(user_id)
|
||||
# 直接返回所有 key 和原始 value
|
||||
result = {}
|
||||
for key in keys:
|
||||
# 读取对应的值
|
||||
value = Redis.read(key)
|
||||
if value:
|
||||
result[key] = value
|
||||
|
||||
category_iids = matrix_data["category_to_iids"].get(category, [])
|
||||
valid_sketch_idxs_inter = [
|
||||
idx for iid, idx in matrix_data["sketch_index_interaction"].items()
|
||||
if iid in category_iids
|
||||
]
|
||||
|
||||
# 处理交互分数
|
||||
raw_inter_scores = []
|
||||
if user_idx_inter is not None and valid_sketch_idxs_inter:
|
||||
raw_inter_scores = matrix_data["interaction_matrix"][user_idx_inter, valid_sketch_idxs_inter]
|
||||
processed_inter = raw_inter_scores * 0.7
|
||||
|
||||
# 处理特征分数
|
||||
valid_sketch_idxs_feature = [
|
||||
idx for iid, idx in matrix_data["sketch_index_feature"].items()
|
||||
if iid in category_iids
|
||||
]
|
||||
raw_feat_scores = []
|
||||
if user_idx_feature is not None and valid_sketch_idxs_feature:
|
||||
raw_feat_scores = matrix_data["feature_matrix"][user_idx_feature, valid_sketch_idxs_feature]
|
||||
raw_feat_scores = (raw_feat_scores - np.min(raw_feat_scores)) / (
|
||||
np.max(raw_feat_scores) - np.min(raw_feat_scores) + 1e-8)
|
||||
processed_feat = raw_feat_scores
|
||||
else:
|
||||
processed_feat = np.array([])
|
||||
|
||||
# 更新缓存
|
||||
matrix_data["cached_scores"][cache_key] = (processed_inter, processed_feat)
|
||||
matrix_data["cached_valid_idxs"][cache_key] = valid_sketch_idxs_inter
|
||||
|
||||
# 合并分数
|
||||
if brand_id is not None:
|
||||
brand_idx_feature = matrix_data["brand_index_map"].get(brand_id)
|
||||
|
||||
brand_feat_valid = (
|
||||
matrix_data["brand_feature_matrix"].size > 0 and # 矩阵非空
|
||||
brand_idx_feature is not None and
|
||||
valid_sketch_idxs_feature # 有可用索引
|
||||
)
|
||||
|
||||
if brand_feat_valid:
|
||||
raw_brand_feat_scores = matrix_data["brand_feature_matrix"][
|
||||
brand_idx_feature, valid_sketch_idxs_feature
|
||||
]
|
||||
raw_brand_feat_scores = (raw_brand_feat_scores - np.min(raw_brand_feat_scores)) / (
|
||||
np.max(raw_brand_feat_scores) - np.min(raw_brand_feat_scores) + 1e-8
|
||||
)
|
||||
processed_brand_feat = raw_brand_feat_scores
|
||||
|
||||
# 如果 processed_feat 是空的,替换为全 0,避免 shape 不一致
|
||||
if processed_feat.size == 0:
|
||||
processed_feat = np.zeros_like(processed_brand_feat)
|
||||
|
||||
final_scores = processed_inter + 0.3 * (
|
||||
(1 - brand_scale) * processed_feat + brand_scale * processed_brand_feat
|
||||
)
|
||||
else:
|
||||
# brand 信息不可用
|
||||
final_scores = processed_inter + 0.3 * processed_feat if processed_feat.size > 0 else processed_inter
|
||||
else:
|
||||
final_scores = processed_inter + 0.3 * processed_feat if processed_feat.size > 0 else processed_inter
|
||||
|
||||
valid_sketch_idxs = matrix_data["cached_valid_idxs"][cache_key]
|
||||
|
||||
# 概率采样
|
||||
scores = np.array(final_scores)
|
||||
|
||||
# 调整后的概率转换(带温度控制的softmax)
|
||||
def calibrated_softmax(scores, temperature=1.0):
|
||||
scores = scores / temperature
|
||||
scale = scores - max(scores)
|
||||
exps = np.exp(scale)
|
||||
return exps / np.sum(exps)
|
||||
|
||||
probs = calibrated_softmax(scores, 0.09)
|
||||
|
||||
chosen_indices = np.random.choice(
|
||||
len(valid_sketch_idxs),
|
||||
size=min(num_recommendations, len(valid_sketch_idxs)),
|
||||
p=probs,
|
||||
replace=False
|
||||
)
|
||||
recommendations = [matrix_data["iid_to_sketch"][valid_sketch_idxs[idx]] for idx in chosen_indices]
|
||||
|
||||
logger.info(f"推荐生成完成,耗时: {time.time() - start_time:.2f}秒")
|
||||
return recommendations
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"推荐失败: {str(e)}", exc_info=True)
|
||||
logger.error("获取用户偏好数据失败: %s", e, exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -1,38 +1,42 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.api import api_attribute_retrieve, api_query_image
|
||||
from app.api import api_attribute_retrieve
|
||||
from app.api import api_brand_dna
|
||||
from app.api import api_brighten
|
||||
from app.api import api_chat_robot
|
||||
from app.api import api_clothing_seg
|
||||
from app.api import api_design
|
||||
from app.api import api_design_pre_processing
|
||||
from app.api import api_extraction_project_info
|
||||
from app.api import api_generate_image
|
||||
from app.api import api_image2sketch
|
||||
from app.api import api_mannequins_edit
|
||||
from app.api import api_pose_transform
|
||||
from app.api import api_precompute
|
||||
from app.api import api_prompt_generation
|
||||
from app.api import api_recommendation
|
||||
from app.api import api_super_resolution
|
||||
from app.api import api_test
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
router.include_router(api_test.router, tags=["test"], prefix="/test")
|
||||
router.include_router(api_super_resolution.router, tags=["super_resolution"], prefix="/api")
|
||||
router.include_router(api_generate_image.router, tags=["generate_image"], prefix="/api")
|
||||
router.include_router(api_attribute_retrieve.router, tags=["attribute_retrieve"], prefix="/api")
|
||||
router.include_router(api_design.router, tags=['design'], prefix="/api")
|
||||
router.include_router(api_chat_robot.router, tags=['chat_robot'], prefix="/api")
|
||||
router.include_router(api_prompt_generation.router, tags=['prompt_generation'], prefix="/api")
|
||||
router.include_router(api_design_pre_processing.router, tags=['design_pre_processing'], prefix="/api")
|
||||
router.include_router(api_image2sketch.router, tags=['api_image2sketch'], prefix="/api")
|
||||
router.include_router(api_brighten.router, tags=['api_brighten'], prefix="/api")
|
||||
router.include_router(api_query_image.router, tags=['api_query_image'], prefix="/api")
|
||||
router.include_router(api_brand_dna.router, tags=['api_brand_dna'], prefix="/api")
|
||||
router.include_router(api_recommendation.router, tags=['api_recommendation'], prefix="/api")
|
||||
router.include_router(api_precompute.router, tags=['api_precompute'], prefix="/api")
|
||||
router.include_router(api_mannequins_edit.router, tags=['api_mannequins_edit'], prefix="/api")
|
||||
router.include_router(api_pose_transform.router, tags=['api_pose_transform'], prefix="/api")
|
||||
router.include_router(api_clothing_seg.router, tags=['api_clothing_seg'], prefix="/api")
|
||||
router.include_router(api_extraction_project_info.router, tags=['api_extraction_project_info'], prefix="/api")
|
||||
|
||||
"""停用"""
|
||||
# from app.api import api_chat_robot
|
||||
# from app.api import api_query_image
|
||||
# from app.api import api_brighten
|
||||
# from app.api import api_extraction_project_info
|
||||
# from app.api import api_image2sketch
|
||||
# from app.api import api_super_resolution
|
||||
# router.include_router(api_image2sketch.router, tags=['api_image2sketch'], prefix="/api")
|
||||
# router.include_router(api_query_image.router, tags=['api_query_image'], prefix="/api")
|
||||
# router.include_router(api_chat_robot.router, tags=['chat_robot'], prefix="/api")
|
||||
# router.include_router(api_super_resolution.router, tags=["super_resolution"], prefix="/api")
|
||||
# router.include_router(api_brighten.router, tags=['api_brighten'], prefix="/api")
|
||||
# router.include_router(api_extraction_project_info.router, tags=['api_extraction_project_info'], prefix="/api")
|
||||
|
||||
@@ -27,7 +27,7 @@ def super_resolution(request_item: SuperResolutionModel, background_tasks: Backg
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info(f"super_resolution request item is : @@@@@@:{json.dumps(request_item.dict())}")
|
||||
logger.info(f"super_resolution request item is : @@@@@@:{json.dumps(request_item.dict(),indent=4)}")
|
||||
service = SuperResolution(request_item)
|
||||
background_tasks.add_task(service.sr_result)
|
||||
except Exception as e:
|
||||
|
||||
@@ -4,8 +4,7 @@ import logging
|
||||
from fastapi import APIRouter
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.core.config import SR_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES, GRI_RABBITMQ_QUEUES, OSS, JAVA_STREAM_API_URL, GMV_RABBITMQ_QUEUES, SLOGAN_RABBITMQ_QUEUES, GEN_SINGLE_LOGO_RABBITMQ_QUEUES, PS_RABBITMQ_QUEUES, BATCH_GPI_RABBITMQ_QUEUES, BATCH_GRI_RABBITMQ_QUEUES, \
|
||||
BATCH_PS_RABBITMQ_QUEUES, RABBITMQ_ENV
|
||||
from app.core.config import settings, SR_RABBITMQ_QUEUES, GMV_RABBITMQ_QUEUES, PS_RABBITMQ_QUEUES, SLOGAN_RABBITMQ_QUEUES, GI_RABBITMQ_QUEUES, GPI_RABBITMQ_QUEUES, GRI_RABBITMQ_QUEUES, BATCH_GPI_RABBITMQ_QUEUES, BATCH_GRI_RABBITMQ_QUEUES, BATCH_PS_RABBITMQ_QUEUES
|
||||
from app.schemas.response_template import ResponseModel
|
||||
|
||||
logger = logging.getLogger()
|
||||
@@ -15,9 +14,9 @@ router = APIRouter()
|
||||
@router.get("{id}")
|
||||
def test(id: int):
|
||||
data = {
|
||||
"RABBITMQ_ENV":RABBITMQ_ENV,
|
||||
"超分 SR_RABBITMQ_QUEUES": SR_RABBITMQ_QUEUES,
|
||||
"多视角 GMV_RABBITMQ_QUEUES": GMV_RABBITMQ_QUEUES,
|
||||
"RABBITMQ_ENV": settings.SERVE_ENV,
|
||||
# "超分 SR_RABBITMQ_QUEUES": SR_RABBITMQ_QUEUES,
|
||||
# "多视角 GMV_RABBITMQ_QUEUES": GMV_RABBITMQ_QUEUES,
|
||||
"pose transform PS_RABBITMQ_QUEUES": PS_RABBITMQ_QUEUES,
|
||||
"logan SLOGAN_RABBITMQ_QUEUES": SLOGAN_RABBITMQ_QUEUES,
|
||||
"image and single logo GI_RABBITMQ_QUEUES": GI_RABBITMQ_QUEUES,
|
||||
@@ -29,10 +28,9 @@ def test(id: int):
|
||||
"batch relight BATCH_GRI_RABBITMQ_QUEUES": BATCH_GRI_RABBITMQ_QUEUES,
|
||||
"batch pose transform BATCH_PS_RABBITMQ_QUEUES": BATCH_PS_RABBITMQ_QUEUES,
|
||||
|
||||
"JAVA_STREAM_API_URL": JAVA_STREAM_API_URL,
|
||||
"local_oss_server": OSS
|
||||
"JAVA_STREAM_API_URL": settings.JAVA_STREAM_API_URL,
|
||||
}
|
||||
logger.info(json.dumps(data))
|
||||
logger.info(json.dumps(data, ensure_ascii=False, indent=4))
|
||||
if id == 1:
|
||||
raise HTTPException(status_code=404, detail="Item not found")
|
||||
|
||||
|
||||
235
app/core/config.backup.py
Normal file
235
app/core/config.backup.py
Normal file
@@ -0,0 +1,235 @@
|
||||
import os
|
||||
|
||||
import pika
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import BaseSettings
|
||||
|
||||
BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../'))
|
||||
load_dotenv(os.path.join(BASE_DIR, '.env'))
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
PROJECT_NAME: str = 'FASTAPI BASE'
|
||||
SECRET_KEY: str = ''
|
||||
API_PREFIX: str = ''
|
||||
BACKEND_CORS_ORIGINS: list[str] = ['*']
|
||||
DATABASE_URL: str = ''
|
||||
ACCESS_TOKEN_EXPIRE_SECONDS: int = 60 * 60 * 24 * 7 # Token expired after 7 days
|
||||
SECURITY_ALGORITHM: str = 'HS256'
|
||||
LOGGING_CONFIG_FILE: str = os.path.join(BASE_DIR, 'logging_env.py')
|
||||
|
||||
|
||||
OSS = "minio"
|
||||
DEBUG = False
|
||||
if DEBUG:
|
||||
LOGS_PATH = "logs/"
|
||||
CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv"
|
||||
SEG_CACHE_PATH = "../seg_cache/"
|
||||
POSE_TRANSFORM_VIDEO_PATH = "../pose_transform_video/"
|
||||
RECOMMEND_PATH_PREFIX = "service/recommend/"
|
||||
CHROMADB_PATH = "./chromadb/"
|
||||
else:
|
||||
LOGS_PATH = "app/logs/"
|
||||
CATEGORY_PATH = "app/service/attribute/config/descriptor/category/category_dis.csv"
|
||||
SEG_CACHE_PATH = "/seg_cache/"
|
||||
POSE_TRANSFORM_VIDEO_PATH = "/pose_transform_video/"
|
||||
RECOMMEND_PATH_PREFIX = "app/service/recommend/"
|
||||
CHROMADB_PATH = "/chromadb/"
|
||||
|
||||
# RABBITMQ_ENV = "" # 生产环境
|
||||
RABBITMQ_ENV = os.getenv("RABBITMQ_ENV", "-dev")
|
||||
# RABBITMQ_ENV = "-local" # 本地测试环境
|
||||
|
||||
if RABBITMQ_ENV == "-dev":
|
||||
JAVA_STREAM_API_URL = f"https://develop.api.aida.com.hk/api/third/party/receiveDesignResults"
|
||||
elif RABBITMQ_ENV == "-prod":
|
||||
JAVA_STREAM_API_URL = f"https://api.aida.com.hk/api/third/party/receiveDesignResults"
|
||||
|
||||
settings = Settings()
|
||||
|
||||
# minio 配置
|
||||
MINIO_URL = "www.minio-api.aida.com.hk"
|
||||
MINIO_ACCESS = 'vXKFLSJkYeEq2DrSZvkB'
|
||||
MINIO_SECRET = 'uKTZT3x7C43WvPN9QTc99DiRkwddWZrG9Uh3JVlR'
|
||||
MINIO_SECURE = True
|
||||
|
||||
# S3 配置
|
||||
S3_ACCESS_KEY = "AKIAVD3OJIMF6UJFLSHZ"
|
||||
S3_AWS_SECRET_ACCESS_KEY = "LNIwFFB27/QedtZ+Q/viVUoX9F5x1DbuM8N0DkD8"
|
||||
S3_REGION_NAME = "ap-east-1"
|
||||
|
||||
# redis 配置
|
||||
REDIS_HOST = "10.1.1.240"
|
||||
REDIS_PORT = "6379"
|
||||
REDIS_DB = "2"
|
||||
|
||||
# rabbitmq config
|
||||
RABBITMQ_PARAMS = {
|
||||
"host": "18.167.251.121",
|
||||
"port": 5672,
|
||||
"credentials": pika.credentials.PlainCredentials(username='rabbit', password='123456'),
|
||||
"virtual_host": "/"
|
||||
}
|
||||
|
||||
# milvus 配置
|
||||
MILVUS_URL = "http://10.1.1.240:19530"
|
||||
MILVUS_TOKEN = "root:Milvus"
|
||||
MILVUS_ALIAS = "default"
|
||||
MILVUS_TABLE_KEYPOINT = "keypoint_cache_2"
|
||||
MILVUS_TABLE_SEG = "seg_cache"
|
||||
|
||||
# Mysql 配置
|
||||
DB_HOST = '18.167.251.121' # 数据库主机地址
|
||||
# DB_PORT = int( 33006)
|
||||
DB_PORT = 33008 # 数据库端口
|
||||
DB_USERNAME = 'aida_con_python' # 数据库用户名
|
||||
DB_PASSWORD = '123456' # 数据库密码
|
||||
DB_NAME = 'aida' # 数据库库名
|
||||
|
||||
# openai
|
||||
os.environ['SERPAPI_API_KEY'] = "a793513017b0718db7966207c31703d280d12435c982f1e67bbcbffa52e7632c"
|
||||
OPENAI_STREAM = True
|
||||
BUFFER_THRESHOLD = 6 # must be even number
|
||||
SINGLE_TOKEN_THRESHOLD = 200
|
||||
TOKEN_THRESHOLD = 600
|
||||
OPENAI_TEMPERATURE = 0
|
||||
|
||||
# OPENAI_API_KEY = "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"
|
||||
OPENAI_API_KEY = "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"
|
||||
OPENAI_MODEL = "gpt-3.5-turbo-0613"
|
||||
OPENAI_MODEL_LIST = {"gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-16k-0613",
|
||||
"gpt-4-0314",
|
||||
"gpt-4-32k-0314",
|
||||
"gpt-4-0613",
|
||||
"gpt-4-32k-0613", }
|
||||
|
||||
# SR service config
|
||||
SR_MODEL_NAME = "super_resolution"
|
||||
SR_TRITON_URL = "10.1.1.240:10031"
|
||||
SR_MINIO_BUCKET = "aida-users"
|
||||
SR_RABBITMQ_QUEUES = f"SuperResolution{RABBITMQ_ENV}"
|
||||
|
||||
# GenerateImage service config
|
||||
FAST_GI_MODEL_URL = '10.1.1.243:10011'
|
||||
FAST_GI_MODEL_NAME = 'stable_diffusion_xl'
|
||||
|
||||
GI_MODEL_URL = '10.1.1.240:10061'
|
||||
GI_MODEL_NAME = 'flux'
|
||||
|
||||
GMV_MODEL_URL = '10.1.1.243:10081'
|
||||
GMV_MODEL_NAME = 'multi_view'
|
||||
|
||||
GMV_RABBITMQ_QUEUES = f"GenerateMultiView{RABBITMQ_ENV}"
|
||||
|
||||
GI_MINIO_BUCKET = "aida-users"
|
||||
GI_RABBITMQ_QUEUES = f"GenerateImage{RABBITMQ_ENV}"
|
||||
GI_SYS_IMAGE_URL = "aida-sys-image/generate_image/white_image.jpg"
|
||||
|
||||
# SLOGAN service config
|
||||
SLOGAN_RABBITMQ_QUEUES = f"Slogan{RABBITMQ_ENV}"
|
||||
|
||||
# Generate Single Logo service config
|
||||
GSL_MODEL_URL = '10.1.1.243:10041'
|
||||
GSL_MINIO_BUCKET = "aida-users"
|
||||
GSL_MODEL_NAME = 'stable_diffusion_xl_transparent'
|
||||
GEN_SINGLE_LOGO_RABBITMQ_QUEUES = f"GenSingleLogo{RABBITMQ_ENV}"
|
||||
|
||||
# Generate Product service config
|
||||
# GPI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"ToProductImage{RABBITMQ_ENV}")
|
||||
# GPI_MODEL_NAME_OVERALL = 'sdxl_ensemble_all'
|
||||
# GPI_MODEL_URL = '10.1.1.243:10051'
|
||||
|
||||
# Generate Product service config 旧版product img 模型
|
||||
GPI_RABBITMQ_QUEUES = f"ToProductImage{RABBITMQ_ENV}"
|
||||
BATCH_GPI_RABBITMQ_QUEUES = f"BatchToProductImage{RABBITMQ_ENV}"
|
||||
GPI_MODEL_NAME_OVERALL = 'diffusion_ensemble_all'
|
||||
GPI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_cnet'
|
||||
GPI_MODEL_URL = '10.1.1.243:10051'
|
||||
|
||||
# Generate Single Logo service config
|
||||
GRI_RABBITMQ_QUEUES = f"Relight{RABBITMQ_ENV}"
|
||||
BATCH_GRI_RABBITMQ_QUEUES = f"BatchRelight{RABBITMQ_ENV}"
|
||||
GRI_MODEL_NAME_OVERALL = 'diffusion_relight_ensemble'
|
||||
GRI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_relight'
|
||||
GRI_MODEL_URL = '10.1.1.240:10051'
|
||||
|
||||
# Pose Transform service config
|
||||
|
||||
PS_RABBITMQ_QUEUES = f"PoseTransform{RABBITMQ_ENV}"
|
||||
BATCH_PS_RABBITMQ_QUEUES = f"BatchPoseTransform{RABBITMQ_ENV}"
|
||||
PT_MODEL_URL = '10.1.1.243:10061'
|
||||
|
||||
# SEG service config
|
||||
SEGMENTATION = {
|
||||
"new_model_name": "seg_knet",
|
||||
"name": "seg_ocrnet_hr18",
|
||||
"input": "seg_input__0",
|
||||
"output": "seg_output__0",
|
||||
}
|
||||
# ollama config
|
||||
OLLAMA_URL = "http://10.1.1.240:11434/api/embeddings"
|
||||
|
||||
# design batch
|
||||
BATCH_DESIGN_RABBITMQ_QUEUES = f"DesignBatch{RABBITMQ_ENV}"
|
||||
|
||||
# DESIGN config
|
||||
DESIGN_MODEL_URL = '10.1.1.240:10000'
|
||||
AIDA_CLOTHING = "aida-clothing"
|
||||
KEYPOINT_RESULT_TABLE_FIELD_SET = ('neckline_left', 'neckline_right', 'shoulder_left', 'shoulder_right', 'armpit_left', 'armpit_right',
|
||||
'cuff_left_in', 'cuff_left_out', 'cuff_right_in', 'cuff_right_out', 'waistband_left', 'waistband_right')
|
||||
|
||||
# DESIGN 预处理
|
||||
IF_DEBUG_SHOW = False
|
||||
|
||||
# 优先级
|
||||
PRIORITY_DICT = {
|
||||
'earring_front': 99,
|
||||
'bag_front': 98,
|
||||
'hairstyle_front': 97,
|
||||
'outwear_front': 20,
|
||||
'tops_front': 19,
|
||||
'dress_front': 18,
|
||||
'blouse_front': 17,
|
||||
'skirt_front': 16,
|
||||
'trousers_front': 15,
|
||||
'bottoms_front': 14,
|
||||
'shoes_right': 1,
|
||||
'shoes_left': 1,
|
||||
'body': 0,
|
||||
'bottoms_back': -14,
|
||||
'trousers_back': -15,
|
||||
'skirt_back': -16,
|
||||
'blouse_back': -17,
|
||||
'dress_back': -18,
|
||||
'tops_back': -19,
|
||||
'outwear_back': -20,
|
||||
'hairstyle_back': -97,
|
||||
'bag_back': -98,
|
||||
'earring_back': -99,
|
||||
}
|
||||
|
||||
QWEN_API_KEY = "sk-f31c29e61ac2498ba5e307aaa6dc10e0"
|
||||
|
||||
DB_CONFIG = {
|
||||
"host": "18.167.251.121",
|
||||
"port": 3306,
|
||||
"user": "root",
|
||||
"password": "QWa998345",
|
||||
"database": "aida",
|
||||
"charset": "utf8mb4"
|
||||
}
|
||||
|
||||
TABLE_CATEGORIES = {
|
||||
"female_dress": "female/dress",
|
||||
"female_outwear": "female/outwear",
|
||||
"female_trousers": "female/trousers",
|
||||
"female_skirt": "female/skirt",
|
||||
"female_blouse": "female/blouse",
|
||||
"male_tops": "male/tops",
|
||||
"male_bottoms": "male/bottoms",
|
||||
"male_outwear": "male/outwear"
|
||||
}
|
||||
|
||||
# --- ComfyUI 配置信息 ---
|
||||
COMFYUI_SERVER_ADDRESS = "10.1.2.227:8080" # 替换为您的 ComfyUI 服务器地址
|
||||
@@ -1,188 +1,91 @@
|
||||
import os
|
||||
|
||||
import pika
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import BaseSettings
|
||||
|
||||
BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../'))
|
||||
load_dotenv(os.path.join(BASE_DIR, '.env'))
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
PROJECT_NAME: str = 'FASTAPI BASE'
|
||||
SECRET_KEY: str = ''
|
||||
API_PREFIX: str = ''
|
||||
BACKEND_CORS_ORIGINS: list[str] = ['*']
|
||||
DATABASE_URL: str = ''
|
||||
ACCESS_TOKEN_EXPIRE_SECONDS: int = 60 * 60 * 24 * 7 # Token expired after 7 days
|
||||
SECURITY_ALGORITHM: str = 'HS256'
|
||||
LOGGING_CONFIG_FILE: str = os.path.join(BASE_DIR, 'logging_env.py')
|
||||
"""
|
||||
应用配置类。Pydantic Settings 会自动从环境变量和 .env 文件中加载这些值。
|
||||
"""
|
||||
model_config = SettingsConfigDict(
|
||||
env_file='.env',
|
||||
env_file_encoding='utf-8',
|
||||
# extra='ignore' # 忽略环境变量中多余的键
|
||||
)
|
||||
# --- 服务端口配置信息 ---
|
||||
PORT: int = Field(default=8001, description="")
|
||||
# --- 服务环境 配置信息 ---
|
||||
SERVE_ENV: str = Field(default='', description="")
|
||||
# --- 开发状态 配置信息 ---
|
||||
DEBUG: bool = Field(default=False, description="")
|
||||
# --- 千问api 配置信息 ---
|
||||
QWEN_API_KEY: str = Field(default="", description="")
|
||||
|
||||
# --- ComfyUI 配置信息 ---
|
||||
COMFYUI_SERVER_ADDRESS: str = Field(default='', description="")
|
||||
|
||||
OSS = "minio"
|
||||
DEBUG = False
|
||||
if DEBUG:
|
||||
LOGS_PATH = "logs/"
|
||||
CATEGORY_PATH = "service/attribute/config/descriptor/category/category_dis.csv"
|
||||
SEG_CACHE_PATH = "../seg_cache/"
|
||||
POSE_TRANSFORM_VIDEO_PATH = "../pose_transform_video/"
|
||||
RECOMMEND_PATH_PREFIX = "service/recommend/"
|
||||
CHROMADB_PATH = "./chromadb/"
|
||||
else:
|
||||
LOGS_PATH = "app/logs/"
|
||||
CATEGORY_PATH = "app/service/attribute/config/descriptor/category/category_dis.csv"
|
||||
SEG_CACHE_PATH = "/seg_cache/"
|
||||
POSE_TRANSFORM_VIDEO_PATH = "/pose_transform_video/"
|
||||
RECOMMEND_PATH_PREFIX = "app/service/recommend/"
|
||||
CHROMADB_PATH = "/chromadb/"
|
||||
# --- minio 配置信息 ---
|
||||
MINIO_URL: str = Field(default='', description="")
|
||||
MINIO_ACCESS: str = Field(default='', description="")
|
||||
MINIO_SECRET: str = Field(default='', description="")
|
||||
MINIO_SECURE: bool = Field(default=True, description="")
|
||||
|
||||
# RABBITMQ_ENV = "" # 生产环境
|
||||
RABBITMQ_ENV = os.getenv("RABBITMQ_ENV", "-dev")
|
||||
# RABBITMQ_ENV = "-local" # 本地测试环境
|
||||
# --- redis 配置信息 ---
|
||||
REDIS_HOST: str = Field(default='', description="")
|
||||
REDIS_PORT: str = Field(default='', description="")
|
||||
REDIS_DB: int = Field(default=0, description="")
|
||||
|
||||
# --- mysql 配置信息 ---
|
||||
MYSQL_HOST: str = Field(default='', description="")
|
||||
MYSQL_PORT: int = Field(default='', description="")
|
||||
MYSQL_USER: str = Field(default='', description="")
|
||||
MYSQL_PASSWORD: str = Field(default='', description="")
|
||||
MYSQL_DB: str = Field(default='', description="")
|
||||
MYSQL_CHARSET: str = Field(default='utf8mb4', description="")
|
||||
|
||||
# --- rabbit-mq 配置信息 ---
|
||||
MQ_HOST: str = Field(default='', description="")
|
||||
MQ_PORT: str = Field(default='', description="")
|
||||
MQ_USERNAME: str = Field(default='', description="")
|
||||
MQ_PASSWORD: str = Field(default='', description="")
|
||||
MQ_VIRTUAL_HOST: str = Field(default='/', description="")
|
||||
MQ_ENV: str = Field(default='', description="")
|
||||
|
||||
# --- milvus 配置信息 ---
|
||||
MILVUS_URL: str = Field(default='', description="")
|
||||
MILVUS_TOKEN: str = Field(default='', description="")
|
||||
MILVUS_ALIAS: str = Field(default='', description="")
|
||||
|
||||
# --- ollama 配置信息 ---
|
||||
CHROMADB_PATH: str = Field(default='', description="")
|
||||
|
||||
# --- ollama 配置信息 ---
|
||||
OLLAMA_URL: str = Field(default='', description="")
|
||||
|
||||
# --- Design Callback Java 接口 ---
|
||||
JAVA_STREAM_API_URL: str = Field(default='', description="")
|
||||
|
||||
# --- 其他配置信息 以下均为Docker容器内配置---
|
||||
LOGS_PATH: str = Field(default="/logs/", description="")
|
||||
CATEGORY_PATH: str = Field(default="/app/service/attribute/config/descriptor/category/category_dis.csv", description="")
|
||||
SEG_CACHE_PATH: str = Field(default="/seg_cache/", description="")
|
||||
RECOMMEND_PATH_PREFIX: str = Field(default="/app/service/recommend/", description="")
|
||||
|
||||
if RABBITMQ_ENV == "-dev":
|
||||
JAVA_STREAM_API_URL = f"https://develop.api.aida.com.hk/api/third/party/receiveDesignResults"
|
||||
elif RABBITMQ_ENV == "-prod":
|
||||
JAVA_STREAM_API_URL = f"https://api.aida.com.hk/api/third/party/receiveDesignResults"
|
||||
|
||||
settings = Settings()
|
||||
|
||||
# minio 配置
|
||||
MINIO_URL = "www.minio-api.aida.com.hk"
|
||||
MINIO_ACCESS = 'vXKFLSJkYeEq2DrSZvkB'
|
||||
MINIO_SECRET = 'uKTZT3x7C43WvPN9QTc99DiRkwddWZrG9Uh3JVlR'
|
||||
MINIO_SECURE = True
|
||||
|
||||
# S3 配置
|
||||
S3_ACCESS_KEY = "AKIAVD3OJIMF6UJFLSHZ"
|
||||
S3_AWS_SECRET_ACCESS_KEY = "LNIwFFB27/QedtZ+Q/viVUoX9F5x1DbuM8N0DkD8"
|
||||
S3_REGION_NAME = "ap-east-1"
|
||||
|
||||
# redis 配置
|
||||
REDIS_HOST = "10.1.1.240"
|
||||
REDIS_PORT = "6379"
|
||||
REDIS_DB = "2"
|
||||
|
||||
# rabbitmq config
|
||||
RABBITMQ_PARAMS = {
|
||||
"host": "18.167.251.121",
|
||||
"port": 5672,
|
||||
"credentials": pika.credentials.PlainCredentials(username='rabbit', password='123456'),
|
||||
"virtual_host": "/"
|
||||
"""Design 服务"""
|
||||
# 推荐服装类别映射
|
||||
TABLE_CATEGORIES = {
|
||||
"female_dress": "female/dress",
|
||||
"female_outwear": "female/outwear",
|
||||
"female_trousers": "female/trousers",
|
||||
"female_skirt": "female/skirt",
|
||||
"female_blouse": "female/blouse",
|
||||
"male_tops": "male/tops",
|
||||
"male_bottoms": "male/bottoms",
|
||||
"male_outwear": "male/outwear"
|
||||
}
|
||||
|
||||
# milvus 配置
|
||||
MILVUS_URL = "http://10.1.1.240:19530"
|
||||
MILVUS_TOKEN = "root:Milvus"
|
||||
MILVUS_ALIAS = "default"
|
||||
MILVUS_TABLE_KEYPOINT = "keypoint_cache_2"
|
||||
MILVUS_TABLE_SEG = "seg_cache"
|
||||
|
||||
# Mysql 配置
|
||||
DB_HOST = '18.167.251.121' # 数据库主机地址
|
||||
# DB_PORT = int( 33006)
|
||||
DB_PORT = 33008 # 数据库端口
|
||||
DB_USERNAME = 'aida_con_python' # 数据库用户名
|
||||
DB_PASSWORD = '123456' # 数据库密码
|
||||
DB_NAME = 'aida' # 数据库库名
|
||||
|
||||
# openai
|
||||
os.environ['SERPAPI_API_KEY'] = "a793513017b0718db7966207c31703d280d12435c982f1e67bbcbffa52e7632c"
|
||||
OPENAI_STREAM = True
|
||||
BUFFER_THRESHOLD = 6 # must be even number
|
||||
SINGLE_TOKEN_THRESHOLD = 200
|
||||
TOKEN_THRESHOLD = 600
|
||||
OPENAI_TEMPERATURE = 0
|
||||
|
||||
# OPENAI_API_KEY = "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"
|
||||
OPENAI_API_KEY = "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"
|
||||
OPENAI_MODEL = "gpt-3.5-turbo-0613"
|
||||
OPENAI_MODEL_LIST = {"gpt-3.5-turbo-0613",
|
||||
"gpt-3.5-turbo-16k-0613",
|
||||
"gpt-4-0314",
|
||||
"gpt-4-32k-0314",
|
||||
"gpt-4-0613",
|
||||
"gpt-4-32k-0613", }
|
||||
|
||||
# SR service config
|
||||
SR_MODEL_NAME = "super_resolution"
|
||||
SR_TRITON_URL = "10.1.1.240:10031"
|
||||
SR_MINIO_BUCKET = "aida-users"
|
||||
SR_RABBITMQ_QUEUES = f"SuperResolution{RABBITMQ_ENV}"
|
||||
|
||||
# GenerateImage service config
|
||||
FAST_GI_MODEL_URL = '10.1.1.243:10011'
|
||||
FAST_GI_MODEL_NAME = 'stable_diffusion_xl'
|
||||
|
||||
GI_MODEL_URL = '10.1.1.240:10061'
|
||||
GI_MODEL_NAME = 'flux'
|
||||
|
||||
GMV_MODEL_URL = '10.1.1.243:10081'
|
||||
GMV_MODEL_NAME = 'multi_view'
|
||||
|
||||
GMV_RABBITMQ_QUEUES = f"GenerateMultiView{RABBITMQ_ENV}"
|
||||
|
||||
GI_MINIO_BUCKET = "aida-users"
|
||||
GI_RABBITMQ_QUEUES = f"GenerateImage{RABBITMQ_ENV}"
|
||||
GI_SYS_IMAGE_URL = "aida-sys-image/generate_image/white_image.jpg"
|
||||
|
||||
# SLOGAN service config
|
||||
SLOGAN_RABBITMQ_QUEUES = f"Slogan{RABBITMQ_ENV}"
|
||||
|
||||
# Generate Single Logo service config
|
||||
GSL_MODEL_URL = '10.1.1.243:10041'
|
||||
GSL_MINIO_BUCKET = "aida-users"
|
||||
GSL_MODEL_NAME = 'stable_diffusion_xl_transparent'
|
||||
GEN_SINGLE_LOGO_RABBITMQ_QUEUES = f"GenSingleLogo{RABBITMQ_ENV}"
|
||||
|
||||
# Generate Product service config
|
||||
# GPI_RABBITMQ_QUEUES = os.getenv("GEN_PRODUCT_IMAGE_RABBITMQ_QUEUES", f"ToProductImage{RABBITMQ_ENV}")
|
||||
# GPI_MODEL_NAME_OVERALL = 'sdxl_ensemble_all'
|
||||
# GPI_MODEL_URL = '10.1.1.243:10051'
|
||||
|
||||
# Generate Product service config 旧版product img 模型
|
||||
GPI_RABBITMQ_QUEUES = f"ToProductImage{RABBITMQ_ENV}"
|
||||
BATCH_GPI_RABBITMQ_QUEUES = f"BatchToProductImage{RABBITMQ_ENV}"
|
||||
GPI_MODEL_NAME_OVERALL = 'diffusion_ensemble_all'
|
||||
GPI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_cnet'
|
||||
GPI_MODEL_URL = '10.1.1.243:10051'
|
||||
|
||||
# Generate Single Logo service config
|
||||
GRI_RABBITMQ_QUEUES = f"Relight{RABBITMQ_ENV}"
|
||||
BATCH_GRI_RABBITMQ_QUEUES = f"BatchRelight{RABBITMQ_ENV}"
|
||||
GRI_MODEL_NAME_OVERALL = 'diffusion_relight_ensemble'
|
||||
GRI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_relight'
|
||||
GRI_MODEL_URL = '10.1.1.240:10051'
|
||||
|
||||
# Pose Transform service config
|
||||
|
||||
PS_RABBITMQ_QUEUES = f"PoseTransform{RABBITMQ_ENV}"
|
||||
BATCH_PS_RABBITMQ_QUEUES = f"BatchPoseTransform{RABBITMQ_ENV}"
|
||||
PT_MODEL_URL = '10.1.1.243:10061'
|
||||
|
||||
# SEG service config
|
||||
SEGMENTATION = {
|
||||
"new_model_name": "seg_knet",
|
||||
"name": "seg_ocrnet_hr18",
|
||||
"input": "seg_input__0",
|
||||
"output": "seg_output__0",
|
||||
}
|
||||
# ollama config
|
||||
OLLAMA_URL = "http://10.1.1.240:11434/api/embeddings"
|
||||
|
||||
# design batch
|
||||
BATCH_DESIGN_RABBITMQ_QUEUES = f"DesignBatch{RABBITMQ_ENV}"
|
||||
|
||||
# DESIGN config
|
||||
DESIGN_MODEL_URL = '10.1.1.240:10000'
|
||||
AIDA_CLOTHING = "aida-clothing"
|
||||
KEYPOINT_RESULT_TABLE_FIELD_SET = ('neckline_left', 'neckline_right', 'shoulder_left', 'shoulder_right', 'armpit_left', 'armpit_right',
|
||||
'cuff_left_in', 'cuff_left_out', 'cuff_right_in', 'cuff_right_out', 'waistband_left', 'waistband_right')
|
||||
|
||||
# DESIGN 预处理
|
||||
IF_DEBUG_SHOW = False
|
||||
|
||||
# 优先级
|
||||
# Design前后排优先级
|
||||
PRIORITY_DICT = {
|
||||
'earring_front': 99,
|
||||
'bag_front': 98,
|
||||
@@ -208,25 +111,71 @@ PRIORITY_DICT = {
|
||||
'bag_back': -98,
|
||||
'earring_back': -99,
|
||||
}
|
||||
# Design 关键点字段
|
||||
KEYPOINT_RESULT_TABLE_FIELD_SET = ('neckline_left', 'neckline_right', 'shoulder_left', 'shoulder_right', 'armpit_left', 'armpit_right', 'cuff_left_in', 'cuff_left_out', 'cuff_right_in', 'cuff_right_out', 'waistband_left', 'waistband_right')
|
||||
# milvus配置信息
|
||||
MILVUS_TABLE_KEYPOINT = "keypoint_cache_2"
|
||||
|
||||
QWEN_API_KEY = "sk-f31c29e61ac2498ba5e307aaa6dc10e0"
|
||||
# ollama 地址
|
||||
OLLAMA_URL = "http://10.1.1.240:11434/api/embeddings"
|
||||
|
||||
DB_CONFIG = {
|
||||
"host": "18.167.251.121",
|
||||
"port": 3306,
|
||||
"user": "root",
|
||||
"password": "QWa998345",
|
||||
"database": "aida",
|
||||
"charset": "utf8mb4"
|
||||
}
|
||||
"""Triton Server Config"""
|
||||
# Design
|
||||
DESIGN_MODEL_URL = '10.1.1.240:10000'
|
||||
DESIGN_MODEL_NAME = 'seg_knet'
|
||||
# Generate Image
|
||||
GI_MODEL_URL = '10.1.1.240:10061'
|
||||
GI_MODEL_NAME = 'flux'
|
||||
# Generate Single Logo
|
||||
GSL_MODEL_URL = '10.1.1.243:10041'
|
||||
GSL_MODEL_NAME = 'stable_diffusion_xl_transparent'
|
||||
# Generate Product (整套和单品)
|
||||
GPI_MODEL_URL = '10.1.1.243:10051'
|
||||
GPI_MODEL_NAME_OVERALL = 'diffusion_ensemble_all'
|
||||
GPI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_cnet'
|
||||
|
||||
TABLE_CATEGORIES = {
|
||||
"female_dress": "female/dress",
|
||||
"female_outwear": "female/outwear",
|
||||
"female_trousers": "female/trousers",
|
||||
"female_skirt": "female/skirt",
|
||||
"female_blouse": "female/blouse",
|
||||
"male_tops": "male/tops",
|
||||
"male_bottoms": "male/bottoms",
|
||||
"male_outwear": "male/outwear"
|
||||
}
|
||||
# 以下停用中...*************
|
||||
# 多视角生成
|
||||
GMV_MODEL_URL = '10.1.1.243:10081'
|
||||
GMV_MODEL_NAME = 'multi_view'
|
||||
# 超分
|
||||
SR_MODEL_NAME = "super_resolution"
|
||||
SR_TRITON_URL = "10.1.1.240:10031"
|
||||
# 打光
|
||||
GRI_MODEL_URL = '10.1.1.240:10051'
|
||||
GRI_MODEL_NAME_OVERALL = 'diffusion_relight_ensemble'
|
||||
GRI_MODEL_NAME_SINGLE = 'stable_diffusion_1_5_relight'
|
||||
# agent 图片生成
|
||||
FAST_GI_MODEL_URL = '10.1.1.243:10011'
|
||||
FAST_GI_MODEL_NAME = 'stable_diffusion_xl'
|
||||
# 图转视频 triton版
|
||||
PT_MODEL_URL = '10.1.1.243:10061'
|
||||
|
||||
# *************
|
||||
|
||||
"""MQ 队列信息"""
|
||||
# 生成图片 moodboard printboard sketchboard
|
||||
GI_RABBITMQ_QUEUES = f"GenerateImage-{settings.SERVE_ENV}"
|
||||
# 生成slogan
|
||||
SLOGAN_RABBITMQ_QUEUES = f"Slogan-{settings.SERVE_ENV}"
|
||||
# 转产品图
|
||||
GPI_RABBITMQ_QUEUES = f"ToProductImage-{settings.SERVE_ENV}"
|
||||
# 产品图转视频
|
||||
PS_RABBITMQ_QUEUES = f"PoseTransform-{settings.SERVE_ENV}"
|
||||
|
||||
# 以下停用中...*************
|
||||
# 产品图打光
|
||||
GRI_RABBITMQ_QUEUES = f"Relight-{settings.SERVE_ENV}"
|
||||
# 超分
|
||||
SR_RABBITMQ_QUEUES = f"SuperResolution-{settings.SERVE_ENV}"
|
||||
# 生成多视图
|
||||
GMV_RABBITMQ_QUEUES = f"GenerateMultiView-{settings.SERVE_ENV}"
|
||||
# 批量转产品图
|
||||
BATCH_GPI_RABBITMQ_QUEUES = f"BatchToProductImage-{settings.SERVE_ENV}"
|
||||
# 批量打光
|
||||
BATCH_GRI_RABBITMQ_QUEUES = f"BatchRelight-{settings.SERVE_ENV}"
|
||||
# 批量图片转视频
|
||||
BATCH_PS_RABBITMQ_QUEUES = f"BatchPoseTransform-{settings.SERVE_ENV}"
|
||||
# 批量design
|
||||
BATCH_DESIGN_RABBITMQ_QUEUES = f"DesignBatch-{settings.SERVE_ENV}"
|
||||
# *************
|
||||
|
||||
10
app/core/mysql_config.py
Normal file
10
app/core/mysql_config.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from app.core.config import settings
|
||||
|
||||
DB_CONFIG = {
|
||||
"host": settings.MYSQL_HOST,
|
||||
"port": settings.MYSQL_PORT,
|
||||
"user": settings.MYSQL_USER,
|
||||
"password": settings.MYSQL_PASSWORD,
|
||||
"database": settings.MYSQL_DB,
|
||||
"charset": settings.MYSQL_CHARSET,
|
||||
}
|
||||
10
app/core/rabbit_mq_config.py
Normal file
10
app/core/rabbit_mq_config.py
Normal file
@@ -0,0 +1,10 @@
|
||||
# rabbitmq config
|
||||
import pika
|
||||
from app.core.config import settings
|
||||
|
||||
RABBITMQ_PARAMS = {
|
||||
"host": settings.MQ_HOST,
|
||||
"port": settings.MQ_PORT,
|
||||
"credentials": pika.credentials.PlainCredentials(username=settings.MQ_USERNAME, password=settings.MQ_PASSWORD),
|
||||
"virtual_host": settings.MQ_VIRTUAL_HOST,
|
||||
}
|
||||
@@ -79,12 +79,8 @@
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
],
|
||||
"process_id": "87",
|
||||
"tasks_id": ,
|
||||
"tasks_id": ""
|
||||
}
|
||||
|
||||
|
||||
//用 openai jsonl
|
||||
//
|
||||
35
app/main.py
35
app/main.py
@@ -1,31 +1,40 @@
|
||||
# 1. 这里的顺序至关重要!必须在最顶端
|
||||
import sys
|
||||
|
||||
try:
|
||||
import asyncore
|
||||
except ImportError:
|
||||
import pyasyncore
|
||||
|
||||
sys.modules['asyncore'] = pyasyncore
|
||||
import logging.config
|
||||
|
||||
import uvicorn
|
||||
from apscheduler.schedulers.background import BackgroundScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from fastapi import FastAPI
|
||||
from fastapi import HTTPException, Request
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from app.api.api_route import router
|
||||
from app.core.config import settings
|
||||
from app.core.record_api_count import count_api_calls
|
||||
from app.schemas.response_template import ResponseModel
|
||||
from app.service.recommend.service import load_resources
|
||||
from logging_env import LOGGER_CONFIG_DICT
|
||||
from dotenv import load_dotenv
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
logging.config.dictConfig(LOGGER_CONFIG_DICT)
|
||||
logging.getLogger("pika").setLevel(logging.WARNING)
|
||||
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def get_application() -> FastAPI:
|
||||
application = FastAPI(
|
||||
title=settings.PROJECT_NAME, docs_url="/docs", redoc_url='/re-docs',
|
||||
openapi_url=f"{settings.API_PREFIX}/openapi.json",
|
||||
docs_url="/docs",
|
||||
redoc_url='/re-docs',
|
||||
openapi_url=f"/openapi.json",
|
||||
description='''
|
||||
Base frame with FastAPI
|
||||
- Super Resolution API
|
||||
@@ -34,13 +43,13 @@ def get_application() -> FastAPI:
|
||||
)
|
||||
application.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=[str(origin) for origin in settings.BACKEND_CORS_ORIGINS],
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
application.middleware("http")(count_api_calls)
|
||||
application.include_router(router=router, prefix=settings.API_PREFIX)
|
||||
application.include_router(router=router)
|
||||
return application
|
||||
|
||||
|
||||
@@ -48,14 +57,12 @@ app = get_application()
|
||||
|
||||
|
||||
@app.exception_handler(HTTPException)
|
||||
async def http_exception_handler(request: Request, exc: HTTPException):
|
||||
async def http_exception_handler(exc: HTTPException):
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content=ResponseModel(code=exc.status_code, msg=exc.detail, data=exc.detail).dict()
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
uvicorn.run(app, host="0.0.0.0", port=settings.PORT)
|
||||
|
||||
23
app/schemas/comfyui_i2v.py
Normal file
23
app/schemas/comfyui_i2v.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ComfyuiPose2VModel(BaseModel):
|
||||
# 骨架生成视频
|
||||
image_url: str
|
||||
tasks_id: str
|
||||
pose_id: str
|
||||
|
||||
|
||||
class ComfyuiI2VModel(BaseModel):
|
||||
# 图生视频
|
||||
image_url: str
|
||||
prompt: str
|
||||
tasks_id: str
|
||||
|
||||
|
||||
class ComfyuiFLF2VModel(BaseModel):
|
||||
# 首尾帧生视频
|
||||
start_image_url: str
|
||||
end_image_url: str
|
||||
prompt: str
|
||||
tasks_id: str
|
||||
@@ -1,4 +1,15 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SAMRequestModel(BaseModel):
|
||||
user_id: int = Field(..., description="用户id, 必填字段")
|
||||
image_path: str = Field(..., description="图片路径,必填字段")
|
||||
type: str = Field(..., description="推理类型,必填字段")
|
||||
points: Optional[List[List[float]]] = None
|
||||
labels: Optional[List[int]] = None
|
||||
box: Optional[List[int]] = None
|
||||
|
||||
|
||||
class DesignModel(BaseModel):
|
||||
@@ -10,6 +21,7 @@ class DesignStreamModel(BaseModel):
|
||||
objects: list[dict]
|
||||
process_id: str
|
||||
requestId: str
|
||||
callback_url: str
|
||||
|
||||
|
||||
class DesignProgressModel(BaseModel):
|
||||
|
||||
@@ -1,22 +1,24 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
import logging
|
||||
from pprint import pprint
|
||||
import torch
|
||||
|
||||
import cv2
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from minio import Minio
|
||||
import torch
|
||||
import tritonclient.http as httpclient
|
||||
from app.core.config import *
|
||||
from minio import Minio
|
||||
|
||||
from app.core.config import settings, DESIGN_MODEL_URL
|
||||
from app.schemas.attribute_retrieve import AttributeRecognitionModel
|
||||
from app.service.utils.oss_client import oss_get_image
|
||||
from app.service.utils.new_oss_client import oss_get_image
|
||||
|
||||
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
|
||||
|
||||
|
||||
class AttributeRecognition:
|
||||
def __init__(self, const, request_data):
|
||||
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
self.request_data = []
|
||||
for i, sketch in enumerate(request_data):
|
||||
self.request_data.append(
|
||||
@@ -96,11 +98,12 @@ class AttributeRecognition:
|
||||
res = {**dict1, **dict2}
|
||||
return res
|
||||
|
||||
def get_image(self, url):
|
||||
@staticmethod
|
||||
def get_image(url):
|
||||
# response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1])
|
||||
# img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型
|
||||
# img = cv2.imdecode(img, cv2.IMREAD_COLOR) #
|
||||
img = oss_get_image(bucket=url.split("/", 1)[0], object_name=url.split("/", 1)[1], data_type="cv2")
|
||||
img = oss_get_image(oss_client=minio_client, bucket=url.split("/", 1)[0], object_name=url.split("/", 1)[1], data_type="cv2")
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
return img
|
||||
|
||||
|
||||
@@ -7,24 +7,25 @@
|
||||
@Date :2023/9/16 18:31:08
|
||||
@detail :
|
||||
"""
|
||||
from minio import Minio
|
||||
from skimage import transform
|
||||
import cv2
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from minio import Minio
|
||||
import tritonclient.http as httpclient
|
||||
import torch
|
||||
|
||||
from app.core.config import *
|
||||
from app.core.config import settings, DESIGN_MODEL_URL
|
||||
from app.schemas.attribute_retrieve import CategoryRecognitionModel
|
||||
from app.service.utils.oss_client import oss_get_image
|
||||
from app.service.utils.new_oss_client import oss_get_image
|
||||
|
||||
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
|
||||
|
||||
|
||||
class CategoryRecognition:
|
||||
def __init__(self, request_data):
|
||||
self.attr_type = pd.read_csv(CATEGORY_PATH)
|
||||
# self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
self.attr_type = pd.read_csv(settings.CATEGORY_PATH)
|
||||
self.request_data = []
|
||||
self.triton_client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
|
||||
for sketch in request_data:
|
||||
@@ -46,13 +47,14 @@ class CategoryRecognition:
|
||||
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
|
||||
return preprocessed_img
|
||||
|
||||
def get_image(self, url):
|
||||
@staticmethod
|
||||
def get_image(url):
|
||||
# Get data of an object.
|
||||
# Read data from response.
|
||||
# response = self.minio_client.get_object(url.split("/", 1)[0], url.split("/", 1)[1])
|
||||
# img = np.frombuffer(response.data, np.uint8) # 转成8位无符号整型
|
||||
# img = cv2.imdecode(img, cv2.IMREAD_COLOR) # 解码
|
||||
img = oss_get_image(bucket=url.split("/", 1)[0], object_name=url.split("/", 1)[1], data_type="cv2")
|
||||
img = oss_get_image(oss_client=minio_client, bucket=url.split("/", 1)[0], object_name=url.split("/", 1)[1], data_type="cv2")
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
return img
|
||||
|
||||
@@ -68,7 +70,7 @@ class CategoryRecognition:
|
||||
|
||||
colattr = list(self.attr_type['labelName'])
|
||||
|
||||
task = self.attr_type['taskName'][0]
|
||||
# self.attr_type['taskName'][0]
|
||||
|
||||
maxsc = np.max(scores[0][:5])
|
||||
indexs = np.argwhere(scores == maxsc)[:, 1]
|
||||
|
||||
@@ -9,15 +9,16 @@ import torch.nn.functional as F
|
||||
import tritonclient.http as httpclient
|
||||
from minio import Minio
|
||||
|
||||
from app.core.config import MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, DESIGN_MODEL_URL, CATEGORY_PATH
|
||||
from app.core.config import DESIGN_MODEL_URL
|
||||
from app.core.config import settings
|
||||
from app.schemas.brand_dna import BrandDnaModel
|
||||
from app.service.attribute.config import local_debug_const, const
|
||||
from app.service.attribute.config import const
|
||||
from app.service.utils.generate_uuid import generate_uuid
|
||||
from app.service.utils.new_oss_client import oss_upload_image, oss_get_image
|
||||
|
||||
logger = logging.getLogger()
|
||||
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
|
||||
|
||||
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
class BrandDna:
|
||||
@@ -25,7 +26,7 @@ class BrandDna:
|
||||
self.sketch_bucket = "test"
|
||||
self.image_url = request_item.image_url
|
||||
self.is_brand_dna = request_item.is_brand_dna
|
||||
self.attr_type = pd.read_csv(CATEGORY_PATH)
|
||||
self.attr_type = pd.read_csv(settings.CATEGORY_PATH)
|
||||
# self.attr_type = pd.read_csv(r"E:\workspace\trinity_client_aida\app\service\attribute\config\descriptor\category\category_dis.csv")
|
||||
self.att_client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
|
||||
self.seg_client = httpclient.InferenceServerClient(url='10.1.1.243:30000')
|
||||
|
||||
@@ -3,23 +3,25 @@ import logging
|
||||
import cv2
|
||||
import numpy as np
|
||||
import tritonclient.grpc as grpcclient
|
||||
from langchain.output_parsers import ResponseSchema, StructuredOutputParser
|
||||
from langchain_classic.output_parsers import ResponseSchema, StructuredOutputParser
|
||||
from langchain_community.chat_models import ChatTongyi
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
# from langchain_openai import ChatOpenAI
|
||||
from minio import Minio
|
||||
from tritonclient.utils import np_to_triton_dtype
|
||||
|
||||
from app.core.config import GI_MODEL_URL, MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, GI_MODEL_NAME
|
||||
from app.core.config import GI_MODEL_URL, GI_MODEL_NAME
|
||||
from app.schemas.brand_dna import GenerateBrandModel
|
||||
from app.service.utils.generate_uuid import generate_uuid
|
||||
from app.service.utils.new_oss_client import oss_upload_image
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
class GenerateBrandInfo:
|
||||
def __init__(self, request_data):
|
||||
# minio client init
|
||||
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
self.generate_logo_prompt = None
|
||||
self.minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
|
||||
|
||||
# user info init
|
||||
self.user_id = request_data.user_id
|
||||
@@ -55,7 +57,7 @@ class GenerateBrandInfo:
|
||||
return self.result_data
|
||||
|
||||
def llm_generate_brand_info(self):
|
||||
output = self.model(self._input.to_messages())
|
||||
output = self.model.invoke(self._input.to_messages())
|
||||
brand_data = self.output_parser.parse(output.content)
|
||||
self.result_data = brand_data
|
||||
self.generate_logo_prompt = brand_data['brand_logo_prompt']
|
||||
@@ -87,8 +89,8 @@ class GenerateBrandInfo:
|
||||
def upload_logo_image(self, image, object_name):
|
||||
try:
|
||||
_, img_byte_array = cv2.imencode('.jpg', image)
|
||||
object_name = f'{self.user_id}/{self.category}/{object_name}'
|
||||
req = oss_upload_image(oss_client=self.minio_client, bucket="aida-users", object_name=object_name, image_bytes=img_byte_array)
|
||||
object_name = f'{self.user_id}/{self.category}/{object_name}.jpg'
|
||||
oss_upload_image(oss_client=self.minio_client, bucket="aida-users", object_name=object_name, image_bytes=img_byte_array)
|
||||
image_url = f"aida-users/{object_name}"
|
||||
return image_url
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
from dotenv import load_dotenv
|
||||
from langchain.output_parsers import StructuredOutputParser, ResponseSchema
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
# 加载.env文件的环境变量
|
||||
load_dotenv()
|
||||
|
||||
# 创建一个大语言模型,model指定了大语言模型的种类
|
||||
model = ChatOpenAI(model="qwen2.5-14b-instruct")
|
||||
|
||||
# 想要接收的响应模式
|
||||
response_schemas = [
|
||||
ResponseSchema(name="brand_name", description="Brand name."),
|
||||
ResponseSchema(name="brand_slogan", description="Brand slogan."),
|
||||
ResponseSchema(name="brand_logo_prompt", description="prompt required for brand logo generation.")
|
||||
]
|
||||
output_parser = StructuredOutputParser.from_response_schemas(response_schemas)
|
||||
format_instructions = output_parser.get_format_instructions()
|
||||
prompt = PromptTemplate(
|
||||
template="你是一个时装品牌的设计师。根据用户输入提取出brand name,brand slogan,brand logo 描述。如果没有以上内容,需要你根据用户输入随意发挥。随后根据brand logo 描述生成一个prompt,这个prompt用于生成模型.\n{format_instructions}\n{question}",
|
||||
input_variables=["question"],
|
||||
partial_variables={"format_instructions": format_instructions}
|
||||
)
|
||||
_input = prompt.format_prompt(question="brand name: cat home")
|
||||
|
||||
output = model(_input.to_messages())
|
||||
brand_data = output_parser.parse(output.content)
|
||||
|
||||
|
||||
def generate_logo(bucket_name, object_name, prompt):
|
||||
pass
|
||||
@@ -3,27 +3,20 @@ import json
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Union, Tuple
|
||||
|
||||
from langchain.agents import AgentExecutor
|
||||
from langchain.callbacks.manager import Callbacks, CallbackManager
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.schema import RUN_KEY, RunInfo
|
||||
from langchain_classic.agents import AgentExecutor
|
||||
from langchain_classic.schema import RUN_KEY
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.callbacks import Callbacks, CallbackManager
|
||||
from langchain_core.load import dumpd
|
||||
from langchain_core.outputs import RunInfo
|
||||
|
||||
|
||||
class CustomAgentExecutor(AgentExecutor):
|
||||
def __call__(
|
||||
self,
|
||||
inputs: Union[Dict[str, Any], Any],
|
||||
return_only_outputs: bool = False,
|
||||
callbacks: Callbacks = None,
|
||||
session_key: str = "",
|
||||
*,
|
||||
tags: Optional[List[str]] = None,
|
||||
include_run_info: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
def __call__(self, inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False, callbacks: Callbacks = None, session_key: str = "", *, tags: Optional[List[str]] = None, include_run_info: bool = False, **kwargs) -> Dict[str, Any]:
|
||||
"""Run the logic of this chain and add to output if desired.
|
||||
|
||||
Args:
|
||||
**kwargs:
|
||||
inputs: Dictionary of inputs, or single input if chain expects
|
||||
only one param.
|
||||
return_only_outputs: boolean for whether to return only outputs in the
|
||||
@@ -72,7 +65,7 @@ class CustomAgentExecutor(AgentExecutor):
|
||||
"""Validate and prep outputs."""
|
||||
self._validate_outputs(outputs)
|
||||
if self.memory is not None and outputs['need_record']:
|
||||
self.memory.save_context(inputs, outputs, session_key)
|
||||
self.memory.save_context(inputs, outputs)
|
||||
if return_only_outputs:
|
||||
return outputs
|
||||
else:
|
||||
@@ -95,7 +88,7 @@ class CustomAgentExecutor(AgentExecutor):
|
||||
)
|
||||
inputs = {list(_input_keys)[0]: inputs}
|
||||
if self.memory is not None:
|
||||
external_context = self.memory.load_memory_variables(inputs, session_key)
|
||||
external_context = self.memory.load_memory_variables(inputs)
|
||||
inputs = dict(inputs, **external_context)
|
||||
self._validate_inputs(inputs)
|
||||
return inputs
|
||||
@@ -119,7 +112,8 @@ class CustomAgentExecutor(AgentExecutor):
|
||||
{return_value_key: observation},
|
||||
"",
|
||||
)
|
||||
except:
|
||||
except Exception as e:
|
||||
print(e)
|
||||
pass
|
||||
|
||||
# Invalid tools won't be in the map, so we return False.
|
||||
|
||||
@@ -1,26 +1,15 @@
|
||||
import json
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from json import JSONDecodeError
|
||||
from typing import List, Tuple, Any, Union
|
||||
from dataclasses import dataclass
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.agents import (
|
||||
OpenAIFunctionsAgent,
|
||||
)
|
||||
from langchain.schema import (
|
||||
AgentAction,
|
||||
AgentFinish,
|
||||
BaseMessage,
|
||||
OutputParserException
|
||||
)
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
FunctionMessage
|
||||
)
|
||||
from langchain.tools import BaseTool, StructuredTool
|
||||
# from langchain.tools.convert_to_openai import FunctionDescription
|
||||
from langchain.utils.openai_functions import FunctionDescription
|
||||
from langchain_classic.agents import OpenAIFunctionsAgent
|
||||
from langchain_community.utils.ernie_functions import FunctionDescription
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
from langchain_core.callbacks import Callbacks
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.messages import BaseMessage, AIMessage, FunctionMessage
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -76,7 +65,6 @@ def _create_function_message(
|
||||
content = observation
|
||||
return FunctionMessage(
|
||||
name=agent_action.tool,
|
||||
content=content,
|
||||
)
|
||||
|
||||
|
||||
@@ -177,6 +165,7 @@ class ConversationalFunctionsAgent(OpenAIFunctionsAgent):
|
||||
into it.
|
||||
|
||||
Args:
|
||||
callbacks:
|
||||
intermediate_steps: Steps the LLM has taken to date, along with observations
|
||||
**kwargs: User inputs.
|
||||
**kwargs: Including user's input string
|
||||
|
||||
@@ -2,18 +2,16 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
from langchain_community.callbacks.openai_info import OpenAICallbackHandler
|
||||
from langchain.schema import LLMResult
|
||||
from langchain_community.callbacks.openai_info import standardize_model_name, MODEL_COST_PER_1K_TOKENS, \
|
||||
get_openai_token_cost_for_model
|
||||
|
||||
|
||||
# from langchain.callbacks.openai_info import standardize_model_name, MODEL_COST_PER_1K_TOKENS, get_openai_token_cost_for_model
|
||||
from langchain_core.outputs import LLMResult
|
||||
|
||||
|
||||
class OpenAITokenRecordCallbackHandler(OpenAICallbackHandler):
|
||||
need_record: bool = True
|
||||
response_type: str = "string"
|
||||
"""Callback Handler that tracks OpenAI info and write to redis after agent finish"""
|
||||
|
||||
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
"""Collect token usage."""
|
||||
if response.llm_output is None:
|
||||
@@ -22,7 +20,7 @@ class OpenAITokenRecordCallbackHandler(OpenAICallbackHandler):
|
||||
if "token_usage" not in response.llm_output:
|
||||
return None
|
||||
if "function_call" in response.generations[0][0].message.additional_kwargs:
|
||||
if response.generations[0][0].message.additional_kwargs["function_call"]["name"] in ["sql_db_query", "sql_db_schema","tutorial_tool"]:
|
||||
if response.generations[0][0].message.additional_kwargs["function_call"]["name"] in ["sql_db_query", "sql_db_schema", "tutorial_tool"]:
|
||||
self.need_record = False
|
||||
if response.generations[0][0].message.additional_kwargs["function_call"]["name"] == "sql_db_query":
|
||||
self.response_type = "image"
|
||||
@@ -39,6 +37,7 @@ class OpenAITokenRecordCallbackHandler(OpenAICallbackHandler):
|
||||
self.total_tokens += token_usage.get("total_tokens", 0)
|
||||
self.prompt_tokens += prompt_tokens
|
||||
self.completion_tokens += completion_tokens
|
||||
return None
|
||||
|
||||
def on_chain_end(self, outputs: Dict, **kwargs: Any) -> None:
|
||||
"""Write token usage to redis."""
|
||||
|
||||
@@ -44,12 +44,17 @@ class CustomDatabase(SQLDatabase):
|
||||
final_str = "\n\n".join(tables)
|
||||
return final_str
|
||||
|
||||
def run(self, command: str, fetch: str = "all") -> str:
|
||||
def run(self, command: str, fetch: str = "all", **kwargs) -> str:
|
||||
"""Execute a SQL command and return a string representing the results.
|
||||
|
||||
If the statement returns rows, a string of the results is returned.
|
||||
If the statement returns no rows, an empty string is returned.
|
||||
|
||||
Args:
|
||||
command:
|
||||
fetch:
|
||||
**kwargs:
|
||||
|
||||
"""
|
||||
with self._engine.begin() as connection:
|
||||
if self._schema is not None:
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from langchain.agents import Tool
|
||||
from langchain.callbacks import FileCallbackHandler
|
||||
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder
|
||||
from langchain.schema import SystemMessage, AIMessage
|
||||
from langchain.utilities import SerpAPIWrapper
|
||||
from langchain_community.utilities import SerpAPIWrapper
|
||||
from langchain_core.callbacks import FileCallbackHandler
|
||||
from langchain_core.messages import SystemMessage, AIMessage
|
||||
from langchain_core.prompts import MessagesPlaceholder, HumanMessagePromptTemplate, ChatPromptTemplate
|
||||
from langchain_core.tools import Tool
|
||||
from langchain_community.chat_models import ChatTongyi
|
||||
from loguru import logger
|
||||
|
||||
from app.core.config import *
|
||||
from app.core.config import settings
|
||||
from app.service.chat_robot.script.agents import CustomAgentExecutor, ConversationalFunctionsAgent
|
||||
from app.service.chat_robot.script.database import CustomDatabase
|
||||
from app.service.chat_robot.script.memory import UserConversationBufferWindowMemory
|
||||
@@ -30,10 +30,10 @@ log_handler = FileCallbackHandler(logfile)
|
||||
# # callbacks=[OpenAICallbackHandler()]
|
||||
# )
|
||||
|
||||
llm = ChatTongyi(api_key=QWEN_API_KEY)
|
||||
llm = ChatTongyi(api_key=settings.QWEN_API_KEY)
|
||||
|
||||
search = SerpAPIWrapper()
|
||||
db = CustomDatabase.from_uri(f'mysql+pymysql://{DB_USERNAME}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/attribute_retrieval_V3',
|
||||
db = CustomDatabase.from_uri(f'mysql+pymysql://{settings.DB_USERNAME}:{settings.DB_PASSWORD}@{settings.DB_HOST}:{settings.DB_PORT}/attribute_retrieval_V3',
|
||||
include_tables=['female_top', 'female_skirt', 'female_pants', 'female_dress',
|
||||
'female_outwear', 'male_bottom', 'male_top', 'male_outwear'],
|
||||
engine_args={"pool_recycle": 7200})
|
||||
@@ -43,11 +43,11 @@ tools = [
|
||||
description="Can be used to perform Internet searches",
|
||||
func=search.run
|
||||
),
|
||||
QuerySQLDataBaseTool(db=db, return_direct=False),
|
||||
QuerySQLDataBaseTool(db=db),
|
||||
InfoSQLDatabaseTool(db=db),
|
||||
ListSQLDatabaseTool(db=db),
|
||||
# QuerySQLCheckerTool(db=db, llm=OpenAI(temperature=0, openai_api_key=OPENAI_API_KEY)),
|
||||
QuerySQLCheckerTool(db=db, llm=ChatTongyi(temperature=0, api_key=QWEN_API_KEY)),
|
||||
QuerySQLCheckerTool(db=db, llm=ChatTongyi(api_key=settings.QWEN_API_KEY)),
|
||||
# Tool(
|
||||
# name="tutorial_tool",
|
||||
# description="Utilize this tool to retrieve specific statements related to user guidance tutorials."
|
||||
@@ -133,5 +133,5 @@ def chat(post_data):
|
||||
'completion_tokens': final_outputs['completion_tokens'],
|
||||
'response_type': final_outputs["response_type"]
|
||||
}
|
||||
logging.info(json.dumps(api_response))
|
||||
logging.info(json.dumps(api_response, indent=4))
|
||||
return api_response
|
||||
|
||||
@@ -3,13 +3,12 @@ from typing import Any, Dict, List, Tuple
|
||||
import json
|
||||
|
||||
import redis
|
||||
from langchain_classic.memory.chat_memory import BaseChatMemory
|
||||
from langchain_classic.memory.utils import get_prompt_input_key
|
||||
from langchain_core.messages import messages_from_dict, get_buffer_string, BaseMessage, HumanMessage, AIMessage, message_to_dict
|
||||
from redis import Redis
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.schema.messages import BaseMessage, get_buffer_string, HumanMessage, AIMessage
|
||||
from langchain.schema.messages import _message_to_dict, messages_from_dict
|
||||
from langchain.memory.utils import get_prompt_input_key
|
||||
|
||||
from app.core.config import *
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
class UserConversationBufferWindowMemory(BaseChatMemory):
|
||||
@@ -24,8 +23,8 @@ class UserConversationBufferWindowMemory(BaseChatMemory):
|
||||
@classmethod
|
||||
def from_redis(
|
||||
cls,
|
||||
host: str = REDIS_HOST,
|
||||
port: int = REDIS_PORT,
|
||||
host: str = settings.REDIS_HOST,
|
||||
port: int = settings.REDIS_PORT,
|
||||
db: int = 3,
|
||||
**kwargs
|
||||
):
|
||||
@@ -79,7 +78,7 @@ class UserConversationBufferWindowMemory(BaseChatMemory):
|
||||
return inputs[prompt_input_key], outputs[output_key]
|
||||
|
||||
def add_message(self, key: str, message: BaseMessage) -> None:
|
||||
self.redis_client.lpush(key, json.dumps(_message_to_dict(message)))
|
||||
self.redis_client.lpush(key, json.dumps(message_to_dict(message)))
|
||||
|
||||
def save_context(self, inputs: Dict[str, Any], outputs: Dict[str, str], key: str = "") -> None:
|
||||
"""Save context from this conversation to buffer."""
|
||||
|
||||
@@ -5,10 +5,10 @@ from dashscope import Generation
|
||||
from retry import retry
|
||||
from urllib3.exceptions import NewConnectionError
|
||||
|
||||
from app.core.config import *
|
||||
from app.core.config import settings
|
||||
from app.service.chat_robot.script.callbacks.qwen_callback_handler import QWenCallbackHandler
|
||||
from app.service.chat_robot.script.database import CustomDatabase
|
||||
from app.service.chat_robot.script.prompt import FASHION_CHAT_BOT_PREFIX, TOOLS_FUNCTIONS_SUFFIX, TUTORIAL_TOOL_RETURN, \
|
||||
from app.service.chat_robot.script.prompt import TOOLS_FUNCTIONS_SUFFIX, TUTORIAL_TOOL_RETURN, \
|
||||
GET_LANGUAGE_PREFIX, FASHION_CHAT_BOT_PREFIX_TEMP
|
||||
from app.service.search_image_with_text.service import query
|
||||
|
||||
@@ -149,7 +149,7 @@ tools = [
|
||||
}
|
||||
]
|
||||
|
||||
db = CustomDatabase.from_uri(f'mysql+pymysql://{DB_USERNAME}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/attribute_retrieval_V3',
|
||||
db = CustomDatabase.from_uri(f'mysql+pymysql://{settings.MYSQL_USER}:{settings.MYSQL_PASSWORD}@{settings.MYSQL_HOST}:{settings.MYSQL_PORT}/attribute_retrieval_V3',
|
||||
include_tables=['female_top', 'female_skirt', 'female_pants', 'female_dress',
|
||||
'female_outwear', 'male_bottom', 'male_top', 'male_outwear'],
|
||||
engine_args={"pool_recycle": 7200})
|
||||
@@ -159,7 +159,7 @@ qwen = QWenCallbackHandler()
|
||||
def search_from_internet(message):
|
||||
response = Generation.call(
|
||||
model='qwen-turbo',
|
||||
api_key=QWEN_API_KEY,
|
||||
api_key=settings.QWEN_API_KEY,
|
||||
messages=message,
|
||||
prompt='The output must be in English.Keep the final result under 200 words.'
|
||||
# tools=tools,
|
||||
@@ -190,7 +190,7 @@ def get_image_from_vector_db(gender, content):
|
||||
def get_response(messages):
|
||||
response = Generation.call(
|
||||
model='qwen-max',
|
||||
api_key=QWEN_API_KEY,
|
||||
api_key=settings.QWEN_API_KEY,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
# seed=random.randint(1, 10000), # 设置随机数种子seed,如果没有设置,则随机数种子默认为1234
|
||||
@@ -203,7 +203,7 @@ def get_response(messages):
|
||||
def get_assistant_response(messages):
|
||||
response = Generation.call(
|
||||
model='qwen-max',
|
||||
api_key=QWEN_API_KEY,
|
||||
api_key=settings.QWEN_API_KEY,
|
||||
messages=messages,
|
||||
# seed=random.randint(1, 10000), # 设置随机数种子seed,如果没有设置,则随机数种子默认为1234
|
||||
result_format='message', # 将输出设置为message形式
|
||||
@@ -212,8 +212,10 @@ def get_assistant_response(messages):
|
||||
return response
|
||||
|
||||
|
||||
global tool_info
|
||||
|
||||
|
||||
def call_with_messages(message):
|
||||
global tool_info
|
||||
user_input = message
|
||||
print('\n')
|
||||
|
||||
@@ -241,7 +243,7 @@ def call_with_messages(message):
|
||||
response_type = "chat"
|
||||
|
||||
while flag and count <= 3:
|
||||
first_response = get_response(messages)
|
||||
first_response = get_response
|
||||
assistant_output = first_response.output.choices[0].message
|
||||
QWenCallbackHandler.on_llm_end(qwen, first_response.usage)
|
||||
print(f"\n大模型第 {count} 轮输出信息:{first_response}\n")
|
||||
@@ -260,7 +262,7 @@ def call_with_messages(message):
|
||||
]
|
||||
tool_info['content'] = search_from_internet(message)
|
||||
flag = False
|
||||
result_content = tool_info['content'].output.text
|
||||
result_content = tool_info['content']
|
||||
# 如果模型选择的工具是get_database_table
|
||||
# elif assistant_output.tool_calls[0]['function']['name'] == 'get_database_table':
|
||||
# tool_info = {"name": "get_database_table", "role": "tool", 'content': get_database_table()}
|
||||
|
||||
@@ -2,21 +2,15 @@
|
||||
"""Tools for interacting with a SQL database."""
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain_community.tools.sql_database.prompt import QUERY_CHECKER
|
||||
from langchain_community.tools.sql_database.tool import _QuerySQLCheckerToolInput
|
||||
# from langchain.sql_database import SQLDatabase
|
||||
from langchain_community.utilities import SQLDatabase
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain_community.tools.sql_database.prompt import QUERY_CHECKER
|
||||
|
||||
from langchain_community.tools.sql_database.tool import QuerySQLCheckerTool, _QuerySQLCheckerToolInput
|
||||
from langchain_core.callbacks import CallbackManagerForToolRun, AsyncCallbackManagerForToolRun
|
||||
from langchain_core.language_models import BaseLanguageModel
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_core.tools import BaseTool
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
|
||||
|
||||
class BaseSQLDatabaseTool(BaseModel):
|
||||
@@ -62,7 +56,7 @@ class QuerySQLDataBaseTool(BaseSQLDatabaseTool, BaseTool):
|
||||
"LIMIT 1'"
|
||||
"Example Input 2: 'SELECT img_name FROM top WHERE sleeve_length = 'Long' AND type = 'Blouse' "
|
||||
"order by rand() LIMIT 2'"
|
||||
)
|
||||
)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
@@ -97,7 +91,7 @@ class InfoSQLDatabaseTool(BaseSQLDatabaseTool, BaseTool):
|
||||
"female_skirt, female_outwear, male_bottom, male_top, and male_outwear."
|
||||
|
||||
"Example Input: 'female_outwear, male_top'"
|
||||
)
|
||||
)
|
||||
|
||||
def _run(
|
||||
self,
|
||||
@@ -183,11 +177,11 @@ class QuerySQLCheckerTool(BaseSQLDatabaseTool, BaseTool):
|
||||
args_schema: Type[BaseModel] = _QuerySQLCheckerToolInput
|
||||
|
||||
@root_validator(pre=True)
|
||||
def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def initialize_llm_chain(self, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
if "llm_chain" not in values:
|
||||
# from langchain.chains.llm import LLMChain
|
||||
|
||||
llm = values.get("llm") # type: ignore[arg-type]
|
||||
llm = values.get("llm") # type: ignore[arg-type]
|
||||
prompt = PromptTemplate(
|
||||
template=QUERY_CHECKER, input_variables=["dialect", "query"]
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Any
|
||||
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain_core.tools import BaseTool
|
||||
|
||||
from app.service.chat_robot.script.prompt import TUTORIAL_TOOL_RETURN
|
||||
|
||||
|
||||
@@ -9,14 +9,14 @@ from PIL import Image
|
||||
from minio import Minio
|
||||
from tritonclient.utils import np_to_triton_dtype
|
||||
|
||||
from app.core.config import *
|
||||
from app.core.config import settings
|
||||
from app.schemas.clothing_seg import ClothingSegModel
|
||||
from app.service.design_fast.utils.design_ensemble import get_seg_result
|
||||
from app.service.utils.decorator import RunTime
|
||||
from app.service.utils.generate_uuid import generate_uuid
|
||||
from app.service.utils.new_oss_client import oss_get_image, oss_upload_image
|
||||
|
||||
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
|
||||
|
||||
|
||||
class ClothingSeg:
|
||||
@@ -64,9 +64,9 @@ class ClothingSeg:
|
||||
if image_type == "sketch":
|
||||
if len(image.shape) == 2:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
||||
seg_mask = get_seg_result(1, image[:, :, :3])
|
||||
seg_mask = get_seg_result(image[:, :, :3])
|
||||
else:
|
||||
seg_mask = get_seg_result(1, image[:, :, :3])
|
||||
seg_mask = get_seg_result(image[:, :, :3])
|
||||
temp = seg_mask != 0.0
|
||||
mask = (255 * (temp + 0).astype(np.uint8))
|
||||
x_min, y_min, x_max, y_max = get_bounding_box(mask)
|
||||
|
||||
646
app/service/comfyui_I2V/flf2v_server.py
Normal file
646
app/service/comfyui_I2V/flf2v_server.py
Normal file
@@ -0,0 +1,646 @@
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
from minio import Minio, S3Error
|
||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||
|
||||
from app.core.config import PS_RABBITMQ_QUEUES
|
||||
from app.core.config import settings
|
||||
from app.schemas.comfyui_i2v import ComfyuiFLF2VModel
|
||||
from app.service.generate_image.utils.mq import publish_status
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
# 首尾帧 + 文字 = 视频 工作流
|
||||
workflow_json = {
|
||||
"6": {
|
||||
"inputs": {
|
||||
"text": "A bearded man with red facial hair wearing a yellow straw hat and dark coat in Van Gogh's self-portrait style, slowly and continuously transforms into a space astronaut. The transformation flows like liquid paint - his beard fades away strand by strand, the yellow hat melts and reforms smoothly into a silver space helmet, dark coat gradually lightens and restructures into a white spacesuit. The background swirling brushstrokes slowly organize and clarify into realistic stars and space, with Earth appearing gradually in the distance. Every change happens in seamless waves, maintaining visual continuity throughout the metamorphosis.\n\nConsistent soft lighting throughout, medium close-up maintaining same framing, central composition stays fixed, gentle color temperature shift from warm to cool, gradual contrast increase, smooth style transition from painterly to photorealistic. Static camera with subtle slow zoom, emphasizing the flowing transformation process without abrupt changes.",
|
||||
"clip": [
|
||||
"38",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "CLIPTextEncode",
|
||||
"_meta": {
|
||||
"title": "CLIP Text Encode (Positive Prompt)"
|
||||
}
|
||||
},
|
||||
"7": {
|
||||
"inputs": {
|
||||
"text": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
"clip": [
|
||||
"38",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "CLIPTextEncode",
|
||||
"_meta": {
|
||||
"title": "CLIP Text Encode (Negative Prompt)"
|
||||
}
|
||||
},
|
||||
"8": {
|
||||
"inputs": {
|
||||
"samples": [
|
||||
"58",
|
||||
0
|
||||
],
|
||||
"vae": [
|
||||
"39",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "VAEDecode",
|
||||
"_meta": {
|
||||
"title": "VAE解码"
|
||||
}
|
||||
},
|
||||
"37": {
|
||||
"inputs": {
|
||||
"unet_name": "wan2.2_i2v_high_noise_14B_fp8_scaled.safetensors",
|
||||
"weight_dtype": "default"
|
||||
},
|
||||
"class_type": "UNETLoader",
|
||||
"_meta": {
|
||||
"title": "UNet加载器"
|
||||
}
|
||||
},
|
||||
"38": {
|
||||
"inputs": {
|
||||
"clip_name": "umt5_xxl_fp8_e4m3fn_scaled.safetensors",
|
||||
"type": "wan",
|
||||
"device": "default"
|
||||
},
|
||||
"class_type": "CLIPLoader",
|
||||
"_meta": {
|
||||
"title": "加载CLIP"
|
||||
}
|
||||
},
|
||||
"39": {
|
||||
"inputs": {
|
||||
"vae_name": "wan_2.1_vae.safetensors"
|
||||
},
|
||||
"class_type": "VAELoader",
|
||||
"_meta": {
|
||||
"title": "加载VAE"
|
||||
}
|
||||
},
|
||||
"54": {
|
||||
"inputs": {
|
||||
"shift": 5,
|
||||
"model": [
|
||||
"91",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "ModelSamplingSD3",
|
||||
"_meta": {
|
||||
"title": "采样算法(SD3)"
|
||||
}
|
||||
},
|
||||
"55": {
|
||||
"inputs": {
|
||||
"shift": 5,
|
||||
"model": [
|
||||
"92",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "ModelSamplingSD3",
|
||||
"_meta": {
|
||||
"title": "采样算法(SD3)"
|
||||
}
|
||||
},
|
||||
"56": {
|
||||
"inputs": {
|
||||
"unet_name": "wan2.2_i2v_low_noise_14B_fp8_scaled.safetensors",
|
||||
"weight_dtype": "default"
|
||||
},
|
||||
"class_type": "UNETLoader",
|
||||
"_meta": {
|
||||
"title": "UNet加载器"
|
||||
}
|
||||
},
|
||||
"57": {
|
||||
"inputs": {
|
||||
"add_noise": "enable",
|
||||
"noise_seed": 984937593540091,
|
||||
"steps": 4,
|
||||
"cfg": 1,
|
||||
"sampler_name": "euler",
|
||||
"scheduler": "simple",
|
||||
"start_at_step": 0,
|
||||
"end_at_step": 2,
|
||||
"return_with_leftover_noise": "enable",
|
||||
"model": [
|
||||
"54",
|
||||
0
|
||||
],
|
||||
"positive": [
|
||||
"67",
|
||||
0
|
||||
],
|
||||
"negative": [
|
||||
"67",
|
||||
1
|
||||
],
|
||||
"latent_image": [
|
||||
"67",
|
||||
2
|
||||
]
|
||||
},
|
||||
"class_type": "KSamplerAdvanced",
|
||||
"_meta": {
|
||||
"title": "K采样器(高级)"
|
||||
}
|
||||
},
|
||||
"58": {
|
||||
"inputs": {
|
||||
"add_noise": "disable",
|
||||
"noise_seed": 0,
|
||||
"steps": 4,
|
||||
"cfg": 1,
|
||||
"sampler_name": "euler",
|
||||
"scheduler": "simple",
|
||||
"start_at_step": 2,
|
||||
"end_at_step": 10000,
|
||||
"return_with_leftover_noise": "disable",
|
||||
"model": [
|
||||
"55",
|
||||
0
|
||||
],
|
||||
"positive": [
|
||||
"67",
|
||||
0
|
||||
],
|
||||
"negative": [
|
||||
"67",
|
||||
1
|
||||
],
|
||||
"latent_image": [
|
||||
"57",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "KSamplerAdvanced",
|
||||
"_meta": {
|
||||
"title": "K采样器(高级)"
|
||||
}
|
||||
},
|
||||
"60": {
|
||||
"inputs": {
|
||||
"fps": 16,
|
||||
"images": [
|
||||
"8",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "CreateVideo",
|
||||
"_meta": {
|
||||
"title": "创建视频"
|
||||
}
|
||||
},
|
||||
"61": {
|
||||
"inputs": {
|
||||
"filename_prefix": "video/ComfyUI",
|
||||
"format": "auto",
|
||||
"codec": "auto",
|
||||
"video": [
|
||||
"60",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "SaveVideo",
|
||||
"_meta": {
|
||||
"title": "保存视频"
|
||||
}
|
||||
},
|
||||
"62": {
|
||||
"inputs": {
|
||||
"image": "video_wan2_2_14B_flf2v_start_image.png"
|
||||
},
|
||||
"class_type": "LoadImage",
|
||||
"_meta": {
|
||||
"title": "加载end图像"
|
||||
}
|
||||
},
|
||||
"67": {
|
||||
"inputs": {
|
||||
"width": 640,
|
||||
"height": 640,
|
||||
"length": 81,
|
||||
"batch_size": 1,
|
||||
"positive": [
|
||||
"6",
|
||||
0
|
||||
],
|
||||
"negative": [
|
||||
"7",
|
||||
0
|
||||
],
|
||||
"vae": [
|
||||
"39",
|
||||
0
|
||||
],
|
||||
"start_image": [
|
||||
"68",
|
||||
0
|
||||
],
|
||||
"end_image": [
|
||||
"62",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "WanFirstLastFrameToVideo",
|
||||
"_meta": {
|
||||
"title": "WanFirstLastFrameToVideo"
|
||||
}
|
||||
},
|
||||
"68": {
|
||||
"inputs": {
|
||||
"image": "video_wan2_2_14B_flf2v_end_image.png"
|
||||
},
|
||||
"class_type": "LoadImage",
|
||||
"_meta": {
|
||||
"title": "加载start图像"
|
||||
}
|
||||
},
|
||||
"91": {
|
||||
"inputs": {
|
||||
"lora_name": "wan2.2_i2v_lightx2v_4steps_lora_v1_high_noise.safetensors",
|
||||
"strength_model": 1,
|
||||
"model": [
|
||||
"37",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "LoraLoaderModelOnly",
|
||||
"_meta": {
|
||||
"title": "LoRA加载器(仅模型)"
|
||||
}
|
||||
},
|
||||
"92": {
|
||||
"inputs": {
|
||||
"lora_name": "wan2.2_i2v_lightx2v_4steps_lora_v1_low_noise.safetensors",
|
||||
"strength_model": 1,
|
||||
"model": [
|
||||
"56",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "LoraLoaderModelOnly",
|
||||
"_meta": {
|
||||
"title": "LoRA加载器(仅模型)"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ComfyUIServerFLF2V:
|
||||
def __init__(self, request_data):
|
||||
self.pose_transform_data = None
|
||||
self.start_image_url = request_data.start_image_url
|
||||
self.end_image_url = request_data.end_image_url
|
||||
self.prompt = request_data.prompt
|
||||
self.tasks_id = request_data.tasks_id
|
||||
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
|
||||
self.server_status_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'gif_url': '', 'video_url': '', 'image_url': ''}
|
||||
self.minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
|
||||
|
||||
def get_result(self):
|
||||
workflow_json['6']['inputs']['text'] = self.prompt
|
||||
workflow_json['57']['inputs']["noise_seed"] = random.randint(0, 10 ** 18)
|
||||
|
||||
if self.start_image_url:
|
||||
# 下载图片 上传 comfyui server
|
||||
# TODO 设置视频宽度为480,高度自适应
|
||||
workflow_json['67']['inputs']["width"] = 480
|
||||
workflow_json['67']['inputs']["height"] = 848
|
||||
if self.start_image_url:
|
||||
start_in_memory_file, start_object_name = self.download_from_minio_in_memory(self.start_image_url)
|
||||
# 上传图片到comfyui server
|
||||
filename = self.upload_in_memory_file_to_comfyui(start_in_memory_file, start_object_name)
|
||||
workflow_json['68']['inputs']['image'] = filename
|
||||
else:
|
||||
assert "start_image_url is None"
|
||||
|
||||
if self.end_image_url:
|
||||
end_in_memory_file, end_object_name = self.download_from_minio_in_memory(self.end_image_url)
|
||||
# 上传图片到comfyui server
|
||||
filename = self.upload_in_memory_file_to_comfyui(end_in_memory_file, end_object_name)
|
||||
workflow_json['62']['inputs']['image'] = filename
|
||||
else:
|
||||
assert "end_image_url is None"
|
||||
|
||||
# 1. 提交任务
|
||||
prompt_response = self.queue_prompt(workflow_json, self.tasks_id)
|
||||
if not prompt_response:
|
||||
return None
|
||||
|
||||
prompt_id = prompt_response.get("prompt_id")
|
||||
logger.info(f" 任务已提交,Prompt ID: {prompt_id}")
|
||||
outputs = self.poll_history(prompt_id)
|
||||
file_list = {}
|
||||
for node_id, node_output in outputs.items():
|
||||
# 检查当前节点输出中是否包含 'images' 列表
|
||||
if 'images' in node_output and isinstance(node_output['images'], list):
|
||||
# 'images' 列表中的每个元素都是一个文件对象
|
||||
for file_info in node_output['images']:
|
||||
# 确保关键字段存在
|
||||
if all(key in file_info for key in ['filename', 'subfolder', 'type']):
|
||||
file_list = {
|
||||
'filename': file_info['filename'],
|
||||
'subfolder': file_info['subfolder'],
|
||||
'type': file_info['type']
|
||||
}
|
||||
logger.info(file_list)
|
||||
return self.process_and_upload_comfyui_video(filename=file_list['filename'], subfolder=file_list['subfolder'], prompt_id=prompt_response['prompt_id']), prompt_id
|
||||
return None
|
||||
|
||||
def download_from_minio_in_memory(self, image_url):
|
||||
bucket = image_url.split('/')[0]
|
||||
object_name = image_url[image_url.find('/') + 1:]
|
||||
|
||||
try:
|
||||
# get_object 返回一个 ResponseStream 对象
|
||||
response_stream = self.minio_client.get_object(
|
||||
bucket,
|
||||
object_name,
|
||||
)
|
||||
|
||||
# 读取整个流到内存 (BytesIO),避免写入本地文件
|
||||
image_bytes = response_stream.read()
|
||||
|
||||
response_stream.close()
|
||||
response_stream.release_conn()
|
||||
|
||||
in_memory_file = io.BytesIO(image_bytes)
|
||||
|
||||
# print(f"✅ 图片已下载到内存 ({len(image_bytes)} 字节)。")
|
||||
return in_memory_file, object_name.rsplit('/')[-1]
|
||||
|
||||
except S3Error as e:
|
||||
logger.error(f"❌ MinIO S3 错误 (例如,对象不存在): {e}")
|
||||
return None, None
|
||||
except Exception as e:
|
||||
logger.error(f"❌ MinIO 下载过程中发生未知错误: {e}")
|
||||
return None, None
|
||||
|
||||
@staticmethod
|
||||
def upload_in_memory_file_to_comfyui(in_memory_file, filename):
|
||||
upload_url = f"http://{settings.COMFYUI_SERVER_ADDRESS}/upload/image"
|
||||
|
||||
data = {
|
||||
"overwrite": "true",
|
||||
"type": "input"
|
||||
}
|
||||
|
||||
# 构建 multipart/form-data: (文件名, 内存文件对象, MIME 类型)
|
||||
# MIME 类型可以根据实际图片类型修改,这里使用常见的 png/jpeg
|
||||
mime_type = 'image/png' if filename.lower().endswith('.png') else 'image/jpeg'
|
||||
|
||||
files = {
|
||||
'image': (filename, in_memory_file, mime_type)
|
||||
}
|
||||
|
||||
# print(f"⬆️ 正在上传图片 ({filename}) 到 ComfyUI...")
|
||||
try:
|
||||
comfyui_response = requests.post(upload_url, data=data, files=files)
|
||||
comfyui_response.raise_for_status()
|
||||
|
||||
result = comfyui_response.json()
|
||||
uploaded_name = result.get('name')
|
||||
|
||||
# print(f"🎉 ComfyUI 上传成功! 服务器文件名: {uploaded_name}")
|
||||
return uploaded_name
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"❌ ComfyUI 上传失败: {e}")
|
||||
logger.error(f" 响应内容: {comfyui_response.text}")
|
||||
return None
|
||||
|
||||
def process_and_upload_comfyui_video(self, filename: str, subfolder: str, prompt_id: str, ):
|
||||
"""
|
||||
完整的自动化流程:获取 ComfyUI 视频 -> 转换 GIF 并提取帧 -> 上传所有结果到 MinIO。
|
||||
"""
|
||||
# 1. 从 ComfyUI 获取视频二进制数据
|
||||
mp4_bytes = self.get_comfyui_video_bytes(filename, subfolder)
|
||||
if not mp4_bytes:
|
||||
return None
|
||||
|
||||
# 2. 准备进行视频处理
|
||||
# moviepy 不支持直接使用 bytes,需要将 bytes 写入一个 BytesIO 或临时文件
|
||||
# 为了避免写磁盘,我们将使用 BytesIO,但 MoviePy 内部依赖 FFmpeg,有时需要一个可寻址的本地文件路径。
|
||||
# 最可靠且避免写本地的方案是在内存中操作,然后将结果上传。
|
||||
|
||||
# ⚠️ 关键点:将 mp4_bytes 写入 BytesIO 以模拟文件,供 moviepy 读取
|
||||
|
||||
# 定义输出对象名
|
||||
|
||||
output_base_name = uuid.uuid4().hex
|
||||
MP4_OBJECT = f"{self.user_id}/pose_transform_video/{prompt_id}/{output_base_name}.mp4"
|
||||
GIF_OBJECT = f"{self.user_id}/pose_transform_gif/{prompt_id}/{output_base_name}.gif"
|
||||
FRAME_OBJECT = f"{self.user_id}/pose_transform_first_img/{prompt_id}/{output_base_name}_frame.jpg"
|
||||
|
||||
# --- 视频处理和帧提取 ---
|
||||
try:
|
||||
# 1. 创建一个临时的 MP4 文件路径
|
||||
# delete=False 确保文件在关闭后仍然存在,直到我们手动删除
|
||||
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_file:
|
||||
tmp_file.write(mp4_bytes) # 将内存数据写入磁盘
|
||||
temp_mp4_path = tmp_file.name # 记录文件路径
|
||||
|
||||
# print(f"临时文件已写入: {temp_mp4_path}")
|
||||
|
||||
# 2. 使用 moviepy 打开临时文件 (传入文件路径字符串)
|
||||
clip = VideoFileClip(temp_mp4_path)
|
||||
|
||||
# --- 在这里进行所有的视频处理和提取操作 ---
|
||||
|
||||
# 提取第一帧 (保持原尺寸)
|
||||
frame_array = clip.get_frame(t=0.0)
|
||||
image = Image.fromarray(frame_array)
|
||||
|
||||
frame_stream = io.BytesIO()
|
||||
image.save(frame_stream, 'JPEG')
|
||||
frame_bytes = frame_stream.getvalue()
|
||||
|
||||
logger.info("✅ 成功提取第一帧图片。")
|
||||
|
||||
# 视频转 GIF (使用另一个临时文件来保存 GIF)
|
||||
temp_gif_path = ""
|
||||
with tempfile.NamedTemporaryFile(suffix=".gif", delete=False) as tmp_file:
|
||||
temp_gif_path = tmp_file.name
|
||||
|
||||
target_fps = int(round(clip.fps)) if clip.fps else 24
|
||||
clip.write_gif(temp_gif_path, fps=target_fps)
|
||||
|
||||
with open(temp_gif_path, 'rb') as f:
|
||||
gif_bytes = f.read()
|
||||
|
||||
logger.info("✅ 成功生成 GIF。")
|
||||
|
||||
# 返回结果 (例如: 上传到 MinIO)
|
||||
# return mp4_bytes, gif_bytes, frame_bytes
|
||||
|
||||
# -----------------------------------------------
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 视频处理或文件操作失败: {e}")
|
||||
# 在失败时,也尝试清理文件
|
||||
|
||||
finally:
|
||||
# 3. 清理临时文件 (非常重要!)
|
||||
if os.path.exists(temp_mp4_path):
|
||||
os.remove(temp_mp4_path)
|
||||
logger.info(f"🗑️ 已删除临时 MP4 文件: {temp_mp4_path}")
|
||||
|
||||
if 'temp_gif_path' in locals() and os.path.exists(temp_gif_path):
|
||||
os.remove(temp_gif_path)
|
||||
logger.info(f"🗑️ 已删除临时 GIF 文件: {temp_gif_path}")
|
||||
|
||||
# 3. 上传所有结果到 MinIO
|
||||
|
||||
try:
|
||||
# 上传原始 MP4
|
||||
self.upload_stream_to_minio(mp4_bytes, MP4_OBJECT, "video/mp4")
|
||||
|
||||
# 上传生成的 GIF
|
||||
self.upload_stream_to_minio(gif_bytes, GIF_OBJECT, "image/gif")
|
||||
|
||||
# 上传第一帧图片
|
||||
self.upload_stream_to_minio(frame_bytes, FRAME_OBJECT, "image/jpeg")
|
||||
|
||||
self.pose_transform_data = {'tasks_id': self.tasks_id, 'status': 'SUCCESS', 'message': "success", 'gif_url': f'aida-users/{GIF_OBJECT}', 'video_url': f'aida-users/{MP4_OBJECT}', 'image_url': f'aida-users/{FRAME_OBJECT}'}
|
||||
|
||||
# 推送消息
|
||||
if not settings.DEBUG:
|
||||
publish_status(json.dumps(self.pose_transform_data), PS_RABBITMQ_QUEUES)
|
||||
logger.info(
|
||||
f" [x] Sent to: {PS_RABBITMQ_QUEUES} data:@@@@ {json.dumps(self.pose_transform_data, indent=4)}")
|
||||
|
||||
return "\n🎉 所有任务完成!"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return None
|
||||
|
||||
# --- 辅助函数:提交任务到队列 ---
|
||||
@staticmethod
|
||||
def queue_prompt(prompt, client_id):
|
||||
"""向 ComfyUI 提交工作流提示。"""
|
||||
p = {"prompt": prompt, "client_id": client_id, "prompt_id": client_id}
|
||||
data = json.dumps(p).encode('utf-8')
|
||||
|
||||
# 提交任务到 /prompt 端点
|
||||
response = requests.post(f"http://{settings.COMFYUI_SERVER_ADDRESS}/prompt", data=data)
|
||||
# print(f"-------------{response.text}")
|
||||
# print(f"------------{client_id}")
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
logger.warning(f"提交任务失败,状态码: {response.status_code}")
|
||||
logger.warning(response.text)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def poll_history(prompt_id, interval_seconds=5):
|
||||
"""步骤 2: 轮询 /history/{prompt_id} 检查任务是否完成"""
|
||||
url = f"http://{settings.COMFYUI_SERVER_ADDRESS}/history/{prompt_id}"
|
||||
|
||||
logger.info(f"⏳ 开始轮询状态 (间隔 {interval_seconds} 秒)...")
|
||||
|
||||
while True:
|
||||
time.sleep(interval_seconds)
|
||||
|
||||
try:
|
||||
response = requests.get(url)
|
||||
# 任务未完成时,ComfyUI可能会返回404或空响应,我们只关注成功响应
|
||||
if response.status_code == 200:
|
||||
history_data = response.json()
|
||||
|
||||
# ComfyUI 返回的历史记录结构是 {prompt_id: {outputs: ...}}
|
||||
if prompt_id in history_data:
|
||||
logger.info("🎉 任务已完成!")
|
||||
return history_data[prompt_id]['outputs']
|
||||
|
||||
logger.info("⏳ 任务仍在执行或等待中...")
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
# 处理可能的连接错误,但通常不会在内部轮询中发生
|
||||
logger.info(f"⚠️ 轮询时发生错误: {e}")
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def get_comfyui_video_bytes(filename: str, subfolder: str, file_type: str = "output"):
|
||||
"""
|
||||
从 ComfyUI 的 /view 端点获取视频文件的二进制数据。
|
||||
|
||||
参数:
|
||||
- filename: 视频文件名 (例如: 'ComfyUI_00002_.mp4')
|
||||
- subfolder: 存储子文件夹 (例如: 'ComfyUI_2025-10-31')
|
||||
- file_type: 文件类型 (通常是 'output')
|
||||
|
||||
返回:
|
||||
- 视频文件的二进制内容 (bytes) 或 None。
|
||||
"""
|
||||
url = f"http://{settings.COMFYUI_SERVER_ADDRESS}/view"
|
||||
params = {
|
||||
"filename": filename,
|
||||
"subfolder": subfolder,
|
||||
"type": file_type
|
||||
}
|
||||
|
||||
logger.info(f"📡 正在从 ComfyUI 下载视频: {filename}")
|
||||
try:
|
||||
# 使用 requests.get 下载文件
|
||||
response = requests.get(url, params=params, stream=True)
|
||||
response.raise_for_status() # 检查 HTTP 错误
|
||||
|
||||
# 返回文件的完整二进制内容
|
||||
return response.content
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"❌ 从 ComfyUI 获取视频失败: {e}")
|
||||
return None
|
||||
|
||||
def upload_stream_to_minio(self, video_bytes: bytes, object_name: str, content_type: str):
|
||||
"""从内存流上传数据到 MinIO。"""
|
||||
logger.info(f"☁️ 正在上传对象到 MinIO: {object_name}")
|
||||
try:
|
||||
|
||||
data_stream = io.BytesIO(video_bytes)
|
||||
|
||||
result = self.minio_client.put_object(
|
||||
bucket_name='aida-users',
|
||||
object_name=object_name,
|
||||
data=data_stream,
|
||||
length=len(video_bytes),
|
||||
content_type=content_type
|
||||
)
|
||||
logger.info(f"✅ MinIO 上传成功: {result.object_name}")
|
||||
return True
|
||||
except S3Error as e:
|
||||
logger.error(f"❌ MinIO 上传失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
request_data = ComfyuiFLF2VModel(
|
||||
tasks_id="202511051619-89111",
|
||||
start_image_url="test/start.png",
|
||||
end_image_url="test/end.png",
|
||||
prompt="Model executing a series of poses, dynamic camera movement alternating between detailed close-ups and full shots."
|
||||
)
|
||||
|
||||
server = ComfyUIServerFLF2V(request_data)
|
||||
print(server.get_result())
|
||||
622
app/service/comfyui_I2V/i2v_server.py
Normal file
622
app/service/comfyui_I2V/i2v_server.py
Normal file
@@ -0,0 +1,622 @@
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
from minio import Minio, S3Error
|
||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||
|
||||
from app.core.config import PS_RABBITMQ_QUEUES, settings
|
||||
from app.schemas.comfyui_i2v import ComfyuiI2VModel
|
||||
from app.service.generate_image.utils.mq import publish_status
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
# 图 + 文字 = 视频 工作流
|
||||
workflow_json = {
|
||||
"84": {
|
||||
"inputs": {
|
||||
"clip_name": "umt5_xxl_fp8_e4m3fn_scaled.safetensors",
|
||||
"type": "wan",
|
||||
"device": "default"
|
||||
},
|
||||
"class_type": "CLIPLoader",
|
||||
"_meta": {
|
||||
"title": "加载CLIP"
|
||||
}
|
||||
},
|
||||
"85": {
|
||||
"inputs": {
|
||||
"add_noise": "disable",
|
||||
"noise_seed": 0,
|
||||
"steps": 4,
|
||||
"cfg": 1,
|
||||
"sampler_name": "euler",
|
||||
"scheduler": "simple",
|
||||
"start_at_step": 2,
|
||||
"end_at_step": 4,
|
||||
"return_with_leftover_noise": "disable",
|
||||
"model": [
|
||||
"103",
|
||||
0
|
||||
],
|
||||
"positive": [
|
||||
"98",
|
||||
0
|
||||
],
|
||||
"negative": [
|
||||
"98",
|
||||
1
|
||||
],
|
||||
"latent_image": [
|
||||
"86",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "KSamplerAdvanced",
|
||||
"_meta": {
|
||||
"title": "K采样器(高级)"
|
||||
}
|
||||
},
|
||||
"86": {
|
||||
"inputs": {
|
||||
"add_noise": "enable",
|
||||
"noise_seed": 823962998672127,
|
||||
"steps": 4,
|
||||
"cfg": 1,
|
||||
"sampler_name": "euler",
|
||||
"scheduler": "simple",
|
||||
"start_at_step": 0,
|
||||
"end_at_step": 2,
|
||||
"return_with_leftover_noise": "enable",
|
||||
"model": [
|
||||
"104",
|
||||
0
|
||||
],
|
||||
"positive": [
|
||||
"98",
|
||||
0
|
||||
],
|
||||
"negative": [
|
||||
"98",
|
||||
1
|
||||
],
|
||||
"latent_image": [
|
||||
"98",
|
||||
2
|
||||
]
|
||||
},
|
||||
"class_type": "KSamplerAdvanced",
|
||||
"_meta": {
|
||||
"title": "K采样器(高级)"
|
||||
}
|
||||
},
|
||||
"87": {
|
||||
"inputs": {
|
||||
"samples": [
|
||||
"85",
|
||||
0
|
||||
],
|
||||
"vae": [
|
||||
"90",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "VAEDecode",
|
||||
"_meta": {
|
||||
"title": "VAE解码"
|
||||
}
|
||||
},
|
||||
"89": {
|
||||
"inputs": {
|
||||
"text": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
"clip": [
|
||||
"84",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "CLIPTextEncode",
|
||||
"_meta": {
|
||||
"title": "CLIP Text Encode (Negative Prompt)"
|
||||
}
|
||||
},
|
||||
"90": {
|
||||
"inputs": {
|
||||
"vae_name": "wan_2.1_vae.safetensors"
|
||||
},
|
||||
"class_type": "VAELoader",
|
||||
"_meta": {
|
||||
"title": "加载VAE"
|
||||
}
|
||||
},
|
||||
"93": {
|
||||
"inputs": {
|
||||
"text": "Model executing a series of poses, dynamic camera movement alternating between detailed close-ups and full shots.",
|
||||
"clip": [
|
||||
"84",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "CLIPTextEncode",
|
||||
"_meta": {
|
||||
"title": "CLIP Text Encode (Positive Prompt)"
|
||||
}
|
||||
},
|
||||
"94": {
|
||||
"inputs": {
|
||||
"fps": 16,
|
||||
"images": [
|
||||
"87",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "CreateVideo",
|
||||
"_meta": {
|
||||
"title": "创建视频"
|
||||
}
|
||||
},
|
||||
"95": {
|
||||
"inputs": {
|
||||
"unet_name": "wan2.2_i2v_high_noise_14B_fp8_scaled.safetensors",
|
||||
"weight_dtype": "default"
|
||||
},
|
||||
"class_type": "UNETLoader",
|
||||
"_meta": {
|
||||
"title": "UNet加载器"
|
||||
}
|
||||
},
|
||||
"96": {
|
||||
"inputs": {
|
||||
"unet_name": "wan2.2_i2v_low_noise_14B_fp8_scaled.safetensors",
|
||||
"weight_dtype": "default"
|
||||
},
|
||||
"class_type": "UNETLoader",
|
||||
"_meta": {
|
||||
"title": "UNet加载器"
|
||||
}
|
||||
},
|
||||
"97": {
|
||||
"inputs": {
|
||||
"image": "start (1).png"
|
||||
},
|
||||
"class_type": "LoadImage",
|
||||
"_meta": {
|
||||
"title": "加载图像"
|
||||
}
|
||||
},
|
||||
"98": {
|
||||
"inputs": {
|
||||
"width": 480,
|
||||
"height": 848,
|
||||
"length": 81,
|
||||
"batch_size": 1,
|
||||
"positive": [
|
||||
"93",
|
||||
0
|
||||
],
|
||||
"negative": [
|
||||
"89",
|
||||
0
|
||||
],
|
||||
"vae": [
|
||||
"90",
|
||||
0
|
||||
],
|
||||
"start_image": [
|
||||
"97",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "WanImageToVideo",
|
||||
"_meta": {
|
||||
"title": "Wan图像到视频"
|
||||
}
|
||||
},
|
||||
"101": {
|
||||
"inputs": {
|
||||
"lora_name": "wan2.2_i2v_lightx2v_4steps_lora_v1_high_noise.safetensors",
|
||||
"strength_model": 1.0000000000000002,
|
||||
"model": [
|
||||
"95",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "LoraLoaderModelOnly",
|
||||
"_meta": {
|
||||
"title": "LoRA加载器(仅模型)"
|
||||
}
|
||||
},
|
||||
"102": {
|
||||
"inputs": {
|
||||
"lora_name": "wan2.2_i2v_lightx2v_4steps_lora_v1_low_noise.safetensors",
|
||||
"strength_model": 1.0000000000000002,
|
||||
"model": [
|
||||
"96",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "LoraLoaderModelOnly",
|
||||
"_meta": {
|
||||
"title": "LoRA加载器(仅模型)"
|
||||
}
|
||||
},
|
||||
"103": {
|
||||
"inputs": {
|
||||
"shift": 5.000000000000001,
|
||||
"model": [
|
||||
"102",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "ModelSamplingSD3",
|
||||
"_meta": {
|
||||
"title": "采样算法(SD3)"
|
||||
}
|
||||
},
|
||||
"104": {
|
||||
"inputs": {
|
||||
"shift": 5.000000000000001,
|
||||
"model": [
|
||||
"101",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "ModelSamplingSD3",
|
||||
"_meta": {
|
||||
"title": "采样算法(SD3)"
|
||||
}
|
||||
},
|
||||
"108": {
|
||||
"inputs": {
|
||||
"filename_prefix": "video/ComfyUI",
|
||||
"format": "auto",
|
||||
"codec": "auto",
|
||||
"video-preview": "",
|
||||
"video": [
|
||||
"94",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "SaveVideo",
|
||||
"_meta": {
|
||||
"title": "保存视频"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ComfyUIServerI2V:
|
||||
def __init__(self, request_data):
|
||||
self.pose_transform_data = None
|
||||
self.image_url = request_data.image_url
|
||||
self.prompt = request_data.prompt
|
||||
|
||||
self.tasks_id = request_data.tasks_id
|
||||
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
|
||||
self.server_status_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'gif_url': '', 'video_url': '', 'image_url': ''}
|
||||
self.minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
|
||||
|
||||
def get_result(self):
|
||||
workflow_json['93']['inputs']['text'] = self.prompt
|
||||
workflow_json['86']['inputs']["noise_seed"] = random.randint(0, 10 ** 18)
|
||||
|
||||
if self.image_url:
|
||||
# 下载图片 上传 comfyui server
|
||||
in_memory_file, object_name = self.download_from_minio_in_memory(self.image_url)
|
||||
# TODO 设置视频宽度为480,高度自适应
|
||||
workflow_json['98']['inputs']["width"] = 480
|
||||
workflow_json['98']['inputs']["height"] = 848
|
||||
if in_memory_file and object_name:
|
||||
# 上传图片到comfyui server
|
||||
filename = self.upload_in_memory_file_to_comfyui(in_memory_file, object_name)
|
||||
workflow_json['97']['inputs']['image'] = filename
|
||||
|
||||
# 1. 提交任务
|
||||
prompt_response = self.queue_prompt(workflow_json, self.tasks_id)
|
||||
if not prompt_response:
|
||||
return None
|
||||
prompt_id = prompt_response.get("prompt_id")
|
||||
logger.info(f" 任务已提交,Prompt ID: {prompt_id}")
|
||||
outputs = self.poll_history(prompt_id)
|
||||
file_list = {}
|
||||
for node_id, node_output in outputs.items():
|
||||
# 检查当前节点输出中是否包含 'images' 列表
|
||||
if 'images' in node_output and isinstance(node_output['images'], list):
|
||||
|
||||
# 'images' 列表中的每个元素都是一个文件对象
|
||||
for file_info in node_output['images']:
|
||||
# 确保关键字段存在
|
||||
if all(key in file_info for key in ['filename', 'subfolder', 'type']):
|
||||
file_list = {
|
||||
'filename': file_info['filename'],
|
||||
'subfolder': file_info['subfolder'],
|
||||
'type': file_info['type']
|
||||
}
|
||||
logger.info(file_list)
|
||||
return self.process_and_upload_comfyui_video(filename=file_list['filename'], subfolder=file_list['subfolder'], prompt_id=prompt_response['prompt_id']), prompt_id
|
||||
return None
|
||||
|
||||
def download_from_minio_in_memory(self, image_url):
|
||||
bucket = image_url.split('/')[0]
|
||||
object_name = image_url[image_url.find('/') + 1:]
|
||||
|
||||
try:
|
||||
# get_object 返回一个 ResponseStream 对象
|
||||
response_stream = self.minio_client.get_object(
|
||||
bucket,
|
||||
object_name,
|
||||
)
|
||||
|
||||
# 读取整个流到内存 (BytesIO),避免写入本地文件
|
||||
image_bytes = response_stream.read()
|
||||
|
||||
response_stream.close()
|
||||
response_stream.release_conn()
|
||||
|
||||
in_memory_file = io.BytesIO(image_bytes)
|
||||
|
||||
# print(f"✅ 图片已下载到内存 ({len(image_bytes)} 字节)。")
|
||||
return in_memory_file, object_name.rsplit('/')[-1]
|
||||
|
||||
except S3Error as e:
|
||||
logger.error(f"❌ MinIO S3 错误 (例如,对象不存在): {e}")
|
||||
return None, None
|
||||
except Exception as e:
|
||||
logger.error(f"❌ MinIO 下载过程中发生未知错误: {e}")
|
||||
return None, None
|
||||
|
||||
@staticmethod
|
||||
def upload_in_memory_file_to_comfyui(in_memory_file, filename):
|
||||
upload_url = f"http://{settings.COMFYUI_SERVER_ADDRESS}/upload/image"
|
||||
|
||||
data = {
|
||||
"overwrite": "true",
|
||||
"type": "input"
|
||||
}
|
||||
|
||||
# 构建 multipart/form-data: (文件名, 内存文件对象, MIME 类型)
|
||||
# MIME 类型可以根据实际图片类型修改,这里使用常见的 png/jpeg
|
||||
mime_type = 'image/png' if filename.lower().endswith('.png') else 'image/jpeg'
|
||||
|
||||
files = {
|
||||
'image': (filename, in_memory_file, mime_type)
|
||||
}
|
||||
|
||||
# print(f"⬆️ 正在上传图片 ({filename}) 到 ComfyUI...")
|
||||
try:
|
||||
comfyui_response = requests.post(upload_url, data=data, files=files)
|
||||
comfyui_response.raise_for_status()
|
||||
|
||||
result = comfyui_response.json()
|
||||
uploaded_name = result.get('name')
|
||||
|
||||
# print(f"🎉 ComfyUI 上传成功! 服务器文件名: {uploaded_name}")
|
||||
return uploaded_name
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"❌ ComfyUI 上传失败: {e}")
|
||||
logger.error(f" 响应内容: {comfyui_response.text}")
|
||||
return None
|
||||
|
||||
def process_and_upload_comfyui_video(self, filename: str, subfolder: str, prompt_id: str, ):
|
||||
"""
|
||||
完整的自动化流程:获取 ComfyUI 视频 -> 转换 GIF 并提取帧 -> 上传所有结果到 MinIO。
|
||||
"""
|
||||
# 1. 从 ComfyUI 获取视频二进制数据
|
||||
mp4_bytes = self.get_comfyui_video_bytes(filename, subfolder)
|
||||
if not mp4_bytes:
|
||||
return None
|
||||
|
||||
# 2. 准备进行视频处理
|
||||
# moviepy 不支持直接使用 bytes,需要将 bytes 写入一个 BytesIO 或临时文件
|
||||
# 为了避免写磁盘,我们将使用 BytesIO,但 MoviePy 内部依赖 FFmpeg,有时需要一个可寻址的本地文件路径。
|
||||
# 最可靠且避免写本地的方案是在内存中操作,然后将结果上传。
|
||||
|
||||
# ⚠️ 关键点:将 mp4_bytes 写入 BytesIO 以模拟文件,供 moviepy 读取
|
||||
|
||||
# 定义输出对象名
|
||||
|
||||
output_base_name = uuid.uuid4().hex
|
||||
MP4_OBJECT = f"{self.user_id}/pose_transform_video/{prompt_id}/{output_base_name}.mp4"
|
||||
GIF_OBJECT = f"{self.user_id}/pose_transform_gif/{prompt_id}/{output_base_name}.gif"
|
||||
FRAME_OBJECT = f"{self.user_id}/pose_transform_first_img/{prompt_id}/{output_base_name}_frame.jpg"
|
||||
|
||||
# --- 视频处理和帧提取 ---
|
||||
try:
|
||||
# 1. 创建一个临时的 MP4 文件路径
|
||||
# delete=False 确保文件在关闭后仍然存在,直到我们手动删除
|
||||
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_file:
|
||||
tmp_file.write(mp4_bytes) # 将内存数据写入磁盘
|
||||
temp_mp4_path = tmp_file.name # 记录文件路径
|
||||
|
||||
# print(f"临时文件已写入: {temp_mp4_path}")
|
||||
|
||||
# 2. 使用 moviepy 打开临时文件 (传入文件路径字符串)
|
||||
clip = VideoFileClip(temp_mp4_path)
|
||||
|
||||
# --- 在这里进行所有的视频处理和提取操作 ---
|
||||
|
||||
# 提取第一帧 (保持原尺寸)
|
||||
frame_array = clip.get_frame(t=0.0)
|
||||
image = Image.fromarray(frame_array)
|
||||
|
||||
frame_stream = io.BytesIO()
|
||||
image.save(frame_stream, 'JPEG')
|
||||
frame_bytes = frame_stream.getvalue()
|
||||
|
||||
logger.info("✅ 成功提取第一帧图片。")
|
||||
|
||||
# 视频转 GIF (使用另一个临时文件来保存 GIF)
|
||||
temp_gif_path = ""
|
||||
with tempfile.NamedTemporaryFile(suffix=".gif", delete=False) as tmp_file:
|
||||
temp_gif_path = tmp_file.name
|
||||
|
||||
target_fps = int(round(clip.fps)) if clip.fps else 24
|
||||
clip.write_gif(temp_gif_path, fps=target_fps)
|
||||
|
||||
with open(temp_gif_path, 'rb') as f:
|
||||
gif_bytes = f.read()
|
||||
|
||||
logger.info("✅ 成功生成 GIF。")
|
||||
|
||||
# 返回结果 (例如: 上传到 MinIO)
|
||||
# return mp4_bytes, gif_bytes, frame_bytes
|
||||
|
||||
# -----------------------------------------------
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 视频处理或文件操作失败: {e}")
|
||||
# 在失败时,也尝试清理文件
|
||||
|
||||
finally:
|
||||
# 3. 清理临时文件 (非常重要!)
|
||||
if os.path.exists(temp_mp4_path):
|
||||
os.remove(temp_mp4_path)
|
||||
logger.info(f"🗑️ 已删除临时 MP4 文件: {temp_mp4_path}")
|
||||
|
||||
if 'temp_gif_path' in locals() and os.path.exists(temp_gif_path):
|
||||
os.remove(temp_gif_path)
|
||||
logger.info(f"🗑️ 已删除临时 GIF 文件: {temp_gif_path}")
|
||||
|
||||
# 3. 上传所有结果到 MinIO
|
||||
|
||||
try:
|
||||
# 上传原始 MP4
|
||||
self.upload_stream_to_minio(mp4_bytes, MP4_OBJECT, "video/mp4")
|
||||
|
||||
# 上传生成的 GIF
|
||||
self.upload_stream_to_minio(gif_bytes, GIF_OBJECT, "image/gif")
|
||||
|
||||
# 上传第一帧图片
|
||||
self.upload_stream_to_minio(frame_bytes, FRAME_OBJECT, "image/jpeg")
|
||||
|
||||
self.pose_transform_data = {'tasks_id': self.tasks_id, 'status': 'SUCCESS', 'message': "success", 'gif_url': f'aida-users/{GIF_OBJECT}', 'video_url': f'aida-users/{MP4_OBJECT}', 'image_url': f'aida-users/{FRAME_OBJECT}'}
|
||||
|
||||
# 推送消息
|
||||
if not settings.DEBUG:
|
||||
publish_status(json.dumps(self.pose_transform_data), PS_RABBITMQ_QUEUES)
|
||||
logger.info(
|
||||
f" [x] Sent to: {PS_RABBITMQ_QUEUES} data:@@@@ {json.dumps(self.pose_transform_data, indent=4)}")
|
||||
|
||||
return "\n🎉 所有任务完成!"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return None
|
||||
|
||||
# --- 辅助函数:提交任务到队列 ---
|
||||
@staticmethod
|
||||
def queue_prompt(prompt, client_id):
|
||||
"""向 ComfyUI 提交工作流提示。"""
|
||||
p = {"prompt": prompt, "client_id": client_id, "prompt_id": client_id}
|
||||
data = json.dumps(p).encode('utf-8')
|
||||
|
||||
# 提交任务到 /prompt 端点
|
||||
response = requests.post(f"http://{settings.COMFYUI_SERVER_ADDRESS}/prompt", data=data)
|
||||
# print(f"-------------{response.text}")
|
||||
# print(f"------------{client_id}")
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
logger.warning(f"提交任务失败,状态码: {response.status_code}")
|
||||
logger.warning(response.text)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def poll_history(prompt_id, interval_seconds=5):
|
||||
"""步骤 2: 轮询 /history/{prompt_id} 检查任务是否完成"""
|
||||
url = f"http://{settings.COMFYUI_SERVER_ADDRESS}/history/{prompt_id}"
|
||||
|
||||
logger.info(f"⏳ 开始轮询状态 (间隔 {interval_seconds} 秒)...")
|
||||
|
||||
while True:
|
||||
time.sleep(interval_seconds)
|
||||
|
||||
try:
|
||||
response = requests.get(url)
|
||||
# 任务未完成时,ComfyUI可能会返回404或空响应,我们只关注成功响应
|
||||
if response.status_code == 200:
|
||||
history_data = response.json()
|
||||
|
||||
# ComfyUI 返回的历史记录结构是 {prompt_id: {outputs: ...}}
|
||||
if prompt_id in history_data:
|
||||
logger.info("🎉 任务已完成!")
|
||||
return history_data[prompt_id]['outputs']
|
||||
|
||||
logger.info("⏳ 任务仍在执行或等待中...")
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
# 处理可能的连接错误,但通常不会在内部轮询中发生
|
||||
logger.info(f"⚠️ 轮询时发生错误: {e}")
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def get_comfyui_video_bytes(filename: str, subfolder: str, file_type: str = "output"):
|
||||
"""
|
||||
从 ComfyUI 的 /view 端点获取视频文件的二进制数据。
|
||||
|
||||
参数:
|
||||
- filename: 视频文件名 (例如: 'ComfyUI_00002_.mp4')
|
||||
- subfolder: 存储子文件夹 (例如: 'ComfyUI_2025-10-31')
|
||||
- file_type: 文件类型 (通常是 'output')
|
||||
|
||||
返回:
|
||||
- 视频文件的二进制内容 (bytes) 或 None。
|
||||
"""
|
||||
url = f"http://{settings.COMFYUI_SERVER_ADDRESS}/view"
|
||||
params = {
|
||||
"filename": filename,
|
||||
"subfolder": subfolder,
|
||||
"type": file_type
|
||||
}
|
||||
|
||||
logger.info(f"📡 正在从 ComfyUI 下载视频: {filename}")
|
||||
try:
|
||||
# 使用 requests.get 下载文件
|
||||
response = requests.get(url, params=params, stream=True)
|
||||
response.raise_for_status() # 检查 HTTP 错误
|
||||
|
||||
# 返回文件的完整二进制内容
|
||||
return response.content
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"❌ 从 ComfyUI 获取视频失败: {e}")
|
||||
return None
|
||||
|
||||
def upload_stream_to_minio(self, video_bytes: bytes, object_name: str, content_type: str):
|
||||
"""从内存流上传数据到 MinIO。"""
|
||||
logger.info(f"☁️ 正在上传对象到 MinIO: {object_name}")
|
||||
try:
|
||||
|
||||
data_stream = io.BytesIO(video_bytes)
|
||||
|
||||
result = self.minio_client.put_object(
|
||||
bucket_name='aida-users',
|
||||
object_name=object_name,
|
||||
data=data_stream,
|
||||
length=len(video_bytes),
|
||||
content_type=content_type
|
||||
)
|
||||
logger.info(f"✅ MinIO 上传成功: {result.object_name}")
|
||||
return True
|
||||
except S3Error as e:
|
||||
logger.error(f"❌ MinIO 上传失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
request_data = ComfyuiI2VModel(
|
||||
tasks_id="12222515151123-89111",
|
||||
image_url="aida-users/89/product_image/a6949500-2393-42ac-8723-440b5d5da2b2-0-89.png",
|
||||
prompt="Model executing a series of poses, dynamic camera movement alternating between detailed close-ups and full shots."
|
||||
)
|
||||
|
||||
server = ComfyUIServerI2V(request_data)
|
||||
print(server.get_result())
|
||||
745
app/service/comfyui_I2V/pose2v_server.py
Normal file
745
app/service/comfyui_I2V/pose2v_server.py
Normal file
@@ -0,0 +1,745 @@
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import redis
|
||||
import requests
|
||||
from PIL import Image
|
||||
from minio import Minio, S3Error
|
||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||
|
||||
from app.core.config import settings
|
||||
from app.schemas.comfyui_i2v import ComfyuiPose2VModel
|
||||
from app.service.generate_image.utils.mq import publish_status
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
# 图 + 骨架 = 视频 工作流
|
||||
workflow_json = {
|
||||
"162": {
|
||||
"inputs": {
|
||||
"text": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||
"clip": [
|
||||
"167",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "CLIPTextEncode",
|
||||
"_meta": {
|
||||
"title": "CLIP Text Encode (Negative Prompt)"
|
||||
}
|
||||
},
|
||||
"163": {
|
||||
"inputs": {
|
||||
"fps": 24,
|
||||
"images": [
|
||||
"192",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "CreateVideo",
|
||||
"_meta": {
|
||||
"title": "创建视频"
|
||||
}
|
||||
},
|
||||
"164": {
|
||||
"inputs": {
|
||||
"samples": [
|
||||
"175",
|
||||
0
|
||||
],
|
||||
"vae": [
|
||||
"168",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "VAEDecode",
|
||||
"_meta": {
|
||||
"title": "VAE解码"
|
||||
}
|
||||
},
|
||||
"165": {
|
||||
"inputs": {
|
||||
"unet_name": "wan2.2_fun_control_high_noise_14B_fp8_scaled.safetensors",
|
||||
"weight_dtype": "default"
|
||||
},
|
||||
"class_type": "UNETLoader",
|
||||
"_meta": {
|
||||
"title": "UNet加载器"
|
||||
}
|
||||
},
|
||||
"166": {
|
||||
"inputs": {
|
||||
"unet_name": "wan2.2_fun_control_low_noise_14B_fp8_scaled.safetensors",
|
||||
"weight_dtype": "default"
|
||||
},
|
||||
"class_type": "UNETLoader",
|
||||
"_meta": {
|
||||
"title": "UNet加载器"
|
||||
}
|
||||
},
|
||||
"167": {
|
||||
"inputs": {
|
||||
"clip_name": "umt5_xxl_fp8_e4m3fn_scaled.safetensors",
|
||||
"type": "wan",
|
||||
"device": "default"
|
||||
},
|
||||
"class_type": "CLIPLoader",
|
||||
"_meta": {
|
||||
"title": "加载CLIP"
|
||||
}
|
||||
},
|
||||
"168": {
|
||||
"inputs": {
|
||||
"vae_name": "wan_2.1_vae.safetensors"
|
||||
},
|
||||
"class_type": "VAELoader",
|
||||
"_meta": {
|
||||
"title": "加载VAE"
|
||||
}
|
||||
},
|
||||
"169": {
|
||||
"inputs": {
|
||||
"add_noise": "enable",
|
||||
"noise_seed": 8860422635573,
|
||||
"steps": 4,
|
||||
"cfg": 1,
|
||||
"sampler_name": "euler",
|
||||
"scheduler": "simple",
|
||||
"start_at_step": 0,
|
||||
"end_at_step": 2,
|
||||
"return_with_leftover_noise": "enable",
|
||||
"model": [
|
||||
"176",
|
||||
0
|
||||
],
|
||||
"positive": [
|
||||
"180",
|
||||
0
|
||||
],
|
||||
"negative": [
|
||||
"180",
|
||||
1
|
||||
],
|
||||
"latent_image": [
|
||||
"180",
|
||||
2
|
||||
]
|
||||
},
|
||||
"class_type": "KSamplerAdvanced",
|
||||
"_meta": {
|
||||
"title": "K采样器(高级)"
|
||||
}
|
||||
},
|
||||
"170": {
|
||||
"inputs": {
|
||||
"filename_prefix": "video/wan2.2_fun_control",
|
||||
"format": "auto",
|
||||
"codec": "auto",
|
||||
"video-preview": "",
|
||||
"video": [
|
||||
"163",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "SaveVideo",
|
||||
"_meta": {
|
||||
"title": "保存视频"
|
||||
}
|
||||
},
|
||||
"171": {
|
||||
"inputs": {
|
||||
"video": [
|
||||
"174",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "GetVideoComponents",
|
||||
"_meta": {
|
||||
"title": "获取视频组件"
|
||||
}
|
||||
},
|
||||
"174": {
|
||||
"inputs": {
|
||||
"file": "skeleton_3.mp4"
|
||||
},
|
||||
"class_type": "LoadVideo",
|
||||
"_meta": {
|
||||
"title": "加载视频"
|
||||
}
|
||||
},
|
||||
"175": {
|
||||
"inputs": {
|
||||
"add_noise": "disable",
|
||||
"noise_seed": 0,
|
||||
"steps": 4,
|
||||
"cfg": 1,
|
||||
"sampler_name": "euler",
|
||||
"scheduler": "simple",
|
||||
"start_at_step": 2,
|
||||
"end_at_step": 4,
|
||||
"return_with_leftover_noise": "disable",
|
||||
"model": [
|
||||
"177",
|
||||
0
|
||||
],
|
||||
"positive": [
|
||||
"180",
|
||||
0
|
||||
],
|
||||
"negative": [
|
||||
"180",
|
||||
1
|
||||
],
|
||||
"latent_image": [
|
||||
"169",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "KSamplerAdvanced",
|
||||
"_meta": {
|
||||
"title": "K采样器(高级)"
|
||||
}
|
||||
},
|
||||
"176": {
|
||||
"inputs": {
|
||||
"shift": 8.000000000000002,
|
||||
"model": [
|
||||
"181",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "ModelSamplingSD3",
|
||||
"_meta": {
|
||||
"title": "采样算法(SD3)"
|
||||
}
|
||||
},
|
||||
"177": {
|
||||
"inputs": {
|
||||
"shift": 8.000000000000002,
|
||||
"model": [
|
||||
"182",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "ModelSamplingSD3",
|
||||
"_meta": {
|
||||
"title": "采样算法(SD3)"
|
||||
}
|
||||
},
|
||||
"178": {
|
||||
"inputs": {
|
||||
"image": "296f5fd6-c5e4-4003-9798-f378a4f08411-0-89.png"
|
||||
},
|
||||
"class_type": "LoadImage",
|
||||
"_meta": {
|
||||
"title": "加载图像"
|
||||
}
|
||||
},
|
||||
"179": {
|
||||
"inputs": {
|
||||
"text": "The model is catwalking at the fashion show.",
|
||||
"clip": [
|
||||
"167",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "CLIPTextEncode",
|
||||
"_meta": {
|
||||
"title": "CLIP Text Encode (Positive Prompt)"
|
||||
}
|
||||
},
|
||||
"180": {
|
||||
"inputs": {
|
||||
"width": 480,
|
||||
"height": 720,
|
||||
"length": 121,
|
||||
"batch_size": 1,
|
||||
"positive": [
|
||||
"179",
|
||||
0
|
||||
],
|
||||
"negative": [
|
||||
"162",
|
||||
0
|
||||
],
|
||||
"vae": [
|
||||
"168",
|
||||
0
|
||||
],
|
||||
"ref_image": [
|
||||
"178",
|
||||
0
|
||||
],
|
||||
"control_video": [
|
||||
"171",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "Wan22FunControlToVideo",
|
||||
"_meta": {
|
||||
"title": "Wan22FunControlToVideo"
|
||||
}
|
||||
},
|
||||
"181": {
|
||||
"inputs": {
|
||||
"lora_name": "wan2.2_i2v_lightx2v_4steps_lora_v1_high_noise.safetensors",
|
||||
"strength_model": 1,
|
||||
"model": [
|
||||
"165",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "LoraLoaderModelOnly",
|
||||
"_meta": {
|
||||
"title": "LoRA加载器(仅模型)"
|
||||
}
|
||||
},
|
||||
"182": {
|
||||
"inputs": {
|
||||
"lora_name": "wan2.2_i2v_lightx2v_4steps_lora_v1_low_noise.safetensors",
|
||||
"strength_model": 1,
|
||||
"model": [
|
||||
"166",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "LoraLoaderModelOnly",
|
||||
"_meta": {
|
||||
"title": "LoRA加载器(仅模型)"
|
||||
}
|
||||
},
|
||||
"189": {
|
||||
"inputs": {
|
||||
"images": [
|
||||
"171",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "PreviewImage",
|
||||
"_meta": {
|
||||
"title": "预览图像"
|
||||
}
|
||||
},
|
||||
"190": {
|
||||
"inputs": {
|
||||
"images": [
|
||||
"192",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "PreviewImage",
|
||||
"_meta": {
|
||||
"title": "预览图像"
|
||||
}
|
||||
},
|
||||
"192": {
|
||||
"inputs": {
|
||||
"batch_index": 4,
|
||||
"length": 117,
|
||||
"image": [
|
||||
"164",
|
||||
0
|
||||
]
|
||||
},
|
||||
"class_type": "ImageFromBatch",
|
||||
"_meta": {
|
||||
"title": "从批次获取图像"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# 骨架映射
|
||||
video_map = {
|
||||
"1": "input_pose_video/1.mp4",
|
||||
"2": "input_pose_video/2.mp4",
|
||||
"3": "input_pose_video/3.mp4",
|
||||
"4": "input_pose_video/4.mp4",
|
||||
"5": "input_pose_video/5.mp4",
|
||||
"6": "input_pose_video/6.mp4"
|
||||
}
|
||||
|
||||
|
||||
class ComfyUIServerPose2V:
|
||||
def __init__(self, request_data):
|
||||
self.image_url = request_data.image_url
|
||||
self.pose_num = request_data.pose_id
|
||||
self.tasks_id = request_data.tasks_id
|
||||
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
|
||||
self.redis_client = redis.StrictRedis(host=settings.REDIS_HOST, port=settings.REDIS_PORT, db=settings.REDIS_DB, decode_responses=True)
|
||||
self.pose_transform_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'gif_url': '', 'video_url': '', 'image_url': ''}
|
||||
self.redis_client.set(self.tasks_id, json.dumps(self.pose_transform_data))
|
||||
self.redis_client.expire(self.tasks_id, 600)
|
||||
self.minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
|
||||
|
||||
def get_result(self):
|
||||
workflow_json['174']['inputs']['file'] = video_map[self.pose_num]
|
||||
workflow_json['169']['inputs']['noise_seed'] = random.randint(0, 10 ** 18)
|
||||
|
||||
# 下载图片 上传 comfyui server
|
||||
in_memory_file, object_name = self.download_from_minio_in_memory()
|
||||
if in_memory_file and object_name:
|
||||
uploaded_filename = self.upload_in_memory_file_to_comfyui(in_memory_file, object_name)
|
||||
workflow_json['178']['inputs']['image'] = uploaded_filename
|
||||
# 1. 提交任务
|
||||
prompt_response = self.queue_prompt(workflow_json, self.tasks_id)
|
||||
if not prompt_response:
|
||||
return None
|
||||
|
||||
prompt_id = prompt_response.get("prompt_id")
|
||||
logger.info(f" 任务已提交,Prompt ID: {prompt_id}")
|
||||
|
||||
outputs = self.poll_history(prompt_id)
|
||||
file_list = {}
|
||||
for node_id, node_output in outputs.items():
|
||||
# 检查当前节点输出中是否包含 'images' 列表
|
||||
if 'images' in node_output and isinstance(node_output['images'], list):
|
||||
|
||||
# 'images' 列表中的每个元素都是一个文件对象
|
||||
for file_info in node_output['images']:
|
||||
# 确保关键字段存在
|
||||
if all(key in file_info for key in ['filename', 'subfolder', 'type']):
|
||||
file_list = {
|
||||
'filename': file_info['filename'],
|
||||
'subfolder': file_info['subfolder'],
|
||||
'type': file_info['type']
|
||||
}
|
||||
logger.info(file_list)
|
||||
return self.process_and_upload_comfyui_video(filename=file_list['filename'], subfolder=file_list['subfolder'], prompt_id=prompt_response['prompt_id']), prompt_id
|
||||
return None
|
||||
|
||||
def read_tasks_status(self):
|
||||
status_data = self.redis_client.get(self.tasks_id)
|
||||
return json.loads(status_data), status_data
|
||||
|
||||
def download_from_minio_in_memory(self):
|
||||
bucket = self.image_url.split('/')[0]
|
||||
object_name = self.image_url[self.image_url.find('/') + 1:]
|
||||
# print("🚀 正在连接 MinIO 客户端...")
|
||||
|
||||
try:
|
||||
# get_object 返回一个 ResponseStream 对象
|
||||
response_stream = self.minio_client.get_object(
|
||||
bucket,
|
||||
object_name,
|
||||
)
|
||||
|
||||
# 读取整个流到内存 (BytesIO),避免写入本地文件
|
||||
image_bytes = response_stream.read()
|
||||
|
||||
response_stream.close()
|
||||
response_stream.release_conn()
|
||||
|
||||
in_memory_file = io.BytesIO(image_bytes)
|
||||
|
||||
# print(f"✅ 图片已下载到内存 ({len(image_bytes)} 字节)。")
|
||||
return in_memory_file, object_name.rsplit('/')[-1]
|
||||
|
||||
except S3Error as e:
|
||||
logger.error(f"❌ MinIO S3 错误 (例如,对象不存在): {e}")
|
||||
return None, None
|
||||
except Exception as e:
|
||||
logger.error(f"❌ MinIO 下载过程中发生未知错误: {e}")
|
||||
return None, None
|
||||
|
||||
def upload_video_to_minio(self, BUCKET_NAME, OBJECT_NAME, LOCAL_FILE_PATH):
|
||||
"""使用 fput_object 从本地路径上传 MP4 文件"""
|
||||
try:
|
||||
# 使用 fput_object 上传文件
|
||||
# content_type 对于视频流播放非常重要,MP4 文件应使用 'video/mp4'
|
||||
result = self.minio_client.fput_object(
|
||||
bucket_name=BUCKET_NAME,
|
||||
object_name=OBJECT_NAME,
|
||||
file_path=LOCAL_FILE_PATH,
|
||||
content_type="video/mp4" # 设置正确的内容类型
|
||||
)
|
||||
|
||||
# print(f"✅ 文件 '{LOCAL_FILE_PATH}' 已成功上传至 MinIO:")
|
||||
# print(f" 对象名: {result.object_name}")
|
||||
# print(f" Etag: {result.etag}")
|
||||
|
||||
except S3Error as e:
|
||||
logger.error(f"❌ MinIO 操作失败: {e}")
|
||||
except FileNotFoundError:
|
||||
logger.error(f"❌ 找不到本地文件: {LOCAL_FILE_PATH}")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 发生未知错误: {e}")
|
||||
|
||||
def upload_gif_to_minio(self, BUCKET_NAME, OBJECT_NAME, LOCAL_FILE_PATH):
|
||||
"""使用 fput_object 从本地路径上传 MP4 文件"""
|
||||
try:
|
||||
# 使用 fput_object 上传文件
|
||||
# content_type 对于视频流播放非常重要,MP4 文件应使用 'video/mp4'
|
||||
result = self.minio_client.fput_object(
|
||||
bucket_name=BUCKET_NAME,
|
||||
object_name=OBJECT_NAME,
|
||||
file_path=LOCAL_FILE_PATH,
|
||||
content_type="video/mp4" # 设置正确的内容类型
|
||||
)
|
||||
|
||||
# print(f"✅ 文件 '{LOCAL_FILE_PATH}' 已成功上传至 MinIO:")
|
||||
# print(f" 对象名: {result.object_name}")
|
||||
# print(f" Etag: {result.etag}")
|
||||
|
||||
except S3Error as e:
|
||||
logger.error(f"❌ MinIO 操作失败: {e}")
|
||||
except FileNotFoundError:
|
||||
logger.error(f"❌ 找不到本地文件: {LOCAL_FILE_PATH}")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 发生未知错误: {e}")
|
||||
|
||||
@staticmethod
|
||||
def upload_in_memory_file_to_comfyui(in_memory_file, filename):
|
||||
upload_url = f"http://{settings.COMFYUI_SERVER_ADDRESS}/upload/image"
|
||||
|
||||
data = {
|
||||
"overwrite": "true",
|
||||
"type": "input"
|
||||
}
|
||||
|
||||
# 构建 multipart/form-data: (文件名, 内存文件对象, MIME 类型)
|
||||
# MIME 类型可以根据实际图片类型修改,这里使用常见的 png/jpeg
|
||||
mime_type = 'image/png' if filename.lower().endswith('.png') else 'image/jpeg'
|
||||
|
||||
files = {
|
||||
'image': (filename, in_memory_file, mime_type)
|
||||
}
|
||||
|
||||
# print(f"⬆️ 正在上传图片 ({filename}) 到 ComfyUI...")
|
||||
try:
|
||||
comfyui_response = requests.post(upload_url, data=data, files=files)
|
||||
comfyui_response.raise_for_status()
|
||||
|
||||
result = comfyui_response.json()
|
||||
uploaded_name = result.get('name')
|
||||
|
||||
# print(f"🎉 ComfyUI 上传成功! 服务器文件名: {uploaded_name}")
|
||||
return uploaded_name
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"❌ ComfyUI 上传失败: {e}")
|
||||
logger.error(f" 响应内容: {comfyui_response.text}")
|
||||
return None
|
||||
|
||||
def process_and_upload_comfyui_video(self, filename: str, subfolder: str, prompt_id: str, ):
|
||||
"""
|
||||
完整的自动化流程:获取 ComfyUI 视频 -> 转换 GIF 并提取帧 -> 上传所有结果到 MinIO。
|
||||
"""
|
||||
# 1. 从 ComfyUI 获取视频二进制数据
|
||||
mp4_bytes = self.get_comfyui_video_bytes(filename, subfolder)
|
||||
if not mp4_bytes:
|
||||
return None
|
||||
|
||||
# 2. 准备进行视频处理
|
||||
# moviepy 不支持直接使用 bytes,需要将 bytes 写入一个 BytesIO 或临时文件
|
||||
# 为了避免写磁盘,我们将使用 BytesIO,但 MoviePy 内部依赖 FFmpeg,有时需要一个可寻址的本地文件路径。
|
||||
# 最可靠且避免写本地的方案是在内存中操作,然后将结果上传。
|
||||
|
||||
# ⚠️ 关键点:将 mp4_bytes 写入 BytesIO 以模拟文件,供 moviepy 读取
|
||||
|
||||
# 定义输出对象名
|
||||
|
||||
output_base_name = uuid.uuid4().hex
|
||||
MP4_OBJECT = f"{self.user_id}/pose_transform_video/{prompt_id}/{output_base_name}.mp4"
|
||||
GIF_OBJECT = f"{self.user_id}/pose_transform_gif/{prompt_id}/{output_base_name}.gif"
|
||||
FRAME_OBJECT = f"{self.user_id}/pose_transform_first_img/{prompt_id}/{output_base_name}_frame.jpg"
|
||||
|
||||
# --- 视频处理和帧提取 ---
|
||||
try:
|
||||
# 1. 创建一个临时的 MP4 文件路径
|
||||
# delete=False 确保文件在关闭后仍然存在,直到我们手动删除
|
||||
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_file:
|
||||
tmp_file.write(mp4_bytes) # 将内存数据写入磁盘
|
||||
temp_mp4_path = tmp_file.name # 记录文件路径
|
||||
|
||||
# print(f"临时文件已写入: {temp_mp4_path}")
|
||||
|
||||
# 2. 使用 moviepy 打开临时文件 (传入文件路径字符串)
|
||||
clip = VideoFileClip(temp_mp4_path)
|
||||
|
||||
# --- 在这里进行所有的视频处理和提取操作 ---
|
||||
|
||||
# 提取第一帧 (保持原尺寸)
|
||||
frame_array = clip.get_frame(t=0.0)
|
||||
image = Image.fromarray(frame_array)
|
||||
|
||||
frame_stream = io.BytesIO()
|
||||
image.save(frame_stream, 'JPEG')
|
||||
frame_bytes = frame_stream.getvalue()
|
||||
|
||||
logger.info("✅ 成功提取第一帧图片。")
|
||||
|
||||
# 视频转 GIF (使用另一个临时文件来保存 GIF)
|
||||
temp_gif_path = ""
|
||||
with tempfile.NamedTemporaryFile(suffix=".gif", delete=False) as tmp_file:
|
||||
temp_gif_path = tmp_file.name
|
||||
|
||||
target_fps = int(round(clip.fps)) if clip.fps else 24
|
||||
clip.write_gif(temp_gif_path, fps=target_fps)
|
||||
|
||||
with open(temp_gif_path, 'rb') as f:
|
||||
gif_bytes = f.read()
|
||||
|
||||
logger.info("✅ 成功生成 GIF。")
|
||||
|
||||
# 返回结果 (例如: 上传到 MinIO)
|
||||
# return mp4_bytes, gif_bytes, frame_bytes
|
||||
|
||||
# -----------------------------------------------
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 视频处理或文件操作失败: {e}")
|
||||
# 在失败时,也尝试清理文件
|
||||
|
||||
finally:
|
||||
# 3. 清理临时文件 (非常重要!)
|
||||
if os.path.exists(temp_mp4_path):
|
||||
os.remove(temp_mp4_path)
|
||||
logger.info(f"🗑️ 已删除临时 MP4 文件: {temp_mp4_path}")
|
||||
|
||||
if 'temp_gif_path' in locals() and os.path.exists(temp_gif_path):
|
||||
os.remove(temp_gif_path)
|
||||
logger.info(f"🗑️ 已删除临时 GIF 文件: {temp_gif_path}")
|
||||
|
||||
# 3. 上传所有结果到 MinIO
|
||||
|
||||
try:
|
||||
# 上传原始 MP4
|
||||
self.upload_stream_to_minio(mp4_bytes, MP4_OBJECT, "video/mp4")
|
||||
|
||||
# 上传生成的 GIF
|
||||
self.upload_stream_to_minio(gif_bytes, GIF_OBJECT, "image/gif")
|
||||
|
||||
# 上传第一帧图片
|
||||
self.upload_stream_to_minio(frame_bytes, FRAME_OBJECT, "image/jpeg")
|
||||
|
||||
self.pose_transform_data = {'tasks_id': self.tasks_id, 'status': 'SUCCESS', 'message': "success", 'gif_url': f'aida-users/{GIF_OBJECT}', 'video_url': f'aida-users/{MP4_OBJECT}', 'image_url': f'aida-users/{FRAME_OBJECT}'}
|
||||
|
||||
# 推送消息
|
||||
if not settings.DEBUG:
|
||||
publish_status(json.dumps(self.pose_transform_data), settings.COMFYUI_SERVER_ADDRESS)
|
||||
logger.info(
|
||||
f" [x] Sent to: {settings.COMFYUI_SERVER_ADDRESS} data:@@@@ {json.dumps(self.pose_transform_data, indent=4)}")
|
||||
|
||||
return "\n🎉 所有任务完成!"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
return None
|
||||
|
||||
# --- 辅助函数:提交任务到队列 ---
|
||||
@staticmethod
|
||||
def queue_prompt(prompt, client_id):
|
||||
"""向 ComfyUI 提交工作流提示。"""
|
||||
p = {"prompt": prompt, "client_id": client_id, "prompt_id": client_id}
|
||||
data = json.dumps(p).encode('utf-8')
|
||||
|
||||
# 提交任务到 /prompt 端点
|
||||
# noinspection HttpUrlsUsage
|
||||
response = requests.post(f"http://{settings.COMFYUI_SERVER_ADDRESS}/prompt", data=data)
|
||||
# print(f"-------------{response.text}")
|
||||
# print(f"------------{client_id}")
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
logger.warning(f"提交任务失败,状态码: {response.status_code}")
|
||||
logger.warning(response.text)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def poll_history(prompt_id, interval_seconds=5):
|
||||
"""步骤 2: 轮询 /history/{prompt_id} 检查任务是否完成"""
|
||||
url = f"http://{settings.COMFYUI_SERVER_ADDRESS}/history/{prompt_id}"
|
||||
|
||||
logger.info(f"⏳ 开始轮询状态 (间隔 {interval_seconds} 秒)...")
|
||||
|
||||
while True:
|
||||
time.sleep(interval_seconds)
|
||||
|
||||
try:
|
||||
response = requests.get(url)
|
||||
# 任务未完成时,ComfyUI可能会返回404或空响应,我们只关注成功响应
|
||||
if response.status_code == 200:
|
||||
history_data = response.json()
|
||||
|
||||
# ComfyUI 返回的历史记录结构是 {prompt_id: {outputs: ...}}
|
||||
if prompt_id in history_data:
|
||||
logger.info("🎉 任务已完成!")
|
||||
return history_data[prompt_id]['outputs']
|
||||
|
||||
logger.info("⏳ 任务仍在执行或等待中...")
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
# 处理可能的连接错误,但通常不会在内部轮询中发生
|
||||
logger.info(f"⚠️ 轮询时发生错误: {e}")
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def get_comfyui_video_bytes(filename: str, subfolder: str, file_type: str = "output"):
|
||||
"""
|
||||
从 ComfyUI 的 /view 端点获取视频文件的二进制数据。
|
||||
|
||||
参数:
|
||||
- filename: 视频文件名 (例如: 'ComfyUI_00002_.mp4')
|
||||
- subfolder: 存储子文件夹 (例如: 'ComfyUI_2025-10-31')
|
||||
- file_type: 文件类型 (通常是 'output')
|
||||
|
||||
返回:
|
||||
- 视频文件的二进制内容 (bytes) 或 None。
|
||||
"""
|
||||
url = f"http://{settings.COMFYUI_SERVER_ADDRESS}/view"
|
||||
params = {
|
||||
"filename": filename,
|
||||
"subfolder": subfolder,
|
||||
"type": file_type
|
||||
}
|
||||
|
||||
logger.info(f"📡 正在从 ComfyUI 下载视频: {filename}")
|
||||
try:
|
||||
# 使用 requests.get 下载文件
|
||||
response = requests.get(url, params=params, stream=True)
|
||||
response.raise_for_status() # 检查 HTTP 错误
|
||||
|
||||
# 返回文件的完整二进制内容
|
||||
return response.content
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"❌ 从 ComfyUI 获取视频失败: {e}")
|
||||
return None
|
||||
|
||||
def upload_stream_to_minio(self, video_bytes: bytes, object_name: str, content_type: str):
|
||||
"""从内存流上传数据到 MinIO。"""
|
||||
logger.info(f"☁️ 正在上传对象到 MinIO: {object_name}")
|
||||
try:
|
||||
|
||||
data_stream = io.BytesIO(video_bytes)
|
||||
|
||||
result = self.minio_client.put_object(
|
||||
bucket_name='aida-users',
|
||||
object_name=object_name,
|
||||
data=data_stream,
|
||||
length=len(video_bytes),
|
||||
content_type=content_type
|
||||
)
|
||||
logger.info(f"✅ MinIO 上传成功: {result.object_name}")
|
||||
return True
|
||||
except S3Error as e:
|
||||
logger.error(f"❌ MinIO 上传失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
request_data = ComfyuiPose2VModel(
|
||||
tasks_id="122522251123-89111",
|
||||
image_url="aida-users/89/product_image/a6949500-2393-42ac-8723-440b5d5da2b2-0-89.png",
|
||||
pose_id="6"
|
||||
)
|
||||
|
||||
server = ComfyUIServerPose2V(request_data)
|
||||
print(server.get_result())
|
||||
@@ -1,116 +0,0 @@
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
def show(img, win_name="temp"):
|
||||
cv2.imshow(win_name, img)
|
||||
cv2.waitKey(0)
|
||||
|
||||
|
||||
def crop(img):
|
||||
mid_point_h, mid_point_w = int(img.shape[0] / 2 + 30), int(img.shape[1] / 2)
|
||||
img_roi = img[mid_point_h - 520: mid_point_h + 520, mid_point_w - 340: mid_point_w + 340]
|
||||
return img_roi
|
||||
|
||||
|
||||
class Layer(object):
|
||||
def __init__(self):
|
||||
self._layer = []
|
||||
|
||||
@property
|
||||
def layer(self):
|
||||
return self._layer
|
||||
|
||||
def insert(self, layer_instance):
|
||||
if layer_instance['name'] == 'body':
|
||||
self._body = layer_instance
|
||||
self._layer.append(layer_instance)
|
||||
|
||||
def sort(self, priority):
|
||||
self._layer.sort(key=lambda x: priority[x['name']])
|
||||
|
||||
# def merge(self, cfg):
|
||||
# """
|
||||
# opencv shape order (height, width, channel)
|
||||
# image coordinate system:
|
||||
# |------------->x (width)
|
||||
# |
|
||||
# |
|
||||
# |
|
||||
# y (height)
|
||||
# Returns:
|
||||
#
|
||||
#
|
||||
# """
|
||||
# base_image = Image.new('RGBA', self._layer[1]['image'].size, (0, 0, 0, 0))
|
||||
# for layer in self._layer:
|
||||
# y, x = layer['position']
|
||||
# base_image.paste(layer['image'], (x, y), layer['image'])
|
||||
# # base_image.show()
|
||||
#
|
||||
# for x in self._layer:
|
||||
# if np.all(x['mask'] == 0):
|
||||
# continue
|
||||
# # obtain region of interest about roi(roi) and item-image(roi_image, roi_mask)
|
||||
# roi, roi_mask, roi_image, signal = self.get_roi(dst=dst, image=x)
|
||||
# temp_bg = np.expand_dims(cv2.bitwise_not(roi_mask), axis=2).repeat(3, axis=2)
|
||||
# tmp1 = (roi * (temp_bg / 255)).astype(np.uint8)
|
||||
# temp_fg = np.expand_dims(roi_mask, axis=2).repeat(3, axis=2)
|
||||
# tmp2 = (roi_image * (temp_fg / 255)).astype(np.uint8)
|
||||
#
|
||||
# roi[:] = cv2.add(tmp1, tmp2)
|
||||
# # show(cv2.resize(dst, (int(dst.shape[1] * 0.5), int(dst.shape[0] * 0.5)), interpolation=cv2.INTER_AREA),
|
||||
# # win_name=x.get('name'))
|
||||
# # crop image and get the central part
|
||||
# if cfg.get('basic')['self_template'] == False:
|
||||
# dst_roi = crop(dst)
|
||||
# else:
|
||||
# dst_roi = dst
|
||||
# return dst_roi, signal
|
||||
#
|
||||
# @staticmethod
|
||||
# def get_roi(dst, image):
|
||||
# signal = False
|
||||
# dst_y, dst_x = dst.shape[:2]
|
||||
# roi_height, roi_width = image['mask'].shape
|
||||
# roi_y0, roi_x0 = image['position']
|
||||
#
|
||||
# if roi_y0 < 0:
|
||||
# roi_yin = 0
|
||||
# mask_yin = -roi_y0
|
||||
# signal = True
|
||||
# else:
|
||||
# roi_yin = roi_y0
|
||||
# mask_yin = 0
|
||||
# if roi_y0 + roi_height > dst_y:
|
||||
# roi_yout = dst_y
|
||||
# mask_yout = dst_y - roi_y0
|
||||
# signal = True
|
||||
# else:
|
||||
# roi_yout = roi_height + roi_y0
|
||||
# mask_yout = roi_height
|
||||
# # x part
|
||||
# if roi_x0 < 0:
|
||||
# roi_xin = 0
|
||||
# mask_xin = -roi_x0
|
||||
# signal = True
|
||||
# else:
|
||||
# roi_xin = roi_x0
|
||||
# mask_xin = 0
|
||||
# if roi_x0 + roi_width > dst_x:
|
||||
# roi_xout = dst_x
|
||||
# mask_xout = dst_x - roi_x0
|
||||
# signal = True
|
||||
# else:
|
||||
# roi_xout = roi_width + roi_x0
|
||||
# mask_xout = roi_width
|
||||
#
|
||||
# roi = dst[roi_yin: roi_yout, roi_xin: roi_xout]
|
||||
# roi_mask = image['mask'][mask_yin: mask_yout, mask_xin: mask_xout]
|
||||
# roi_image = image['image'][mask_yin: mask_yout, mask_xin: mask_xout]
|
||||
# return roi, roi_mask, roi_image, signal
|
||||
@@ -1,45 +0,0 @@
|
||||
class Priority(object):
|
||||
"""Item layer priority levels.
|
||||
"""
|
||||
|
||||
def __init__(self, item_list):
|
||||
self._priority = dict(
|
||||
earring_front=99,
|
||||
bag_front=98,
|
||||
hairstyle_front=97,
|
||||
outwear_front=20,
|
||||
bottoms_front=19,
|
||||
dress_front=18,
|
||||
blouse_front=17,
|
||||
skirt_front=16,
|
||||
trousers_front=15,
|
||||
tops_front=14,
|
||||
shoes_right=1,
|
||||
shoes_left=1,
|
||||
body=0,
|
||||
tops_back=-14,
|
||||
trousers_back=-15,
|
||||
skirt_back=-16,
|
||||
blouse_back=-17,
|
||||
dress_back=-18,
|
||||
bottoms_back=-19,
|
||||
outwear_back=-20,
|
||||
hairstyle_back=-97,
|
||||
bag_back=-98,
|
||||
earring_back=-99,
|
||||
)
|
||||
self.clothing_start_num = 10
|
||||
if not isinstance(item_list, list):
|
||||
raise ValueError('item_list must be a list!')
|
||||
for cate in item_list:
|
||||
cate = cate.lower()
|
||||
if cate not in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms'):
|
||||
raise ValueError(f'Item type error. Cannot recognize {cate}')
|
||||
for i, cate in enumerate(item_list):
|
||||
cate = cate.lower()
|
||||
self._priority[f'{cate}_front'] = self.clothing_start_num - i
|
||||
self._priority[f'{cate}_back'] = -(self.clothing_start_num - i)
|
||||
|
||||
@property
|
||||
def priority(self):
|
||||
return self._priority
|
||||
@@ -1,16 +0,0 @@
|
||||
from .builder import ITEMS, build_item
|
||||
from .clothing import Clothing # 4.0 sec
|
||||
from .body import Body
|
||||
from .top import Top, Blouse, Outwear, Dress
|
||||
from .bottom import Bottom, Trousers, Skirt
|
||||
from .shoes import Shoes
|
||||
from .bag import Bag
|
||||
from .accessories import Hairstyle, Earring
|
||||
|
||||
__all__ = [
|
||||
'ITEMS', 'build_item',
|
||||
'Clothing', 'Body',
|
||||
'Top', 'Blouse', 'Outwear', 'Dress',
|
||||
'Bottom', 'Trousers', 'Skirt',
|
||||
'Shoes', 'Bag', 'Hairstyle', 'Earring'
|
||||
]
|
||||
@@ -1,59 +0,0 @@
|
||||
from .builder import ITEMS
|
||||
from .clothing import Clothing
|
||||
|
||||
|
||||
@ITEMS.register_module()
|
||||
class Hairstyle(Clothing):
|
||||
def __init__(self, **kwargs):
|
||||
pipeline = [
|
||||
dict(type='LoadImageFromFile', path=kwargs['path']),
|
||||
dict(type='KeypointDetection'),
|
||||
dict(type='ContourDetection'),
|
||||
dict(type='Painting'),
|
||||
dict(type='Scaling'),
|
||||
dict(type='Split'),
|
||||
# dict(type='ImageShow', key=['image', 'mask', 'pattern_image']),
|
||||
]
|
||||
kwargs.update(pipeline=pipeline)
|
||||
super(Hairstyle, self).__init__(**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def calculate_start_point(keypoint_type, scale, clothes_point, body_point):
|
||||
"""
|
||||
align up
|
||||
Args:
|
||||
keypoint_type: string, "head_point"
|
||||
scale: float
|
||||
clothes_point: dict{'left': [x1, y1, z1], 'right': [x2, y2, z2]}
|
||||
body_point: dict, containing keypoint data of body figure
|
||||
|
||||
Returns:
|
||||
start_point: tuple (x', y')
|
||||
x' = y_body - y1 * scale
|
||||
y' = x_body - x1 * scale
|
||||
"""
|
||||
side_indicator = f'{keypoint_type}_up'
|
||||
# clothes_point = {k: tuple(map(lambda x: int(scale * x), v[0: 2])) for k, v in clothes_point.items()}
|
||||
# logging.info(clothes_point[side_indicator])
|
||||
|
||||
start_point = (
|
||||
int(body_point[side_indicator][1] - int(clothes_point[side_indicator].split("_")[1] * scale)),
|
||||
int(body_point[side_indicator][0] - int(clothes_point[side_indicator].split("_")[0] * scale))
|
||||
)
|
||||
return start_point
|
||||
|
||||
|
||||
@ITEMS.register_module()
|
||||
class Earring(Clothing):
|
||||
def __init__(self, **kwargs):
|
||||
pipeline = [
|
||||
dict(type='LoadImageFromFile', path=kwargs['path']),
|
||||
dict(type='KeypointDetection'),
|
||||
dict(type='ContourDetection'),
|
||||
dict(type='Painting'),
|
||||
dict(type='Scaling'),
|
||||
dict(type='Split'),
|
||||
# dict(type='ImageShow', key=['image', 'mask', 'pattern_image']),
|
||||
]
|
||||
kwargs.update(pipeline=pipeline)
|
||||
super(Earring, self).__init__(**kwargs)
|
||||
@@ -1,45 +0,0 @@
|
||||
import random
|
||||
|
||||
from .builder import ITEMS
|
||||
from .clothing import Clothing
|
||||
|
||||
|
||||
@ITEMS.register_module()
|
||||
class Bag(Clothing):
|
||||
def __init__(self, **kwargs):
|
||||
pipeline = [
|
||||
dict(type='LoadImageFromFile', path=kwargs['path'], color=kwargs['color']),
|
||||
dict(type='KeypointDetection'),
|
||||
dict(type='ContourDetection'),
|
||||
dict(type='Painting'),
|
||||
dict(type='Scaling'),
|
||||
dict(type='Split'),
|
||||
# dict(type='ImageShow', key=['image', 'mask', 'pattern_image']),
|
||||
]
|
||||
kwargs.update(pipeline=pipeline)
|
||||
super(Bag, self).__init__(**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def calculate_start_point(keypoint_type, scale, clothes_point, body_point):
|
||||
"""
|
||||
align left
|
||||
Args:
|
||||
keypoint_type: string, "hand_point"
|
||||
scale: float
|
||||
clothes_point: dict{'left': [x1, y1, z1], 'right': [x2, y2, z2]}
|
||||
body_point: dict, containing keypoint data of body figure
|
||||
|
||||
Returns:
|
||||
start_point: tuple (y', x')
|
||||
x' = y_body - y1 * scale
|
||||
y' = x_body - x1 * scale
|
||||
"""
|
||||
location = random.choice(seq=['left', 'right'])
|
||||
if location == 'left':
|
||||
side_indicator = f'{keypoint_type}_left'
|
||||
else:
|
||||
side_indicator = f'{keypoint_type}_right'
|
||||
# clothes_point = {k: tuple(map(lambda x: int(scale * x), v[0: 2])) for k, v in clothes_point.items()}
|
||||
start_point = (body_point[side_indicator][1] - int(int(clothes_point[keypoint_type].split("_")[1]) * scale),
|
||||
body_point[side_indicator][0] - int(int(clothes_point[keypoint_type].split("_")[0]) * scale))
|
||||
return start_point
|
||||
@@ -1,36 +0,0 @@
|
||||
import cv2
|
||||
|
||||
from .builder import ITEMS
|
||||
from .pipelines import Compose
|
||||
|
||||
|
||||
@ITEMS.register_module()
|
||||
class Body(object):
|
||||
def __init__(self, **kwargs):
|
||||
pipeline = [
|
||||
dict(type='LoadBodyImageFromFile', body_path=kwargs['body_path']),
|
||||
# dict(type='ImageShow', key=['body_image', "body_mask"])
|
||||
]
|
||||
self.pipeline = Compose(pipeline)
|
||||
self.result = dict()
|
||||
|
||||
def process(self):
|
||||
self.pipeline(self.result)
|
||||
pass
|
||||
|
||||
def organize(self, layer):
|
||||
body_layer = dict(priority=0,
|
||||
name=type(self).__name__.lower(),
|
||||
image=self.result['body_image'],
|
||||
image_url=self.result['image_url'],
|
||||
mask_image=None,
|
||||
mask_url=None,
|
||||
sacle=1,
|
||||
# mask=self.result['body_mask'],
|
||||
position=(0, 0))
|
||||
layer.insert(body_layer)
|
||||
|
||||
@staticmethod
|
||||
def show(img):
|
||||
cv2.imshow('', img)
|
||||
cv2.waitKey(0)
|
||||
@@ -1,39 +0,0 @@
|
||||
from .builder import ITEMS
|
||||
from .clothing import Clothing
|
||||
|
||||
|
||||
@ITEMS.register_module()
|
||||
class Bottom(Clothing):
|
||||
def __init__(self, pipeline, **kwargs):
|
||||
if pipeline is None:
|
||||
pipeline = [
|
||||
dict(type='LoadImageFromFile', path=kwargs['path'], color=kwargs['color'], print_dict=kwargs['print']),
|
||||
dict(type='KeypointDetection'),
|
||||
dict(type='ContourDetection'),
|
||||
# dict(type='Segmentation'),
|
||||
dict(type='Painting', painting_flag=True),
|
||||
dict(type='PrintPainting', print_flag=True),
|
||||
dict(type='Scaling'),
|
||||
dict(type='Split'),
|
||||
# dict(type='ImageShow', key=['image', 'mask', 'pattern_image', 'print_image']),
|
||||
]
|
||||
kwargs.update(pipeline=pipeline)
|
||||
super(Bottom, self).__init__(**kwargs)
|
||||
|
||||
|
||||
@ITEMS.register_module()
|
||||
class Trousers(Bottom):
|
||||
def __init__(self, pipeline=None, **kwargs):
|
||||
super(Trousers, self).__init__(pipeline, **kwargs)
|
||||
|
||||
|
||||
@ITEMS.register_module()
|
||||
class Skirt(Bottom):
|
||||
def __init__(self, pipeline=None, **kwargs):
|
||||
super(Skirt, self).__init__(pipeline, **kwargs)
|
||||
|
||||
|
||||
@ITEMS.register_module()
|
||||
class Bottoms(Bottom):
|
||||
def __init__(self, pipeline=None, **kwargs):
|
||||
super(Bottoms, self).__init__(pipeline, **kwargs)
|
||||
@@ -1,9 +0,0 @@
|
||||
from mmcv.utils import Registry, build_from_cfg
|
||||
|
||||
ITEMS = Registry('item')
|
||||
PIPELINES = Registry('pipeline')
|
||||
|
||||
|
||||
def build_item(cfg, default_args=None):
|
||||
item = build_from_cfg(cfg, ITEMS, default_args)
|
||||
return item
|
||||
@@ -1,100 +0,0 @@
|
||||
import cv2
|
||||
|
||||
from app.core.config import PRIORITY_DICT
|
||||
from .builder import ITEMS
|
||||
from .pipelines import Compose
|
||||
|
||||
|
||||
@ITEMS.register_module()
|
||||
class Clothing(object):
|
||||
def __init__(self, pipeline, **kwargs):
|
||||
self.pipeline = Compose(pipeline)
|
||||
self.result = dict(name=type(self).__name__.lower(), **kwargs)
|
||||
|
||||
def process(self):
|
||||
self.pipeline(self.result)
|
||||
|
||||
def apply_scale(self, img):
|
||||
scale = self.result['scale']
|
||||
height, width = img.shape[0: 2]
|
||||
if len(img.shape) > 2:
|
||||
height, width = img.shape[0: 2]
|
||||
scaled_img = cv2.resize(img, (int(width * scale), int(height * scale)), interpolation=cv2.INTER_AREA)
|
||||
return scaled_img
|
||||
|
||||
def organize(self, layer):
|
||||
start_point = self.calculate_start_point(self.result['keypoint'], self.result['scale'], self.result['clothes_keypoint'], self.result['body_point_test'], self.result["offset"], self.result["resize_scale"])
|
||||
|
||||
front_layer = dict(priority=self.result.get("priority", None) if self.result.get("layer_order", False) else PRIORITY_DICT.get(f'{type(self).__name__.lower()}_front', None),
|
||||
name=f'{type(self).__name__.lower()}_front',
|
||||
image=self.result["front_image"],
|
||||
# mask_image=self.result['front_mask_image'],
|
||||
image_url=self.result['front_image_url'],
|
||||
mask_url=self.result['mask_url'],
|
||||
sacle=self.result['scale'],
|
||||
clothes_keypoint=self.result['clothes_keypoint'],
|
||||
position=start_point,
|
||||
resize_scale=self.result["resize_scale"],
|
||||
mask=cv2.resize(self.result['mask'], self.result["front_image"].size),
|
||||
gradient_string=self.result['gradient_string'] if 'gradient_string' in self.result.keys() else "",
|
||||
pattern_image_url=self.result['pattern_image_url'],
|
||||
pattern_image=self.result['pattern_image']
|
||||
|
||||
)
|
||||
layer.insert(front_layer)
|
||||
|
||||
back_layer = dict(priority=-self.result.get("priority", 0) if self.result.get("layer_order", False) else PRIORITY_DICT.get(f'{type(self).__name__.lower()}_back', None),
|
||||
name=f'{type(self).__name__.lower()}_back',
|
||||
image=self.result["back_image"],
|
||||
# mask_image=self.result['back_mask_image'],
|
||||
image_url=self.result['back_image_url'],
|
||||
mask_url=self.result['mask_url'],
|
||||
sacle=self.result['scale'],
|
||||
clothes_keypoint=self.result['clothes_keypoint'],
|
||||
position=start_point,
|
||||
resize_scale=self.result["resize_scale"],
|
||||
mask=cv2.resize(self.result['mask'], self.result["front_image"].size),
|
||||
gradient_string=self.result['gradient_string'] if 'gradient_string' in self.result.keys() else "",
|
||||
pattern_image_url=self.result['pattern_image_url'],
|
||||
)
|
||||
layer.insert(back_layer)
|
||||
|
||||
@staticmethod
|
||||
def calculate_start_point(keypoint_type, scale, clothes_point, body_point, offset, resize_scale):
|
||||
"""
|
||||
Align left
|
||||
Args:
|
||||
keypoint_type: string, "waistband" | "shoulder" | "ear_point"
|
||||
scale: float
|
||||
clothes_point: dict{'left': [x1, y1, z1], 'right': [x2, y2, z2]}
|
||||
body_point: dict, containing keypoint data of body figure
|
||||
|
||||
Returns:
|
||||
start_point: tuple (x', y')
|
||||
x' = y_body - y1 * scale + offset
|
||||
y' = x_body - x1 * scale + offset
|
||||
|
||||
"""
|
||||
|
||||
side_indicator = f'{keypoint_type}_left'
|
||||
|
||||
# if keypoint_type == "ear_point":
|
||||
# start_point = (body_point[side_indicator][1] - int(int(clothes_point[side_indicator].split("_")[1]) * scale),
|
||||
# body_point[side_indicator][0] - int(int(clothes_point[side_indicator].split("_")[0]) * scale))
|
||||
# else:
|
||||
# start_point = (
|
||||
# int(body_point[side_indicator][1] + offset[1] - int(clothes_point[side_indicator].split("_")[0]) * scale), # y
|
||||
# int(body_point[side_indicator][0] + offset[0] - int(clothes_point[side_indicator].split("_")[1]) * scale) # x
|
||||
# )
|
||||
|
||||
# milvus_DB_keypoint_cache:
|
||||
start_point = (
|
||||
int(body_point[side_indicator][1] + offset[1] - int(clothes_point[side_indicator][0]) * scale), # y
|
||||
int(body_point[side_indicator][0] + offset[0] - int(clothes_point[side_indicator][1]) * scale) # x
|
||||
)
|
||||
# start_point = (
|
||||
# int(body_point[side_indicator][1] + offset[1] - int(clothes_point[side_indicator].split("_")[0]) * scale), # y
|
||||
# int(body_point[side_indicator][0] + offset[0] - int(clothes_point[side_indicator].split("_")[1]) * scale) # x
|
||||
# )
|
||||
|
||||
return start_point
|
||||
@@ -1,19 +0,0 @@
|
||||
from .compose import Compose
|
||||
from .loading import LoadImageFromFile, LoadBodyImageFromFile, ImageShow
|
||||
from .keypoints import KeypointDetection
|
||||
from .segmentation import Segmentation
|
||||
from .painting import Painting, PrintPainting
|
||||
from .scale import Scaling
|
||||
from .contour_detection import ContourDetection
|
||||
from .split import Split
|
||||
|
||||
__all__ = [
|
||||
'Compose',
|
||||
'LoadImageFromFile', 'LoadBodyImageFromFile', 'ImageShow',
|
||||
'KeypointDetection',
|
||||
'Segmentation',
|
||||
'Painting', 'PrintPainting',
|
||||
'Scaling',
|
||||
'ContourDetection',
|
||||
'split',
|
||||
]
|
||||
@@ -1,36 +0,0 @@
|
||||
import collections
|
||||
|
||||
from mmcv.utils import build_from_cfg
|
||||
|
||||
from ..builder import PIPELINES
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class Compose(object):
|
||||
def __init__(self, transforms):
|
||||
assert isinstance(transforms, collections.abc.Sequence)
|
||||
self.transforms = []
|
||||
for transform in transforms:
|
||||
if isinstance(transform, dict):
|
||||
transform = build_from_cfg(transform, PIPELINES)
|
||||
self.transforms.append(transform)
|
||||
elif callable(transform):
|
||||
self.transforms.append(transform)
|
||||
else:
|
||||
raise TypeError('transform must be callable or a dict')
|
||||
|
||||
def __call__(self, data):
|
||||
"""Call function to apply transforms sequentially.
|
||||
|
||||
Args:
|
||||
data (dict): A result dict contains the data to transform.
|
||||
|
||||
Returns:
|
||||
dict: Transformed data.
|
||||
"""
|
||||
|
||||
for t in self.transforms:
|
||||
data = t(data)
|
||||
if data is None:
|
||||
return None
|
||||
return data
|
||||
@@ -1,59 +0,0 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from ..builder import PIPELINES
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class ContourDetection(object):
|
||||
def __init__(self):
|
||||
# logging.info("ContourDetection run ")
|
||||
pass
|
||||
|
||||
# @ RunTime
|
||||
def __call__(self, result):
|
||||
# shoe diff
|
||||
if result['name'] == 'shoes':
|
||||
Contour = self.get_contours(result['image'])
|
||||
Mask = np.zeros(result['image'].shape[:2], np.uint8)
|
||||
for i in range(2):
|
||||
Max_contour = Contour[i]
|
||||
Epsilon = 0.001 * cv2.arcLength(Max_contour, True)
|
||||
Approx = cv2.approxPolyDP(Max_contour, Epsilon, True)
|
||||
cv2.drawContours(Mask, [Approx], -1, 255, -1)
|
||||
if result['pre_mask'] is None:
|
||||
result['mask'] = Mask
|
||||
else:
|
||||
result['mask'] = cv2.bitwise_and(Mask, result['pre_mask'])
|
||||
else:
|
||||
Contour = self.get_contours(result['image'])
|
||||
Mask = np.zeros(result['image'].shape[:2], np.uint8)
|
||||
if len(Contour):
|
||||
Max_contour = Contour[0]
|
||||
Epsilon = 0.001 * cv2.arcLength(Max_contour, True)
|
||||
Approx = cv2.approxPolyDP(Max_contour, Epsilon, True)
|
||||
cv2.drawContours(Mask, [Approx], -1, 255, -1)
|
||||
else:
|
||||
Mask = np.ones(result['image'].shape[:2], np.uint8) * 255
|
||||
# TODO 修复部分图片出现透明的情况 下版本上线
|
||||
# img2gray = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY)
|
||||
# ret, Mask = cv2.threshold(img2gray, 126, 255, cv2.THRESH_BINARY)
|
||||
# Mask = cv2.bitwise_not(Mask)
|
||||
if result['pre_mask'] is None:
|
||||
result['mask'] = Mask
|
||||
else:
|
||||
result['mask'] = cv2.bitwise_and(Mask, result['pre_mask'])
|
||||
result['front_mask'] = result['mask']
|
||||
result['back_mask'] = result['mask']
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def get_contours(image):
|
||||
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
||||
Edge = cv2.Canny(gray, 10, 150)
|
||||
kernel = np.ones((5, 5), np.uint8)
|
||||
Edge = cv2.dilate(Edge, kernel=kernel, iterations=1)
|
||||
Edge = cv2.erode(Edge, kernel=kernel, iterations=1)
|
||||
Contour, _ = cv2.findContours(Edge, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
Contour = sorted(Contour, key=cv2.contourArea, reverse=True)
|
||||
return Contour
|
||||
@@ -1,140 +0,0 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
from pymilvus import MilvusClient
|
||||
|
||||
from app.core.config import *
|
||||
from app.service.utils.decorator import RunTime, ClassCallRunTime
|
||||
from ..builder import PIPELINES
|
||||
from ...utils.design_ensemble import get_keypoint_result
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class KeypointDetection(object):
|
||||
"""
|
||||
path here: abstract path
|
||||
"""
|
||||
|
||||
# def __init__(self):
|
||||
# self.client = MilvusClient(
|
||||
# uri="http://10.1.1.240:19530",
|
||||
# token="root:Milvus",
|
||||
# db_name=MILVUS_ALIAS
|
||||
# )
|
||||
|
||||
# def __del__(self):
|
||||
# start_time = time.time()
|
||||
# self.client.close()
|
||||
# print(f"client close time : {time.time() - start_time}")
|
||||
|
||||
# @ClassCallRunTime
|
||||
def __call__(self, result):
|
||||
# logging.info("KeypointDetection run ")
|
||||
if result['name'] in ['blouse', 'skirt', 'dress', 'outwear', 'trousers', 'tops', 'bottoms']: # 查询是否有数据 且类别相同 相同则直接读 不同则推理后更新
|
||||
# result['clothes_keypoint'] = self.infer_keypoint_result(result)
|
||||
site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
|
||||
# keypoint_cache = search_keypoint_cache(result["image_id"], site)
|
||||
|
||||
keypoint_cache = self.keypoint_cache(result, site)
|
||||
# 取消向量查询 直接过模型推理
|
||||
# keypoint_cache = False
|
||||
|
||||
if keypoint_cache is False:
|
||||
keypoint_infer_result, site = self.infer_keypoint_result(result)
|
||||
result['clothes_keypoint'] = self.save_keypoint_cache(result["image_id"], keypoint_infer_result, site)
|
||||
else:
|
||||
result['clothes_keypoint'] = keypoint_cache
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def infer_keypoint_result(result):
|
||||
site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
|
||||
start_time = time.time()
|
||||
keypoint_infer_result = get_keypoint_result(result["image"], site) # 推理结果
|
||||
# logging.info(f"infer keypoint time : {time.time() - start_time}")
|
||||
return keypoint_infer_result, site
|
||||
|
||||
@staticmethod
|
||||
# @ RunTime
|
||||
def save_keypoint_cache(keypoint_id, cache, site):
|
||||
if site == "down":
|
||||
zeros = np.zeros(20, dtype=int)
|
||||
result = np.concatenate([zeros, cache.flatten()])
|
||||
else:
|
||||
zeros = np.zeros(4, dtype=int)
|
||||
result = np.concatenate([cache.flatten(), zeros])
|
||||
# 取消向量保存 直接拿结果
|
||||
data = [
|
||||
{"keypoint_id": keypoint_id,
|
||||
"keypoint_site": site,
|
||||
"keypoint_vector": result.tolist()
|
||||
}
|
||||
]
|
||||
try:
|
||||
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
|
||||
# start_time = time.time()
|
||||
res = client.upsert(collection_name=MILVUS_TABLE_KEYPOINT, data=data)
|
||||
# logging.info(f"save keypoint time : {time.time() - start_time}")
|
||||
client.close()
|
||||
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
||||
except Exception as e:
|
||||
logging.info(f"save keypoint cache milvus error : {e}")
|
||||
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
||||
|
||||
@staticmethod
|
||||
def update_keypoint_cache(keypoint_id, infer_result, search_result, site):
|
||||
if site == "up":
|
||||
# 需要的是up 即推理出来的是up 那么查询的就是down
|
||||
result = np.concatenate([infer_result.flatten(), search_result[-4:]])
|
||||
else:
|
||||
# 需要的是down 即推理出来的是down 那么查询的就是up
|
||||
result = np.concatenate([search_result[:20], infer_result.flatten()])
|
||||
data = [
|
||||
{"keypoint_id": keypoint_id,
|
||||
"keypoint_site": "all",
|
||||
"keypoint_vector": result.tolist()
|
||||
}
|
||||
]
|
||||
|
||||
try:
|
||||
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
|
||||
# connections.connect(alias=MILVUS_ALIAS, host=MILVUS_DB_HOST, port=MILVUS_PORT)
|
||||
start_time = time.time()
|
||||
# collection = Collection(MILVUS_TABLE_KEYPOINT) # Get an existing collection.
|
||||
# mr = collection.upsert(data)
|
||||
client.upsert(
|
||||
collection_name=MILVUS_TABLE_KEYPOINT,
|
||||
data=data
|
||||
)
|
||||
# logging.info(f"save keypoint time : {time.time() - start_time}")
|
||||
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
||||
except Exception as e:
|
||||
logging.info(f"save keypoint cache milvus error : {e}")
|
||||
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
||||
|
||||
# @ RunTime
|
||||
def keypoint_cache(self, result, site):
|
||||
try:
|
||||
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
|
||||
keypoint_id = result['image_id']
|
||||
res = client.query(
|
||||
collection_name=MILVUS_TABLE_KEYPOINT,
|
||||
# ids=[keypoint_id],
|
||||
filter=f"keypoint_id == {keypoint_id}",
|
||||
output_fields=['keypoint_vector', 'keypoint_site']
|
||||
)
|
||||
if len(res) == 0:
|
||||
# 没有结果 直接推理拿结果 并保存
|
||||
keypoint_infer_result, site = self.infer_keypoint_result(result)
|
||||
return self.save_keypoint_cache(result['image_id'], keypoint_infer_result, site)
|
||||
elif res[0]["keypoint_site"] == "all" or res[0]["keypoint_site"] == site:
|
||||
# 需要的类型和查询的类型一致,或者查询的类型为all 则直接返回查询的结果
|
||||
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, np.array(res[0]['keypoint_vector']).astype(int).reshape(12, 2).tolist()))
|
||||
elif res[0]["keypoint_site"] != site:
|
||||
# 需要的类型和查询到的不一致,则更新类型为all
|
||||
keypoint_infer_result, site = self.infer_keypoint_result(result)
|
||||
return self.update_keypoint_cache(result["image_id"], keypoint_infer_result, res[0]['keypoint_vector'], site)
|
||||
except Exception as e:
|
||||
logging.info(f"search keypoint cache milvus error {e}")
|
||||
return False
|
||||
@@ -1,134 +0,0 @@
|
||||
import cv2
|
||||
|
||||
from app.service.utils.oss_client import oss_get_image
|
||||
from ..builder import PIPELINES
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class LoadImageFromFile(object):
|
||||
def __init__(self, path, color=None, print_dict=None):
|
||||
self.path = path
|
||||
self.color = color
|
||||
self.print_dict = print_dict
|
||||
# self.minio_client = Minio(f"{MINIO_URL}", access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
|
||||
# @ClassCallRunTime
|
||||
def __call__(self, result):
|
||||
result['image'], result['pre_mask'] = self.read_image(self.path)
|
||||
result['gray'] = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY)
|
||||
result['keypoint'] = self.get_keypoint(result['name'])
|
||||
result['path'] = self.path
|
||||
result['img_shape'] = result['image'].shape
|
||||
result['ori_shape'] = result['image'].shape
|
||||
result['color'] = self.color if self.color is not None else None
|
||||
result['print_dict'] = self.print_dict
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def get_keypoint(name):
|
||||
if name == 'blouse' or name == 'outwear' or name == 'dress' or name == 'tops':
|
||||
keypoint = 'shoulder'
|
||||
elif name == 'trousers' or name == 'skirt' or name == 'bottoms':
|
||||
keypoint = 'waistband'
|
||||
elif name == 'bag':
|
||||
keypoint = 'hand_point'
|
||||
elif name == 'shoes':
|
||||
keypoint = 'toe'
|
||||
elif name == 'hairstyle':
|
||||
keypoint = 'head_point'
|
||||
elif name == 'earring':
|
||||
keypoint = 'ear_point'
|
||||
else:
|
||||
raise KeyError(f"{name} does not belong to item category list: blouse, outwear, dress, trousers, skirt, "
|
||||
f"bag, shoes, hairstyle, earring.")
|
||||
return keypoint
|
||||
|
||||
@staticmethod
|
||||
def read_image(image_path):
|
||||
image_mask = None
|
||||
image = oss_get_image(bucket=image_path.split("/", 1)[0], object_name=image_path.split("/", 1)[1], data_type="cv2")
|
||||
if len(image.shape) == 2:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
||||
if image.shape[2] == 4: # 如果是四通道 mask
|
||||
image_mask = image[:, :, 3]
|
||||
image = image[:, :, :3]
|
||||
|
||||
if image.shape[:2] <= (50, 50):
|
||||
# 计算新尺寸
|
||||
new_size = (image.shape[1] * 2, image.shape[0] * 2)
|
||||
# 调整大小
|
||||
image = cv2.resize(image, new_size, interpolation=cv2.INTER_LINEAR)
|
||||
return image, image_mask
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class LoadBodyImageFromFile(object):
|
||||
def __init__(self, body_path):
|
||||
self.body_path = body_path
|
||||
# self.minioClient = Minio(f"{MINIO_URL}", access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
|
||||
# response = self.minioClient.get_object("aida-mannequins", "model_1693218345.2714431.png")
|
||||
|
||||
# @ RunTime
|
||||
def __call__(self, result):
|
||||
result["image_url"] = result['body_path'] = self.body_path
|
||||
result["name"] = "mannequin"
|
||||
# if not result['image_url'].lower().endswith(".png"):
|
||||
# bucket = self.body_path.split("/", 1)[0]
|
||||
# object_name = self.body_path.split("/", 1)[1]
|
||||
# new_object_name = f'{object_name[:object_name.rfind(".")]}.png'
|
||||
# image = self.minioClient.get_object(bucket, object_name)
|
||||
# image = Image.open(io.BytesIO(image.data))
|
||||
# image = image.convert("RGBA")
|
||||
# data = image.getdata()
|
||||
# #
|
||||
# new_data = []
|
||||
# for item in data:
|
||||
# if item[0] >= 230 and item[1] >= 230 and item[2] >= 230:
|
||||
# new_data.append((255, 255, 255, 0))
|
||||
# else:
|
||||
# new_data.append(item)
|
||||
# image.putdata(new_data)
|
||||
# image_data = io.BytesIO()
|
||||
# image.save(image_data, format='PNG')
|
||||
# image_data.seek(0)
|
||||
# image_bytes = image_data.read()
|
||||
# image_path = f"{bucket}/{self.minioClient.put_object(bucket, new_object_name, io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
|
||||
# self.body_path = image_path
|
||||
# result["image_url"] = result['body_path'] = self.body_path
|
||||
# response = self.minioClient.get_object(self.body_path.split("/", 1)[0], self.body_path.split("/", 1)[1])
|
||||
# put_image_time = time.time()
|
||||
# result['body_image'] = Image.open(io.BytesIO(response.read()))
|
||||
result['body_image'] = oss_get_image(bucket=self.body_path.split("/", 1)[0], object_name=self.body_path.split("/", 1)[1], data_type="PIL")
|
||||
# logging.info(f"Image.open time is : {time.time() - put_image_time}")
|
||||
return result
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class ImageShow(object):
|
||||
def __init__(self, key):
|
||||
self.key = key
|
||||
|
||||
# @ RunTime
|
||||
def __call__(self, result):
|
||||
import matplotlib.pyplot as plt
|
||||
if isinstance(self.key, list):
|
||||
for key in self.key:
|
||||
plt.imshow(result[key])
|
||||
plt.title(key)
|
||||
plt.show()
|
||||
elif isinstance(self.key, str):
|
||||
img = self._resize_img(result[self.key])
|
||||
cv2.imshow(self.key, img)
|
||||
cv2.waitKey(0)
|
||||
else:
|
||||
raise TypeError(f'key should be string but got type {type(self.key)}.')
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _resize_img(img):
|
||||
shape = img.shape
|
||||
if shape[0] > 400 or shape[1] > 400:
|
||||
ratio = min(400 / shape[0], 400 / shape[1])
|
||||
img = cv2.resize(img, (int(ratio * shape[1]), int(ratio * shape[0])))
|
||||
return img
|
||||
@@ -1,605 +0,0 @@
|
||||
import logging
|
||||
import random
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from app.service.utils.oss_client import oss_get_image
|
||||
from ..builder import PIPELINES
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class Painting(object):
|
||||
def __init__(self, painting_flag=True):
|
||||
self.painting_flag = painting_flag
|
||||
|
||||
# @ClassCallRunTime
|
||||
def __call__(self, result):
|
||||
if result['name'] not in ['hairstyle', 'earring'] and self.painting_flag and result['color'] != 'none':
|
||||
dim_image_h, dim_image_w = result['image'].shape[0:2]
|
||||
if "gradient" in result.keys() and result['gradient'] != "":
|
||||
bucket_name = result['gradient'].split('/')[0]
|
||||
object_name = result['gradient'][result['gradient'].find('/') + 1:]
|
||||
pattern = self.get_gradient(bucket_name=bucket_name, object_name=object_name)
|
||||
resize_pattern = cv2.resize(pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA)
|
||||
else:
|
||||
pattern = self.get_pattern(result['color'])
|
||||
resize_pattern = cv2.resize(pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA)
|
||||
closed_mo = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
|
||||
gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
|
||||
get_image_fir = resize_pattern * (closed_mo / 255) * (gray_mo / 255)
|
||||
result['pattern_image'] = get_image_fir.astype(np.uint8)
|
||||
result['final_image'] = result['pattern_image']
|
||||
canvas = np.full_like(result['final_image'], 255)
|
||||
temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2)
|
||||
tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8)
|
||||
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
|
||||
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
|
||||
result['single_image'] = cv2.add(tmp1, tmp2)
|
||||
result['alpha'] = 100 / 255.0
|
||||
else:
|
||||
closed_mo = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
|
||||
get_image_fir = result['image'] * (closed_mo / 255)
|
||||
result['pattern_image'] = get_image_fir.astype(np.uint8)
|
||||
result['final_image'] = result['pattern_image']
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def get_gradient(bucket_name, object_name):
|
||||
# image_data = minio_client.get_object(bucket_name, object_name)
|
||||
# image_data = s3.get_object(Bucket=bucket_name, Key=object_name)['Body']
|
||||
|
||||
# 从数据流中读取图像
|
||||
# image_bytes = image_data.read()
|
||||
|
||||
# 将图像数据转换为numpy数组
|
||||
# image_array = np.asarray(bytearray(image_bytes), dtype=np.uint8)
|
||||
|
||||
# 使用OpenCV解码图像数组
|
||||
# image = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
|
||||
image = oss_get_image(bucket=bucket_name, object_name=object_name, data_type="cv2")
|
||||
if image.shape[2] == 4:
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGRA2BGR)
|
||||
return image
|
||||
|
||||
@staticmethod
|
||||
def crop_image(image, image_size_h, image_size_w):
|
||||
x_offset = np.random.randint(low=0, high=int(image_size_h / 5) - 6)
|
||||
y_offset = np.random.randint(low=0, high=int(image_size_w / 5) - 6)
|
||||
image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w, :]
|
||||
return image
|
||||
|
||||
@staticmethod
|
||||
def get_pattern(single_color):
|
||||
if single_color is None:
|
||||
raise False
|
||||
R, G, B = single_color.split(' ')
|
||||
pattern = np.zeros([1, 1, 3], np.uint8)
|
||||
pattern[0, 0, 0] = int(B)
|
||||
pattern[0, 0, 1] = int(G)
|
||||
pattern[0, 0, 2] = int(R)
|
||||
return pattern
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class PrintPainting(object):
|
||||
def __init__(self, print_flag=True):
|
||||
self.print_flag = print_flag
|
||||
|
||||
# @ClassCallRunTime
|
||||
def __call__(self, result):
|
||||
single_print = result['print']['single']
|
||||
overall_print = result['print']['overall']
|
||||
element_print = result['print']['element']
|
||||
result['single_image'] = None
|
||||
result['print_image'] = None
|
||||
if overall_print['print_path_list']:
|
||||
painting_dict = {'dim_image_h': result['pattern_image'].shape[0], 'dim_image_w': result['pattern_image'].shape[1]}
|
||||
result['print_image'] = result['pattern_image']
|
||||
if "print_angle_list" in overall_print.keys() and overall_print['print_angle_list'][0] != 0:
|
||||
painting_dict = self.painting_collection(painting_dict, overall_print, print_trigger=True)
|
||||
painting_dict['tile_print'] = self.rotate_crop_image(img=painting_dict['tile_print'], angle=-overall_print['print_angle_list'][0], crop=True)
|
||||
painting_dict['mask_inv_print'] = self.rotate_crop_image(img=painting_dict['mask_inv_print'], angle=-overall_print['print_angle_list'][0], crop=True)
|
||||
|
||||
# resize 到sketch大小
|
||||
painting_dict['tile_print'] = self.resize_and_crop(img=painting_dict['tile_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h'])
|
||||
painting_dict['mask_inv_print'] = self.resize_and_crop(img=painting_dict['mask_inv_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h'])
|
||||
else:
|
||||
painting_dict = self.painting_collection(painting_dict, overall_print, print_trigger=True, is_single=False)
|
||||
result['print_image'] = self.printpaint(result, painting_dict, print_=True)
|
||||
result['single_image'] = result['final_image'] = result['pattern_image'] = result['print_image']
|
||||
|
||||
if single_print['print_path_list']:
|
||||
print_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
|
||||
mask_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
|
||||
for i in range(len(single_print['print_path_list'])):
|
||||
image, image_mode = self.read_image(single_print['print_path_list'][i])
|
||||
if image_mode == "RGBA":
|
||||
new_size = (int(image.width * single_print['print_scale_list'][i]), int(image.height * single_print['print_scale_list'][i]))
|
||||
|
||||
mask = image.split()[3]
|
||||
resized_source = image.resize(new_size)
|
||||
resized_source_mask = mask.resize(new_size)
|
||||
|
||||
rotated_resized_source = resized_source.rotate(-single_print['print_angle_list'][i])
|
||||
rotated_resized_source_mask = resized_source_mask.rotate(-single_print['print_angle_list'][i])
|
||||
|
||||
source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB))
|
||||
source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB))
|
||||
|
||||
source_image_pil.paste(rotated_resized_source, (int(single_print['location'][i][0]), int(single_print['location'][i][1])), rotated_resized_source)
|
||||
source_image_pil_mask.paste(rotated_resized_source_mask, (int(single_print['location'][i][0]), int(single_print['location'][i][1])), rotated_resized_source_mask)
|
||||
|
||||
print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR)
|
||||
mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR)
|
||||
ret, mask_background = cv2.threshold(mask_background, 124, 255, cv2.THRESH_BINARY)
|
||||
else:
|
||||
mask = self.get_mask_inv(image)
|
||||
mask = np.expand_dims(mask, axis=2)
|
||||
mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
|
||||
mask = cv2.bitwise_not(mask)
|
||||
# 旋转后的坐标需要重新算
|
||||
rotate_mask, _ = self.img_rotate(mask, single_print['print_angle_list'][i], single_print['print_scale_list'][i])
|
||||
rotate_image, rotated_new_size = self.img_rotate(image, single_print['print_angle_list'][i], single_print['print_scale_list'][i])
|
||||
# x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2)
|
||||
x, y = int(single_print['location'][i][0] - rotated_new_size[0]), int(single_print['location'][i][1] - rotated_new_size[1])
|
||||
|
||||
image_x = print_background.shape[1]
|
||||
image_y = print_background.shape[0]
|
||||
print_x = rotate_image.shape[1]
|
||||
print_y = rotate_image.shape[0]
|
||||
|
||||
# 有bug
|
||||
# if x + print_x > image_x:
|
||||
# rotate_image = rotate_image[:, :x + print_x - image_x]
|
||||
# rotate_mask = rotate_mask[:, :x + print_x - image_x]
|
||||
# #
|
||||
# if y + print_y > image_y:
|
||||
# rotate_image = rotate_image[:y + print_y - image_y]
|
||||
# rotate_mask = rotate_mask[:y + print_y - image_y]
|
||||
|
||||
# 不能是并行
|
||||
# 当前第一轮的if (108以及115)是判断有没有过下界和右界。第二轮的是判断左上有没有超出。 如果这个样子的话,先裁了右边,再左移,region就会有问题
|
||||
# 先挪 再判断 最后裁剪
|
||||
|
||||
# 如果print旋转了 或者 print贴边了 则需要判断 判断左界和上界是否小于0
|
||||
if x <= 0:
|
||||
rotate_image = rotate_image[:, -x:]
|
||||
rotate_mask = rotate_mask[:, -x:]
|
||||
start_x = x = 0
|
||||
else:
|
||||
start_x = x
|
||||
|
||||
if y <= 0:
|
||||
rotate_image = rotate_image[-y:, :]
|
||||
rotate_mask = rotate_mask[-y:, :]
|
||||
start_y = y = 0
|
||||
else:
|
||||
start_y = y
|
||||
|
||||
# ------------------
|
||||
# 如果print-size大于image-size 则需要裁剪print
|
||||
|
||||
if x + print_x > image_x:
|
||||
rotate_image = rotate_image[:, :image_x - x]
|
||||
rotate_mask = rotate_mask[:, :image_x - x]
|
||||
|
||||
if y + print_y > image_y:
|
||||
rotate_image = rotate_image[:image_y - y, :]
|
||||
rotate_mask = rotate_mask[:image_y - y, :]
|
||||
|
||||
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = cv2.bitwise_xor(mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]], rotate_mask)
|
||||
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = cv2.add(print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]], rotate_image)
|
||||
|
||||
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = rotate_mask
|
||||
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image
|
||||
mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x)
|
||||
print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x)
|
||||
|
||||
# gray_image = cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY)
|
||||
# print_background = cv2.bitwise_and(print_background, print_background, mask=gray_image)
|
||||
|
||||
print_mask = cv2.bitwise_and(result['mask'], cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY))
|
||||
img_fg = cv2.bitwise_or(print_background, print_background, mask=print_mask)
|
||||
img_bg = cv2.bitwise_and(result['pattern_image'], result['pattern_image'], mask=cv2.bitwise_not(print_mask))
|
||||
mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2)
|
||||
gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
|
||||
img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8)
|
||||
result['final_image'] = cv2.add(img_bg, img_fg)
|
||||
canvas = np.full_like(result['final_image'], 255)
|
||||
temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2)
|
||||
tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8)
|
||||
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
|
||||
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
|
||||
result['single_image'] = cv2.add(tmp1, tmp2)
|
||||
|
||||
if element_print['element_path_list']:
|
||||
print_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8)
|
||||
mask_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8)
|
||||
for i in range(len(element_print['element_path_list'])):
|
||||
image, image_mode = self.read_image(element_print['element_path_list'][i])
|
||||
if image_mode == "RGBA":
|
||||
new_size = (int(image.width * element_print['element_scale_list'][i]), int(image.height * element_print['element_scale_list'][i]))
|
||||
|
||||
mask = image.split()[3]
|
||||
resized_source = image.resize(new_size)
|
||||
resized_source_mask = mask.resize(new_size)
|
||||
|
||||
rotated_resized_source = resized_source.rotate(-element_print['element_angle_list'][i])
|
||||
rotated_resized_source_mask = resized_source_mask.rotate(-element_print['element_angle_list'][i])
|
||||
|
||||
source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB))
|
||||
source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB))
|
||||
|
||||
source_image_pil.paste(rotated_resized_source, (int(element_print['location'][i][0]), int(element_print['location'][i][1])), rotated_resized_source)
|
||||
source_image_pil_mask.paste(rotated_resized_source_mask, (int(element_print['location'][i][0]), int(element_print['location'][i][1])), rotated_resized_source_mask)
|
||||
|
||||
print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR)
|
||||
mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR)
|
||||
else:
|
||||
mask = self.get_mask_inv(image)
|
||||
mask = np.expand_dims(mask, axis=2)
|
||||
mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
|
||||
mask = cv2.bitwise_not(mask)
|
||||
# 旋转后的坐标需要重新算
|
||||
rotate_mask, _ = self.img_rotate(mask, element_print['element_angle_list'][i], element_print['element_scale_list'][i])
|
||||
rotate_image, rotated_new_size = self.img_rotate(image, element_print['element_angle_list'][i], element_print['element_scale_list'][i])
|
||||
# x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2)
|
||||
x, y = int(element_print['location'][i][0] - rotated_new_size[0]), int(element_print['location'][i][1] - rotated_new_size[1])
|
||||
|
||||
image_x = print_background.shape[1]
|
||||
image_y = print_background.shape[0]
|
||||
print_x = rotate_image.shape[1]
|
||||
print_y = rotate_image.shape[0]
|
||||
|
||||
# 有bug
|
||||
# if x + print_x > image_x:
|
||||
# rotate_image = rotate_image[:, :x + print_x - image_x]
|
||||
# rotate_mask = rotate_mask[:, :x + print_x - image_x]
|
||||
# #
|
||||
# if y + print_y > image_y:
|
||||
# rotate_image = rotate_image[:y + print_y - image_y]
|
||||
# rotate_mask = rotate_mask[:y + print_y - image_y]
|
||||
|
||||
# 不能是并行
|
||||
# 当前第一轮的if (108以及115)是判断有没有过下界和右界。第二轮的是判断左上有没有超出。 如果这个样子的话,先裁了右边,再左移,region就会有问题
|
||||
# 先挪 再判断 最后裁剪
|
||||
|
||||
# 如果print旋转了 或者 print贴边了 则需要判断 判断左界和上界是否小于0
|
||||
if x <= 0:
|
||||
rotate_image = rotate_image[:, -x:]
|
||||
rotate_mask = rotate_mask[:, -x:]
|
||||
start_x = x = 0
|
||||
else:
|
||||
start_x = x
|
||||
|
||||
if y <= 0:
|
||||
rotate_image = rotate_image[-y:, :]
|
||||
rotate_mask = rotate_mask[-y:, :]
|
||||
start_y = y = 0
|
||||
else:
|
||||
start_y = y
|
||||
|
||||
# ------------------
|
||||
# 如果print-size大于image-size 则需要裁剪print
|
||||
|
||||
if x + print_x > image_x:
|
||||
rotate_image = rotate_image[:, :image_x - x]
|
||||
rotate_mask = rotate_mask[:, :image_x - x]
|
||||
|
||||
if y + print_y > image_y:
|
||||
rotate_image = rotate_image[:image_y - y, :]
|
||||
rotate_mask = rotate_mask[:image_y - y, :]
|
||||
|
||||
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = cv2.bitwise_xor(mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]], rotate_mask)
|
||||
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = cv2.add(print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]], rotate_image)
|
||||
|
||||
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = rotate_mask
|
||||
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image
|
||||
mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x)
|
||||
print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x)
|
||||
|
||||
# gray_image = cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY)
|
||||
# print_background = cv2.bitwise_and(print_background, print_background, mask=gray_image)
|
||||
|
||||
print_mask = cv2.bitwise_and(result['mask'], cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY))
|
||||
img_fg = cv2.bitwise_or(print_background, print_background, mask=print_mask)
|
||||
# TODO element 丢失信息
|
||||
three_channel_image = cv2.merge([cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask)])
|
||||
img_bg = cv2.bitwise_and(result['final_image'], three_channel_image)
|
||||
# mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2)
|
||||
# gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
|
||||
# img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8)
|
||||
result['final_image'] = cv2.add(img_bg, img_fg)
|
||||
canvas = np.full_like(result['final_image'], 255)
|
||||
temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2)
|
||||
tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8)
|
||||
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
|
||||
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
|
||||
result['single_image'] = cv2.add(tmp1, tmp2)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def stack_prin(print_background, pattern_image, rotate_image, start_y, y, start_x, x):
|
||||
temp_print = np.zeros((pattern_image.shape[0], pattern_image.shape[1], 3), dtype=np.uint8)
|
||||
temp_print[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image
|
||||
img2gray = cv2.cvtColor(temp_print, cv2.COLOR_BGR2GRAY)
|
||||
ret, mask_ = cv2.threshold(img2gray, 1, 255, cv2.THRESH_BINARY)
|
||||
mask_inv = cv2.bitwise_not(mask_)
|
||||
img1_bg = cv2.bitwise_and(print_background, print_background, mask=mask_inv)
|
||||
img2_fg = cv2.bitwise_and(temp_print, temp_print, mask=mask_)
|
||||
print_background = img1_bg + img2_fg
|
||||
return print_background
|
||||
|
||||
def painting_collection(self, painting_dict, print_dict, print_trigger=False, is_single=False):
|
||||
if print_trigger:
|
||||
print_ = self.get_print(print_dict)
|
||||
painting_dict['Trigger'] = not is_single
|
||||
painting_dict['location'] = print_['location']
|
||||
single_mask_inv_print = self.get_mask_inv(print_['image'])
|
||||
dim_max = max(painting_dict['dim_image_h'], painting_dict['dim_image_w'])
|
||||
dim_pattern = (int(dim_max * print_['scale'] / 5), int(dim_max * print_['scale'] / 5))
|
||||
if not is_single:
|
||||
self.random_seed = random.randint(0, 1000)
|
||||
# 如果print 模式为overall 且 有角度的话 , 组合的print为正方形,方便裁剪
|
||||
if "print_angle_list" in print_dict.keys() and print_dict['print_angle_list'][0] != 0:
|
||||
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], dim_max, dim_max, painting_dict['location'], trigger=True)
|
||||
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], dim_max, dim_max, painting_dict['location'], trigger=True)
|
||||
else:
|
||||
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'], trigger=True)
|
||||
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'], trigger=True)
|
||||
else:
|
||||
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'])
|
||||
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'])
|
||||
painting_dict['dim_print_h'], painting_dict['dim_print_w'] = dim_pattern
|
||||
return painting_dict
|
||||
|
||||
def tile_image(self, pattern, dim, scale, dim_image_h, dim_image_w, location, trigger=False):
|
||||
tile = None
|
||||
if not trigger:
|
||||
tile = cv2.resize(pattern, dim, interpolation=cv2.INTER_AREA)
|
||||
else:
|
||||
resize_pattern = cv2.resize(pattern, dim, interpolation=cv2.INTER_AREA)
|
||||
if len(pattern.shape) == 2:
|
||||
tile = np.tile(resize_pattern, (int((5 + 1) / scale) + 4, int((5 + 1) / scale) + 4))
|
||||
if len(pattern.shape) == 3:
|
||||
tile = np.tile(resize_pattern, (int((5 + 1) / scale) + 4, int((5 + 1) / scale) + 4, 1))
|
||||
tile = self.crop_image(tile, dim_image_h, dim_image_w, location, resize_pattern.shape)
|
||||
return tile
|
||||
|
||||
def get_mask_inv(self, print_):
|
||||
if print_[0][0][0] == 255 and print_[0][0][1] == 255 and print_[0][0][2] == 255:
|
||||
bg_color = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)[0][0]
|
||||
print_tile = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)
|
||||
bg_l, bg_a, bg_b = bg_color[0], bg_color[1], bg_color[2]
|
||||
bg_L_high, bg_L_low = self.get_low_high_lab(bg_l, L=True)
|
||||
bg_a_high, bg_a_low = self.get_low_high_lab(bg_a)
|
||||
bg_b_high, bg_b_low = self.get_low_high_lab(bg_b)
|
||||
lower = np.array([bg_L_low, bg_a_low, bg_b_low])
|
||||
upper = np.array([bg_L_high, bg_a_high, bg_b_high])
|
||||
mask_inv = cv2.inRange(print_tile, lower, upper)
|
||||
return mask_inv
|
||||
else:
|
||||
# bg_color = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)[0][0]
|
||||
# print_tile = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)
|
||||
# bg_l, bg_a, bg_b = bg_color[0], bg_color[1], bg_color[2]
|
||||
# bg_L_high, bg_L_low = self.get_low_high_lab(bg_l, L=True)
|
||||
# bg_a_high, bg_a_low = self.get_low_high_lab(bg_a)
|
||||
# bg_b_high, bg_b_low = self.get_low_high_lab(bg_b)
|
||||
# lower = np.array([bg_L_low, bg_a_low, bg_b_low])
|
||||
# upper = np.array([bg_L_high, bg_a_high, bg_b_high])
|
||||
|
||||
# print_tile = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)
|
||||
# mask_inv = cv2.cvtColor(print_tile, cv2.COLOR_BGR2GRAY)
|
||||
|
||||
# mask_inv = cv2.cvtColor(print_, cv2.COLOR_BGR2GRAY)
|
||||
mask_inv = np.zeros(print_.shape[:2], dtype=np.uint8)
|
||||
return mask_inv
|
||||
|
||||
@staticmethod
|
||||
def printpaint(result, painting_dict, print_=False):
|
||||
|
||||
if print_ and painting_dict['Trigger']:
|
||||
print_mask = cv2.bitwise_and(result['mask'], cv2.bitwise_not(painting_dict['mask_inv_print']))
|
||||
img_fg = cv2.bitwise_and(painting_dict['tile_print'], painting_dict['tile_print'], mask=print_mask)
|
||||
else:
|
||||
print_mask = result['mask']
|
||||
img_fg = result['final_image']
|
||||
if print_ and not painting_dict['Trigger']:
|
||||
index_ = None
|
||||
try:
|
||||
index_ = len(painting_dict['location'])
|
||||
except:
|
||||
assert f'there must be parameter of location if choose IfSingle'
|
||||
|
||||
for i in range(index_):
|
||||
start_h, start_w = int(painting_dict['location'][i][1]), int(painting_dict['location'][i][0])
|
||||
|
||||
length_h = min(start_h + painting_dict['dim_print_h'], img_fg.shape[0])
|
||||
length_w = min(start_w + painting_dict['dim_print_w'], img_fg.shape[1])
|
||||
|
||||
change_region = img_fg[start_h: length_h, start_w: length_w, :]
|
||||
# problem in change_mask
|
||||
change_mask = print_mask[start_h: length_h, start_w: length_w]
|
||||
# get real part into change mask
|
||||
_, change_mask = cv2.threshold(change_mask, 220, 255, cv2.THRESH_BINARY)
|
||||
mask = cv2.bitwise_not(painting_dict['mask_inv_print'])
|
||||
img_fg[start_h:start_h + painting_dict['dim_print_h'], start_w:start_w + painting_dict['dim_print_w'], :] = change_region
|
||||
|
||||
clothes_mask_print = cv2.bitwise_not(print_mask)
|
||||
|
||||
img_bg = cv2.bitwise_and(result['pattern_image'], result['pattern_image'], mask=clothes_mask_print)
|
||||
mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2)
|
||||
gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
|
||||
img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8)
|
||||
print_image = cv2.add(img_bg, img_fg)
|
||||
return print_image
|
||||
|
||||
@staticmethod
|
||||
def get_print(print_dict):
|
||||
if 'print_scale_list' not in print_dict.keys() or print_dict['print_scale_list'][0] < 0.3:
|
||||
print_dict['scale'] = 0.3
|
||||
else:
|
||||
print_dict['scale'] = print_dict['print_scale_list'][0]
|
||||
|
||||
bucket_name = print_dict['print_path_list'][0].split("/", 1)[0]
|
||||
object_name = print_dict['print_path_list'][0].split("/", 1)[1]
|
||||
image = oss_get_image(bucket=bucket_name, object_name=object_name, data_type="PIL")
|
||||
# 判断图片格式,如果是RGBA 则贴在一张纯白图片上 防止透明转黑
|
||||
if image.mode == "RGBA":
|
||||
new_background = Image.new('RGB', image.size, (255, 255, 255))
|
||||
new_background.paste(image, mask=image.split()[3])
|
||||
image = new_background
|
||||
print_dict['image'] = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
|
||||
return print_dict
|
||||
|
||||
def crop_image(self, image, image_size_h, image_size_w, location, print_shape):
|
||||
print_w = print_shape[1]
|
||||
print_h = print_shape[0]
|
||||
|
||||
random.seed(self.random_seed)
|
||||
# logging.info(f'overall print location : {location}')
|
||||
# x_offset = random.randint(0, image.shape[0] - image_size_h)
|
||||
# y_offset = random.randint(0, image.shape[1] - image_size_w)
|
||||
|
||||
# 1.拿到偏移量后和resize后的print宽高取余 得到真正偏移量
|
||||
x_offset = print_w - int(location[0][1] % print_w)
|
||||
y_offset = print_w - int(location[0][0] % print_h)
|
||||
|
||||
# y_offset = int(location[0][0])
|
||||
# x_offset = int(location[0][1])
|
||||
|
||||
if len(image.shape) == 2:
|
||||
image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w]
|
||||
elif len(image.shape) == 3:
|
||||
image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w, :]
|
||||
return image
|
||||
|
||||
@staticmethod
|
||||
def get_low_high_lab(Lab_value, L=False):
|
||||
if L:
|
||||
high = Lab_value + 30 if Lab_value + 30 < 255 else 255
|
||||
low = Lab_value - 30 if Lab_value - 30 > 0 else 0
|
||||
else:
|
||||
high = Lab_value + 30 if Lab_value + 30 < 255 else 255
|
||||
low = Lab_value - 30 if Lab_value - 30 > 0 else 0
|
||||
return high, low
|
||||
|
||||
@staticmethod
|
||||
def img_rotate(image, angel, scale):
|
||||
"""顺时针旋转图像任意角度
|
||||
|
||||
Args:
|
||||
image (np.array): [原始图像]
|
||||
angel (float): [逆时针旋转的角度]
|
||||
|
||||
Returns:
|
||||
[array]: [旋转后的图像]
|
||||
"""
|
||||
|
||||
h, w = image.shape[:2]
|
||||
center = (w // 2, h // 2)
|
||||
# if type(angel) is not int:
|
||||
# angel = 0
|
||||
M = cv2.getRotationMatrix2D(center, -angel, scale)
|
||||
# 调整旋转后的图像长宽
|
||||
rotated_h = int((w * np.abs(M[0, 1]) + (h * np.abs(M[0, 0]))))
|
||||
rotated_w = int((h * np.abs(M[0, 1]) + (w * np.abs(M[0, 0]))))
|
||||
M[0, 2] += (rotated_w - w) // 2
|
||||
M[1, 2] += (rotated_h - h) // 2
|
||||
# 旋转图像
|
||||
rotated_img = cv2.warpAffine(image, M, (rotated_w, rotated_h))
|
||||
|
||||
return rotated_img, ((rotated_img.shape[1] - image.shape[1] * scale) // 2, (rotated_img.shape[0] - image.shape[0] * scale) // 2)
|
||||
# return rotated_img, (0, 0)
|
||||
|
||||
@staticmethod
|
||||
def rotate_crop_image(img, angle, crop):
|
||||
"""
|
||||
angle: 旋转的角度
|
||||
crop: 是否需要进行裁剪,布尔向量
|
||||
"""
|
||||
crop_image = lambda img, x0, y0, w, h: img[y0:y0 + h, x0:x0 + w]
|
||||
w, h = img.shape[:2]
|
||||
# 旋转角度的周期是360°
|
||||
angle %= 360
|
||||
# 计算仿射变换矩阵
|
||||
M_rotation = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1)
|
||||
# 得到旋转后的图像
|
||||
img_rotated = cv2.warpAffine(img, M_rotation, (w, h))
|
||||
|
||||
# 如果需要去除黑边
|
||||
if crop:
|
||||
# 裁剪角度的等效周期是180°
|
||||
angle_crop = angle % 180
|
||||
if angle > 90:
|
||||
angle_crop = 180 - angle_crop
|
||||
# 转化角度为弧度
|
||||
theta = angle_crop * np.pi / 180
|
||||
# 计算高宽比
|
||||
hw_ratio = float(h) / float(w)
|
||||
# 计算裁剪边长系数的分子项
|
||||
tan_theta = np.tan(theta)
|
||||
numerator = np.cos(theta) + np.sin(theta) * np.tan(theta)
|
||||
|
||||
# 计算分母中和高宽比相关的项
|
||||
r = hw_ratio if h > w else 1 / hw_ratio
|
||||
# 计算分母项
|
||||
denominator = r * tan_theta + 1
|
||||
# 最终的边长系数
|
||||
crop_mult = numerator / denominator
|
||||
|
||||
# 得到裁剪区域
|
||||
w_crop = int(crop_mult * w)
|
||||
h_crop = int(crop_mult * h)
|
||||
x0 = int((w - w_crop) / 2)
|
||||
y0 = int((h - h_crop) / 2)
|
||||
|
||||
img_rotated = crop_image(img_rotated, x0, y0, w_crop, h_crop)
|
||||
|
||||
return img_rotated
|
||||
|
||||
@staticmethod
|
||||
def read_image(image_url):
|
||||
image = oss_get_image(bucket=image_url.split("/", 1)[0], object_name=image_url.split("/", 1)[1], data_type="cv2")
|
||||
if image.shape[2] == 4:
|
||||
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
|
||||
image = Image.fromarray(image_rgb)
|
||||
image_mode = "RGBA"
|
||||
else:
|
||||
image_mode = "RGB"
|
||||
return image, image_mode
|
||||
|
||||
@staticmethod
|
||||
def resize_and_crop(img, target_width, target_height):
|
||||
# 获取原始图像的尺寸
|
||||
original_height, original_width = img.shape[:2]
|
||||
|
||||
# 计算目标尺寸的宽高比
|
||||
target_ratio = target_width / target_height
|
||||
|
||||
# 计算原始图像的宽高比
|
||||
original_ratio = original_width / original_height
|
||||
|
||||
# 调整尺寸
|
||||
if original_ratio > target_ratio:
|
||||
# 原始图像更宽,按高度resize,然后裁剪宽度
|
||||
new_height = target_height
|
||||
new_width = int(original_width * (target_height / original_height))
|
||||
resized_img = cv2.resize(img, (new_width, new_height))
|
||||
# 裁剪宽度
|
||||
start_x = (new_width - target_width) // 2
|
||||
cropped_img = resized_img[:, start_x:start_x + target_width]
|
||||
else:
|
||||
# 原始图像更高,按宽度resize,然后裁剪高度
|
||||
new_width = target_width
|
||||
new_height = int(original_height * (target_width / original_width))
|
||||
resized_img = cv2.resize(img, (new_width, new_height))
|
||||
# 裁剪高度
|
||||
start_y = (new_height - target_height) // 2
|
||||
cropped_img = resized_img[start_y:start_y + target_height, :]
|
||||
|
||||
return cropped_img
|
||||
@@ -1,57 +0,0 @@
|
||||
import math
|
||||
|
||||
import cv2
|
||||
|
||||
from app.service.utils.decorator import ClassCallRunTime
|
||||
from ..builder import PIPELINES
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class Scaling(object):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
# @ClassCallRunTime
|
||||
def __call__(self, result):
|
||||
if result['keypoint'] in ['waistband', 'shoulder', 'head_point']:
|
||||
# milvus_db_keypoint_cache
|
||||
distance_clo = math.sqrt(
|
||||
(int(result['clothes_keypoint'][result['keypoint'] + '_left'][0]) - int(result['clothes_keypoint'][result['keypoint'] + '_right'][0])) ** 2
|
||||
+
|
||||
(int(result['clothes_keypoint'][result['keypoint'] + '_left'][1]) - int(result['clothes_keypoint'][result['keypoint'] + '_right'][1])) ** 2)
|
||||
|
||||
distance_bdy = math.sqrt((int(result['body_point_test'][result['keypoint'] + '_left'][0]) - int(result['body_point_test'][result['keypoint'] + '_right'][0])) ** 2 + 1)
|
||||
# distance_clo = math.sqrt(
|
||||
# (int(result['clothes_keypoint'][result['keypoint'] + '_left'].split("_")[0]) - int(result['clothes_keypoint'][result['keypoint'] + '_right'].split("_")[0])) ** 2
|
||||
# +
|
||||
# (int(result['clothes_keypoint'][result['keypoint'] + '_left'].split("_")[1]) - int(result['clothes_keypoint'][result['keypoint'] + '_right'].split("_")[1])) ** 2)
|
||||
#
|
||||
# distance_bdy = math.sqrt((int(result['body_point_test'][result['keypoint'] + '_left'][0]) - int(result['body_point_test'][result['keypoint'] + '_right'][0])) ** 2 + 1)
|
||||
if distance_clo == 0:
|
||||
result['scale'] = 1
|
||||
else:
|
||||
result['scale'] = distance_bdy / distance_clo
|
||||
elif result['keypoint'] == 'toe':
|
||||
distance_bdy = math.sqrt(
|
||||
(int(result['body_point_test']['foot_length'][0]) - int(result['body_point_test']['foot_length'][2])) ** 2
|
||||
+
|
||||
(int(result['body_point_test']['foot_length'][1]) - int(result['body_point_test']['foot_length'][3])) ** 2
|
||||
)
|
||||
|
||||
Blur = cv2.GaussianBlur(result['gray'], (3, 3), 0)
|
||||
Edge = cv2.Canny(Blur, 10, 200)
|
||||
Edge = cv2.dilate(Edge, None)
|
||||
Edge = cv2.erode(Edge, None)
|
||||
Contour, _ = cv2.findContours(Edge, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
Contours = sorted(Contour, key=cv2.contourArea, reverse=True)
|
||||
|
||||
Max_contour = Contours[0]
|
||||
x, y, w, h = cv2.boundingRect(Max_contour)
|
||||
width = w
|
||||
distance_clo = width
|
||||
result['scale'] = distance_bdy / distance_clo
|
||||
elif result['keypoint'] == 'hand_point':
|
||||
result['scale'] = result['scale_bag']
|
||||
elif result['keypoint'] == 'ear_point':
|
||||
result['scale'] = result['scale_earrings']
|
||||
return result
|
||||
@@ -1,71 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from app.core.config import SEG_CACHE_PATH
|
||||
from app.service.utils.decorator import ClassCallRunTime
|
||||
from app.service.utils.oss_client import oss_get_image
|
||||
from ..builder import PIPELINES
|
||||
from ...utils.design_ensemble import get_seg_result
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class Segmentation(object):
|
||||
|
||||
@ClassCallRunTime
|
||||
def __call__(self, result):
|
||||
if "seg_mask_url" in result.keys() and result['seg_mask_url'] != "":
|
||||
seg_mask = oss_get_image(bucket=result['seg_mask_url'].split('/')[0], object_name=result['seg_mask_url'][result['seg_mask_url'].find('/') + 1:], data_type="cv2")
|
||||
seg_mask = cv2.resize(seg_mask, (result['img_shape'][1], result['img_shape'][0]), interpolation=cv2.INTER_NEAREST)
|
||||
# 转换颜色空间为 RGB(OpenCV 默认是 BGR)
|
||||
image_rgb = cv2.cvtColor(seg_mask, cv2.COLOR_BGR2RGB)
|
||||
|
||||
r, g, b = cv2.split(image_rgb)
|
||||
red_mask = r > g
|
||||
green_mask = g > r
|
||||
|
||||
# 创建红色和绿色掩码
|
||||
result['front_mask'] = np.array(red_mask, dtype=np.uint8) * 255
|
||||
result['back_mask'] = np.array(green_mask, dtype=np.uint8) * 255
|
||||
result['mask'] = result['front_mask'] + result['back_mask']
|
||||
else:
|
||||
# 本地查询seg 缓存是否存在
|
||||
_, seg_result = self.load_seg_result(result["image_id"])
|
||||
result['seg_result'] = seg_result
|
||||
if not _:
|
||||
# 推理获得seg 结果
|
||||
seg_result = get_seg_result(result["image_id"], result['image'])[0]
|
||||
self.save_seg_result(seg_result, result['image_id'])
|
||||
# 处理前片后片
|
||||
temp_front = seg_result == 1.0
|
||||
result['front_mask'] = (255 * (temp_front + 0).astype(np.uint8))
|
||||
temp_back = seg_result == 2.0
|
||||
result['back_mask'] = (255 * (temp_back + 0).astype(np.uint8))
|
||||
result['mask'] = result['front_mask'] + result['back_mask']
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def save_seg_result(seg_result, image_id):
|
||||
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
|
||||
try:
|
||||
np.save(file_path, seg_result)
|
||||
logger.debug(f"保存成功 :{os.path.abspath(file_path)}")
|
||||
except Exception as e:
|
||||
logger.error(f"保存失败: {e}")
|
||||
|
||||
@staticmethod
|
||||
def load_seg_result(image_id):
|
||||
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
|
||||
try:
|
||||
seg_result = np.load(file_path)
|
||||
return True, seg_result
|
||||
except FileNotFoundError:
|
||||
# logger.warning("文件不存在")
|
||||
return False, None
|
||||
except Exception as e:
|
||||
logger.error(f"加载失败: {e}")
|
||||
return False, None
|
||||
@@ -1,79 +0,0 @@
|
||||
import io
|
||||
import logging
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from cv2 import cvtColor, COLOR_BGR2RGBA
|
||||
|
||||
from app.core.config import AIDA_CLOTHING
|
||||
from app.service.utils.generate_uuid import generate_uuid
|
||||
from app.service.utils.oss_client import oss_upload_image
|
||||
from ..builder import PIPELINES
|
||||
from ...utils.conversion_image import rgb_to_rgba
|
||||
from ...utils.upload_image import upload_png_mask
|
||||
|
||||
|
||||
@PIPELINES.register_module()
|
||||
class Split(object):
|
||||
"""
|
||||
Split image into front and back layer according to the segmentation result
|
||||
"""
|
||||
|
||||
# @ClassCallRunTime
|
||||
# KNet
|
||||
def __call__(self, result):
|
||||
try:
|
||||
|
||||
if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms'):
|
||||
front_mask = result['front_mask']
|
||||
back_mask = result['back_mask']
|
||||
rgba_image = rgb_to_rgba(result['final_image'], front_mask + back_mask)
|
||||
new_size = (int(rgba_image.shape[1] * result["scale"] * result["resize_scale"][0]), int(rgba_image.shape[0] * result["scale"] * result["resize_scale"][1]))
|
||||
rgba_image = cv2.resize(rgba_image, new_size)
|
||||
result_front_image = np.zeros_like(rgba_image)
|
||||
front_mask = cv2.resize(front_mask, new_size)
|
||||
result_front_image[front_mask != 0] = rgba_image[front_mask != 0]
|
||||
result_front_image_pil = Image.fromarray(cvtColor(result_front_image, COLOR_BGR2RGBA))
|
||||
result['front_image'], result["front_image_url"], _ = upload_png_mask(result_front_image_pil, f'{generate_uuid()}', mask=None)
|
||||
|
||||
height, width = front_mask.shape
|
||||
mask_image = np.zeros((height, width, 3))
|
||||
mask_image[front_mask != 0] = [0, 0, 255]
|
||||
|
||||
if result["name"] in ('blouse', 'dress', 'outwear', 'tops'):
|
||||
result_back_image = np.zeros_like(rgba_image)
|
||||
back_mask = cv2.resize(back_mask, new_size)
|
||||
result_back_image[back_mask != 0] = rgba_image[back_mask != 0]
|
||||
result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA))
|
||||
result['back_image'], result["back_image_url"], _ = upload_png_mask(result_back_image_pil, f'{generate_uuid()}', mask=None)
|
||||
mask_image[back_mask != 0] = [0, 255, 0]
|
||||
|
||||
rbga_mask = rgb_to_rgba(mask_image, front_mask + back_mask)
|
||||
mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
|
||||
image_data = io.BytesIO()
|
||||
mask_pil.save(image_data, format='PNG')
|
||||
image_data.seek(0)
|
||||
image_bytes = image_data.read()
|
||||
req = oss_upload_image(bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
|
||||
result['mask_url'] = req.bucket_name + "/" + req.object_name
|
||||
else:
|
||||
rbga_mask = rgb_to_rgba(mask_image, front_mask)
|
||||
mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
|
||||
image_data = io.BytesIO()
|
||||
mask_pil.save(image_data, format='PNG')
|
||||
image_data.seek(0)
|
||||
image_bytes = image_data.read()
|
||||
req = oss_upload_image(bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
|
||||
result['mask_url'] = req.bucket_name + "/" + req.object_name
|
||||
result['back_image'] = None
|
||||
result["back_image_url"] = None
|
||||
# result["back_mask_url"] = None
|
||||
# result['back_mask_image'] = None
|
||||
# 创建中间图层
|
||||
result_pattern_image_rgba = rgb_to_rgba(result['pattern_image'], result['mask'])
|
||||
result_pattern_image_pil = Image.fromarray(cvtColor(result_pattern_image_rgba, COLOR_BGR2RGBA))
|
||||
result['pattern_image'], result['pattern_image_url'], _ = upload_png_mask(result_pattern_image_pil, f'{generate_uuid()}')
|
||||
return result
|
||||
except Exception as e:
|
||||
logging.warning(f"split runtime exception : {e} image_id : {result['image_id']}")
|
||||
@@ -1,121 +0,0 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from .builder import ITEMS
|
||||
from .clothing import Clothing
|
||||
from ..utils.conversion_image import rgb_to_rgba
|
||||
from ..utils.upload_image import upload_png_mask
|
||||
from ...utils.generate_uuid import generate_uuid
|
||||
|
||||
|
||||
@ITEMS.register_module()
|
||||
class Shoes(Clothing):
|
||||
# TODO location of shoes has little mismatch
|
||||
def __init__(self, **kwargs):
|
||||
pipeline = [
|
||||
dict(type='LoadImageFromFile', path=kwargs['path'], color=kwargs['color']),
|
||||
dict(type='KeypointDetection'),
|
||||
dict(type='ContourDetection'),
|
||||
dict(type='Painting'),
|
||||
dict(type='Scaling'),
|
||||
dict(type='Split'),
|
||||
# dict(type='ImageShow', key=['image', 'mask', 'pattern_image']),
|
||||
]
|
||||
kwargs.update(pipeline=pipeline)
|
||||
super(Shoes, self).__init__(**kwargs)
|
||||
|
||||
def organize(self, layer):
|
||||
left_shoe_mask, right_shoe_mask = self.cut()
|
||||
|
||||
left_layer = dict(name=f'{type(self).__name__.lower()}_left',
|
||||
image=self.result['shoes_left'],
|
||||
image_url=self.result['left_image_url'],
|
||||
mask_url=self.result['left_mask_url'],
|
||||
sacle=self.result['scale'],
|
||||
clothes_keypoint=self.result['clothes_keypoint'],
|
||||
position=self.calculate_start_point(self.result['keypoint'],
|
||||
self.result['scale'],
|
||||
self.result['clothes_keypoint'],
|
||||
self.result['body_point'],
|
||||
'left'))
|
||||
layer.insert(left_layer)
|
||||
|
||||
right_layer = dict(name=f'{type(self).__name__.lower()}_right',
|
||||
image=self.result['shoes_right'],
|
||||
image_url=self.result['right_image_url'],
|
||||
mask_url=self.result['right_mask_url'],
|
||||
sacle=self.result['scale'],
|
||||
clothes_keypoint=self.result['clothes_keypoint'],
|
||||
position=self.calculate_start_point(self.result['keypoint'],
|
||||
self.result['scale'],
|
||||
self.result['clothes_keypoint'],
|
||||
self.result['body_point'],
|
||||
'right'))
|
||||
|
||||
layer.insert(right_layer)
|
||||
|
||||
def cut(self):
|
||||
"""
|
||||
Cut shoes mask into two pieces
|
||||
Returns:
|
||||
"""
|
||||
contour, _ = cv2.findContours(self.result['mask'], cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
contours = sorted(contour, key=cv2.contourArea, reverse=True)
|
||||
|
||||
bounding_boxes = [cv2.boundingRect(c) for c in contours[:2]]
|
||||
(contours, bounding_boxes) = zip(*sorted(zip(contours[:2], bounding_boxes), key=lambda x: x[1][0], reverse=False))
|
||||
|
||||
epsilon_left = 0.001 * cv2.arcLength(contours[0], True)
|
||||
|
||||
approx_left = cv2.approxPolyDP(contours[0], epsilon_left, True)
|
||||
mask_left = np.zeros(self.result['final_image'].shape[:2], np.uint8)
|
||||
cv2.drawContours(mask_left, [approx_left], -1, 255, -1)
|
||||
item_mask_left = cv2.GaussianBlur(mask_left, (5, 5), 0)
|
||||
|
||||
rgba_image = rgb_to_rgba((self.result['final_image'].shape[0], self.result['final_image'].shape[1]), self.result['final_image'], item_mask_left)
|
||||
result_image = np.zeros_like(rgba_image)
|
||||
result_image[self.result['front_mask'] != 0] = rgba_image[self.result['front_mask'] != 0]
|
||||
result_left_image_pil = Image.fromarray(result_image, 'RGBA')
|
||||
result_left_image_pil = result_left_image_pil.resize((int(result_left_image_pil.width * self.result["scale"]), int(result_left_image_pil.height * self.result["scale"])), Image.LANCZOS)
|
||||
self.result['shoes_left'], self.result["left_image_url"], self.result["left_mask_url"] = upload_png_mask(result_left_image_pil, f"{generate_uuid()}")
|
||||
|
||||
epsilon_right = 0.001 * cv2.arcLength(contours[1], True)
|
||||
approx_right = cv2.approxPolyDP(contours[1], epsilon_right, True)
|
||||
mask_right = np.zeros(self.result['final_image'].shape[:2], np.uint8)
|
||||
cv2.drawContours(mask_right, [approx_right], -1, 255, -1)
|
||||
item_mask_right = cv2.GaussianBlur(mask_right, (5, 5), 0)
|
||||
|
||||
rgba_image = rgb_to_rgba((self.result['final_image'].shape[0], self.result['final_image'].shape[1]), self.result['final_image'], item_mask_right)
|
||||
result_image = np.zeros_like(rgba_image)
|
||||
result_image[self.result['front_mask'] != 0] = rgba_image[self.result['front_mask'] != 0]
|
||||
result_right_image_pil = Image.fromarray(result_image, 'RGBA')
|
||||
result_right_image_pil = result_right_image_pil.resize((int(result_right_image_pil.width * self.result["scale"]), int(result_right_image_pil.height * self.result["scale"])), Image.LANCZOS)
|
||||
self.result['shoes_right'], self.result["right_image_url"], self.result["right_mask_url"] = upload_png_mask(result_right_image_pil, f"{generate_uuid()}")
|
||||
|
||||
return item_mask_left, item_mask_right
|
||||
|
||||
@staticmethod
|
||||
def calculate_start_point(keypoint_type, scale, clothes_point, body_point, location):
|
||||
"""
|
||||
left shoes align left
|
||||
right shoes align right
|
||||
Args:
|
||||
keypoint_type: string, "toe"
|
||||
scale: float
|
||||
clothes_point: dict{'left': [x1, y1, z1], 'right': [x2, y2, z2]}
|
||||
body_point: dict, containing keypoint data of body figure
|
||||
location: string, indicates whether the start point belongs to right or left shoe
|
||||
|
||||
Returns:
|
||||
start_point: tuple (x', y')
|
||||
x' = y_body - y1 * scale
|
||||
y' = x_body - x1 * scale
|
||||
"""
|
||||
if location not in ['left', 'right']:
|
||||
raise KeyError(f'location value must be left or right but got {location}')
|
||||
side_indicator = f'{keypoint_type}_{location}'
|
||||
# clothes_point = {k: tuple(map(lambda x: int(scale * x), v[0: 2])) for k, v in clothes_point.items()}
|
||||
start_point = (body_point[side_indicator][1] - int(int(clothes_point[side_indicator].split("_")[1]) * scale),
|
||||
body_point[side_indicator][0] - int(int(clothes_point[side_indicator].split("_")[0]) * scale))
|
||||
return start_point
|
||||
@@ -1,46 +0,0 @@
|
||||
from .builder import ITEMS
|
||||
from .clothing import Clothing
|
||||
|
||||
|
||||
@ITEMS.register_module()
|
||||
class Top(Clothing):
|
||||
def __init__(self, pipeline, **kwargs):
|
||||
if pipeline is None:
|
||||
pipeline = [
|
||||
dict(type='LoadImageFromFile', path=kwargs['path'], color=kwargs['color'], print_dict=kwargs['print']),
|
||||
dict(type='KeypointDetection'),
|
||||
# dict(type='ContourDetection'),
|
||||
dict(type='Segmentation'),
|
||||
dict(type='Painting', painting_flag=True),
|
||||
dict(type='PrintPainting', print_flag=True),
|
||||
# dict(type='ImageShow', key=['image', 'mask', 'seg_visualize', 'pattern_image']),
|
||||
dict(type='Scaling'),
|
||||
dict(type='Split'),
|
||||
]
|
||||
kwargs.update(pipeline=pipeline)
|
||||
super(Top, self).__init__(**kwargs)
|
||||
|
||||
|
||||
@ITEMS.register_module()
|
||||
class Blouse(Top):
|
||||
def __init__(self, pipeline=None, **kwargs):
|
||||
super(Blouse, self).__init__(pipeline, **kwargs)
|
||||
|
||||
|
||||
@ITEMS.register_module()
|
||||
class Outwear(Top):
|
||||
def __init__(self, pipeline=None, **kwargs):
|
||||
super(Outwear, self).__init__(pipeline, **kwargs)
|
||||
|
||||
|
||||
@ITEMS.register_module()
|
||||
class Dress(Top):
|
||||
def __init__(self, pipeline=None, **kwargs):
|
||||
super(Dress, self).__init__(pipeline, **kwargs)
|
||||
|
||||
|
||||
# Men's clothing
|
||||
@ITEMS.register_module()
|
||||
class Tops(Top):
|
||||
def __init__(self, pipeline=None, **kwargs):
|
||||
super(Tops, self).__init__(pipeline, **kwargs)
|
||||
@@ -1,197 +0,0 @@
|
||||
import concurrent.futures
|
||||
import io
|
||||
|
||||
import cv2
|
||||
|
||||
from app.core.config import PRIORITY_DICT
|
||||
from app.service.design.core.layer import Layer
|
||||
from app.service.design.items import build_item
|
||||
from app.service.design.utils.redis_utils import Redis
|
||||
from app.service.design.utils.synthesis_item import synthesis, synthesis_single
|
||||
from app.service.utils.decorator import RunTime
|
||||
from app.service.utils.oss_client import oss_upload_image
|
||||
|
||||
|
||||
def process_item(item, layers):
|
||||
# logging.info("process running.........")
|
||||
item.process()
|
||||
item.organize(layers)
|
||||
if item.result['name'] == "mannequin":
|
||||
return item.result['body_image'].size
|
||||
|
||||
|
||||
def update_progress(process_id, total):
|
||||
r = Redis()
|
||||
progress = r.read(key=process_id)
|
||||
if progress and total != 1:
|
||||
if int(progress) <= 100:
|
||||
r.write(key=process_id, value=int(progress) + int(100 / total))
|
||||
else:
|
||||
r.write(key=process_id, value=99)
|
||||
return progress
|
||||
elif total == 1:
|
||||
r.write(key=process_id, value=100)
|
||||
return progress
|
||||
else:
|
||||
r.write(key=process_id, value=int(100 / total))
|
||||
return progress
|
||||
|
||||
|
||||
def final_progress(process_id):
|
||||
r = Redis()
|
||||
progress = r.read(key=process_id)
|
||||
r.write(key=process_id, value=100)
|
||||
return progress
|
||||
|
||||
|
||||
@RunTime
|
||||
def generate(request_data):
|
||||
return_response = {}
|
||||
return_png_mask = []
|
||||
request_data = request_data.dict()
|
||||
assert "process_id" in request_data.keys(), "Need process_id parameters"
|
||||
|
||||
objects = request_data['objects']
|
||||
# insert_keypoint_cache(objects)
|
||||
process_id = request_data['process_id']
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
# 提交每个对象的处理任务
|
||||
futures = {executor.submit(process_object, cfg, process_id, len(objects)): obj for obj, cfg in enumerate(objects)}
|
||||
# 获取处理结果
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
obj = futures[future]
|
||||
return_response[obj] = future.result()[0]
|
||||
return_png_mask.extend(future.result()[1])
|
||||
# upload_results = process_images(return_png_mask)
|
||||
final_progress(process_id)
|
||||
return return_response
|
||||
|
||||
|
||||
def process_object(cfg, process_id, total):
|
||||
uploaded_images = []
|
||||
basic_info = cfg.get('basic')
|
||||
items_response = {
|
||||
'layers': []
|
||||
}
|
||||
if cfg.get('basic')['single_overall'] == 'overall':
|
||||
basic_info['debug'] = False
|
||||
items = [build_item(x, default_args=basic_info) for x in cfg.get('items')]
|
||||
layers = Layer()
|
||||
body_size = None
|
||||
futures = []
|
||||
for item in items:
|
||||
futures = [process_item(item, layers)]
|
||||
for future in futures:
|
||||
if future is not None:
|
||||
body_size = future
|
||||
# 是否自定义排序
|
||||
if basic_info.get('layer_order', False):
|
||||
layers = sorted(layers.layer, key=lambda s: s.get("priority", float('inf')))
|
||||
else:
|
||||
layers = sorted(layers.layer, key=lambda x: PRIORITY_DICT.get(x['name'], float('inf')))
|
||||
# 上传所有图片
|
||||
# for layer in layers:
|
||||
# if 'image' in layer.keys() and layer['image'] is not None:
|
||||
# uploaded_images.append({'image_obj': layer['image'], 'image_url': layer['image_url'], 'image_type': 'image'})
|
||||
# if 'pattern_image' in layer.keys() and layer['pattern_image'] is not None:
|
||||
# uploaded_images.append({'image_obj': layer['pattern_image'], 'image_url': layer['pattern_image_url'], 'image_type': 'pattern_image'})
|
||||
# if 'mask' in layer.keys() and layer['mask'] is not None and layer['mask_url'] is not None:
|
||||
# uploaded_images.append({'image_obj': layer['mask'], 'image_url': layer['mask_url'], 'image_type': 'mask'})
|
||||
layers, new_size = update_base_size_priority(layers, body_size)
|
||||
# 合成
|
||||
items_response['synthesis_url'] = synthesis(layers, new_size, basic_info)
|
||||
|
||||
for lay in layers:
|
||||
items_response['layers'].append({
|
||||
'image_category': lay['name'],
|
||||
'position': lay['position'],
|
||||
'priority': lay.get("priority", None),
|
||||
'resize_scale': lay['resize_scale'] if "resize_scale" in lay.keys() else None,
|
||||
'image_size': lay['image'] if lay['image'] is None else lay['image'].size,
|
||||
'gradient_string': lay['gradient_string'] if 'gradient_string' in lay.keys() else "",
|
||||
'mask_url': lay['mask_url'],
|
||||
'image_url': lay['image_url'] if 'image_url' in lay.keys() else None,
|
||||
'pattern_image_url': lay['pattern_image_url'] if 'pattern_image_url' in lay.keys() else None,
|
||||
|
||||
# 'image': lay['image'],
|
||||
# 'mask_image': lay['mask_image'],
|
||||
})
|
||||
elif cfg.get('basic')['single_overall'] == 'single':
|
||||
assert cfg.get('basic')['switch_category'] in [x['type'] for x in cfg.get('items')], "Lack of switch_category parameters "
|
||||
basic_info['debug'] = False
|
||||
for item in cfg.get('items'):
|
||||
if item['type'] == cfg.get('basic')['switch_category']:
|
||||
item = build_item(item, default_args=cfg.get('basic'))
|
||||
item.process()
|
||||
items_response['layers'].append({
|
||||
'image_category': f"{item.result['name']}_front",
|
||||
'image_size': item.result['back_image'].size if item.result['back_image'] else None,
|
||||
'position': None,
|
||||
'priority': 0,
|
||||
'image_url': item.result['front_image_url'],
|
||||
'mask_url': item.result['mask_url'],
|
||||
"gradient_string": item.result['gradient_string'] if 'gradient_string' in item.result.keys() else "",
|
||||
'pattern_image_url': item.result['pattern_image_url'] if 'pattern_image_url' in item.result.keys() else None,
|
||||
|
||||
})
|
||||
items_response['layers'].append({
|
||||
'image_category': f"{item.result['name']}_back",
|
||||
'image_size': item.result['front_image'].size if item.result['front_image'] else None,
|
||||
'position': None,
|
||||
'priority': 0,
|
||||
'image_url': item.result['back_image_url'],
|
||||
'mask_url': item.result['mask_url'],
|
||||
"gradient_string": item.result['gradient_string'] if 'gradient_string' in item.result.keys() else "",
|
||||
'pattern_image_url': item.result['pattern_image_url'] if 'pattern_image_url' in item.result.keys() else None,
|
||||
|
||||
})
|
||||
items_response['synthesis_url'] = synthesis_single(item.result['front_image'], item.result['back_image'])
|
||||
break
|
||||
update_progress(process_id, total)
|
||||
return items_response, uploaded_images
|
||||
|
||||
|
||||
@RunTime
|
||||
def process_images(images):
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
results = list(executor.map(upload_images, images))
|
||||
# results = []
|
||||
# for image in images:
|
||||
# results.append(upload_images(image))
|
||||
return results
|
||||
|
||||
|
||||
# @RunTime
|
||||
def upload_images(image_obj):
|
||||
bucket_name = image_obj['image_url'].split("/", 1)[0]
|
||||
object_name = image_obj['image_url'].split("/", 1)[1]
|
||||
if image_obj['image_type'] == 'image' or image_obj['image_type'] == 'pattern_image':
|
||||
image_data = io.BytesIO()
|
||||
image_obj['image_obj'].save(image_data, format='PNG')
|
||||
image_data.seek(0)
|
||||
image_bytes = image_data.read()
|
||||
req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
|
||||
return image_obj['image_url']
|
||||
else:
|
||||
mask_inverted = cv2.bitwise_not(image_obj['image_obj'])
|
||||
# 将掩模的3通道转换为4通道,白色部分不透明,黑色部分透明
|
||||
rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA)
|
||||
rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0]
|
||||
req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=cv2.imencode('.png', rgba_image)[1])
|
||||
return image_obj['image_url']
|
||||
|
||||
|
||||
def update_base_size_priority(layers, size):
|
||||
# 计算透明背景图片的宽度
|
||||
min_x = min(info['position'][1] for info in layers)
|
||||
x_list = []
|
||||
for info in layers:
|
||||
if info['image'] is not None:
|
||||
x_list.append(info['position'][1] + info['image'].width)
|
||||
max_x = max(x_list)
|
||||
new_width = max_x - min_x
|
||||
new_height = 700
|
||||
# 更新坐标
|
||||
for info in layers:
|
||||
info['adaptive_position'] = (info['position'][0], info['position'][1] - min_x)
|
||||
return layers, (new_width, new_height)
|
||||
@@ -1,31 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
"""
|
||||
@Project :trinity_client
|
||||
@File :conversion_image.py
|
||||
@Author :周成融
|
||||
@Date :2023/8/21 10:40:29
|
||||
@detail :
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
|
||||
# def rgb_to_rgba(rgb_size, rgb_image, mask):
|
||||
# alpha_channel = np.full(rgb_size, 255, dtype=np.uint8)
|
||||
# # 创建四通道的结果图像
|
||||
# rgba_image = np.dstack((rgb_image, alpha_channel))
|
||||
# alpha_channel = np.where(mask > 0, 255, 0)
|
||||
# # 更新RGBA图像的透明度通道
|
||||
# rgba_image[:, :, 3] = alpha_channel
|
||||
# return rgba_image
|
||||
|
||||
def rgb_to_rgba(rgb_image, mask):
|
||||
# 创建全透明的alpha通道
|
||||
alpha_channel = np.where(mask > 0, 255, 0).astype(np.uint8)
|
||||
# 合并RGB图像和alpha通道
|
||||
rgba_image = np.dstack((rgb_image, alpha_channel))
|
||||
return rgba_image
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
image = open("")
|
||||
@@ -1,143 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
"""
|
||||
@Project :trinity_client
|
||||
@File :design_ensemble.py
|
||||
@Author :周成融
|
||||
@Date :2023/8/16 19:36:21
|
||||
@detail :发起请求 获取推理结果
|
||||
"""
|
||||
import logging
|
||||
|
||||
import cv2
|
||||
import mmcv
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import tritonclient.http as httpclient
|
||||
|
||||
from app.core.config import *
|
||||
|
||||
"""
|
||||
keypoint
|
||||
预处理 推理 后处理
|
||||
"""
|
||||
|
||||
|
||||
def keypoint_preprocess(img_path):
|
||||
img = mmcv.imread(img_path)
|
||||
img_scale = (256, 256)
|
||||
h, w = img.shape[:2]
|
||||
img = cv2.resize(img, img_scale)
|
||||
w_scale = img_scale[0] / w
|
||||
h_scale = img_scale[1] / h
|
||||
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
|
||||
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
|
||||
return preprocessed_img, (w_scale, h_scale)
|
||||
|
||||
|
||||
# @ RunTime
|
||||
# 推理
|
||||
def get_keypoint_result(image, site):
|
||||
keypoint_result = None
|
||||
try:
|
||||
image, scale_factor = keypoint_preprocess(image)
|
||||
client = httpclient.InferenceServerClient(url=DESIGN_MODEL_URL)
|
||||
transformed_img = image.astype(np.float32)
|
||||
inputs = [httpclient.InferInput(f"input", transformed_img.shape, datatype="FP32")]
|
||||
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
|
||||
outputs = [httpclient.InferRequestedOutput(f"output", binary_data=True)]
|
||||
results = client.infer(model_name=f"keypoint_{site}_ocrnet_hr18", inputs=inputs, outputs=outputs)
|
||||
inference_output = torch.from_numpy(results.as_numpy(f'output'))
|
||||
keypoint_result = keypoint_postprocess(inference_output, scale_factor)
|
||||
except Exception as e:
|
||||
logging.warning(f"get_keypoint_result : {e}")
|
||||
return keypoint_result
|
||||
|
||||
|
||||
def keypoint_postprocess(output, scale_factor):
|
||||
max_indices = torch.argmax(output.view(output.size(0), output.size(1), -1), dim=2).unsqueeze(dim=2)
|
||||
max_coords = torch.cat((max_indices / output.size(3), max_indices % output.size(3)), dim=2)
|
||||
segment_result = max_coords.numpy()
|
||||
scale_factor = [1 / x for x in scale_factor[::-1]]
|
||||
scale_matrix = np.diag(scale_factor)
|
||||
nan = np.isinf(scale_matrix)
|
||||
scale_matrix[nan] = 0
|
||||
return np.ceil(np.dot(segment_result, scale_matrix) * 4)
|
||||
|
||||
|
||||
"""
|
||||
seg
|
||||
预处理 推理 后处理
|
||||
"""
|
||||
|
||||
|
||||
# KNet
|
||||
def seg_preprocess(img_path):
|
||||
img = mmcv.imread(img_path)
|
||||
ori_shape = img.shape[:2]
|
||||
img_scale_w, img_scale_h = ori_shape
|
||||
if ori_shape[0] > 1024:
|
||||
img_scale_w = 1024
|
||||
if ori_shape[1] > 1024:
|
||||
img_scale_h = 1024
|
||||
# 如果图片size任意一边 大于 1024, 则会resize 成1024
|
||||
if ori_shape != (img_scale_w, img_scale_h):
|
||||
# mmcv.imresize(img, img_scale_h, img_scale_w) # 老代码 引以为戒!哈哈哈~ h和w写反了
|
||||
img = cv2.resize(img, (img_scale_h, img_scale_w))
|
||||
img = mmcv.imnormalize(img, mean=np.array([123.675, 116.28, 103.53]), std=np.array([58.395, 57.12, 57.375]), to_rgb=True)
|
||||
preprocessed_img = np.expand_dims(img.transpose(2, 0, 1), axis=0)
|
||||
return preprocessed_img, ori_shape
|
||||
|
||||
|
||||
# @ RunTime
|
||||
def get_seg_result(image_id, image):
|
||||
image, ori_shape = seg_preprocess(image)
|
||||
client = httpclient.InferenceServerClient(url=f"{DESIGN_MODEL_URL}")
|
||||
transformed_img = image.astype(np.float32)
|
||||
# 输入集
|
||||
inputs = [
|
||||
httpclient.InferInput(SEGMENTATION['input'], transformed_img.shape, datatype="FP32")
|
||||
]
|
||||
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
|
||||
# 输出集
|
||||
outputs = [
|
||||
httpclient.InferRequestedOutput(SEGMENTATION['output'], binary_data=True),
|
||||
]
|
||||
results = client.infer(model_name=SEGMENTATION['new_model_name'], inputs=inputs, outputs=outputs)
|
||||
# 推理
|
||||
# 取结果
|
||||
inference_output1 = results.as_numpy(SEGMENTATION['output'])
|
||||
seg_result = seg_postprocess(int(image_id), inference_output1, ori_shape)
|
||||
return seg_result
|
||||
|
||||
|
||||
# no cache
|
||||
def seg_postprocess(image_id, output, ori_shape):
|
||||
seg_logit = F.interpolate(torch.tensor(output).float(), size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False)
|
||||
seg_pred = seg_logit.cpu().numpy()
|
||||
return seg_pred[0]
|
||||
|
||||
|
||||
def key_point_show(image_path, key_point_result=None):
|
||||
img = cv2.imread(image_path)
|
||||
points_list = key_point_result
|
||||
point_size = 1
|
||||
point_color = (0, 0, 255) # BGR
|
||||
thickness = 4 # 可以为 0 、4、8
|
||||
for point in points_list:
|
||||
cv2.circle(img, point[::-1], point_size, point_color, thickness)
|
||||
cv2.imshow("0", img)
|
||||
cv2.waitKey(0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
image = cv2.imread("9070101c-e5be-49b5-9602-4113a968969b.png")
|
||||
a = get_keypoint_result(image, "up")
|
||||
new_list = []
|
||||
print(list)
|
||||
for i in a[0]:
|
||||
new_list.append((int(i[0]), int(i[1])))
|
||||
key_point_show("9070101c-e5be-49b5-9602-4113a968969b.png", new_list)
|
||||
# a = get_seg_result(1, image)
|
||||
print(a)
|
||||
@@ -1,99 +0,0 @@
|
||||
import redis
|
||||
|
||||
from app.core.config import REDIS_HOST, REDIS_PORT
|
||||
|
||||
|
||||
class Redis(object):
|
||||
"""
|
||||
redis数据库操作
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _get_r():
|
||||
host = REDIS_HOST
|
||||
port = REDIS_PORT
|
||||
db = 0
|
||||
r = redis.StrictRedis(host, port, db)
|
||||
return r
|
||||
|
||||
@classmethod
|
||||
def write(cls, key, value, expire=None):
|
||||
"""
|
||||
写入键值对
|
||||
"""
|
||||
# 判断是否有过期时间,没有就设置默认值
|
||||
if expire:
|
||||
expire_in_seconds = expire
|
||||
else:
|
||||
expire_in_seconds = 100
|
||||
r = cls._get_r()
|
||||
r.set(key, value, ex=expire_in_seconds)
|
||||
|
||||
@classmethod
|
||||
def read(cls, key):
|
||||
"""
|
||||
读取键值对内容
|
||||
"""
|
||||
r = cls._get_r()
|
||||
value = r.get(key)
|
||||
return value.decode('utf-8') if value else value
|
||||
|
||||
@classmethod
|
||||
def hset(cls, name, key, value):
|
||||
"""
|
||||
写入hash表
|
||||
"""
|
||||
r = cls._get_r()
|
||||
r.hset(name, key, value)
|
||||
|
||||
@classmethod
|
||||
def hget(cls, name, key):
|
||||
"""
|
||||
读取指定hash表的键值
|
||||
"""
|
||||
r = cls._get_r()
|
||||
value = r.hget(name, key)
|
||||
return value.decode('utf-8') if value else value
|
||||
|
||||
@classmethod
|
||||
def hgetall(cls, name):
|
||||
"""
|
||||
获取指定hash表所有的值
|
||||
"""
|
||||
r = cls._get_r()
|
||||
return r.hgetall(name)
|
||||
|
||||
@classmethod
|
||||
def delete(cls, *names):
|
||||
"""
|
||||
删除一个或者多个
|
||||
"""
|
||||
r = cls._get_r()
|
||||
r.delete(*names)
|
||||
|
||||
@classmethod
|
||||
def hdel(cls, name, key):
|
||||
"""
|
||||
删除指定hash表的键值
|
||||
"""
|
||||
r = cls._get_r()
|
||||
r.hdel(name, key)
|
||||
|
||||
@classmethod
|
||||
def expire(cls, name, expire=None):
|
||||
"""
|
||||
设置过期时间
|
||||
"""
|
||||
if expire:
|
||||
expire_in_seconds = expire
|
||||
else:
|
||||
expire_in_seconds = 100
|
||||
r = cls._get_r()
|
||||
r.expire(name, expire_in_seconds)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
redis_client = Redis()
|
||||
# print(redis_client.write(key="1230", value=0))
|
||||
redis_client.write(key="1230", value=10)
|
||||
# print(redis_client.read(key="1230"))
|
||||
@@ -1,181 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: UTF-8 -*-
|
||||
"""
|
||||
@Project :trinity_client
|
||||
@File :synthesis_item.py
|
||||
@Author :周成融
|
||||
@Date :2023/8/26 14:13:04
|
||||
@detail :
|
||||
"""
|
||||
import io
|
||||
import logging
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from app.service.utils.generate_uuid import generate_uuid
|
||||
from app.service.utils.oss_client import oss_upload_image
|
||||
|
||||
|
||||
def positioning(all_mask_shape, mask_shape, offset):
|
||||
all_start = 0
|
||||
all_end = 0
|
||||
mask_start = 0
|
||||
mask_end = 0
|
||||
if offset == 0:
|
||||
all_start = 0
|
||||
all_end = min(all_mask_shape, mask_shape)
|
||||
|
||||
mask_start = 0
|
||||
mask_end = min(all_mask_shape, mask_shape)
|
||||
elif offset > 0:
|
||||
all_start = min(offset, all_mask_shape)
|
||||
all_end = min(offset + mask_shape, all_mask_shape)
|
||||
|
||||
mask_start = 0
|
||||
mask_end = 0 if offset > all_mask_shape else min(all_mask_shape - offset, mask_shape)
|
||||
elif offset < 0:
|
||||
if abs(offset) > mask_shape:
|
||||
all_start = 0
|
||||
all_end = 0
|
||||
else:
|
||||
all_start = 0
|
||||
if mask_shape - abs(offset) > all_mask_shape:
|
||||
all_end = min(mask_shape - abs(offset), all_mask_shape)
|
||||
else:
|
||||
all_end = mask_shape - abs(offset)
|
||||
|
||||
if abs(offset) > mask_shape:
|
||||
mask_start = mask_shape
|
||||
mask_end = mask_shape
|
||||
else:
|
||||
mask_start = abs(offset)
|
||||
if mask_shape - abs(offset) >= all_mask_shape:
|
||||
mask_end = all_mask_shape + abs(offset)
|
||||
else:
|
||||
mask_end = mask_shape
|
||||
return all_start, all_end, mask_start, mask_end
|
||||
|
||||
|
||||
# @RunTime
|
||||
def synthesis(data, size, basic_info):
|
||||
# 创建底图
|
||||
base_image = Image.new('RGBA', size, (0, 0, 0, 0))
|
||||
try:
|
||||
all_mask_shape = (size[1], size[0])
|
||||
body_mask = None
|
||||
for d in data:
|
||||
if d['name'] == 'body':
|
||||
# 创建一个新的宽高透明图像, 把模特贴上去获取mask
|
||||
transparent_image = Image.new("RGBA", size, (0, 0, 0, 0))
|
||||
transparent_image.paste(d['image'], (d['adaptive_position'][1], d['adaptive_position'][0]), d['image']) # 此处可变数组会被paste篡改值,所以使用下标获取position
|
||||
body_mask = np.array(transparent_image.split()[3])
|
||||
|
||||
# 根据新的坐标获取新的肩点
|
||||
left_shoulder = [x + y for x, y in zip(basic_info['body_point_test']['shoulder_left'], [d['adaptive_position'][1], d['adaptive_position'][0]])]
|
||||
right_shoulder = [x + y for x, y in zip(basic_info['body_point_test']['shoulder_right'], [d['adaptive_position'][1], d['adaptive_position'][0]])]
|
||||
body_mask[:min(left_shoulder[1], right_shoulder[1]), left_shoulder[0]:right_shoulder[0]] = 255
|
||||
_, binary_body_mask = cv2.threshold(body_mask, 127, 255, cv2.THRESH_BINARY)
|
||||
top_outer_mask = np.array(binary_body_mask)
|
||||
bottom_outer_mask = np.array(binary_body_mask)
|
||||
|
||||
top = True
|
||||
bottom = True
|
||||
i = len(data)
|
||||
while i:
|
||||
i -= 1
|
||||
if top and data[i]['name'] in ["blouse_front", "outwear_front", "dress_front", "tops_front"]:
|
||||
top = False
|
||||
mask_shape = data[i]['mask'].shape
|
||||
y_offset, x_offset = data[i]['adaptive_position']
|
||||
# 初始化叠加区域的起始和结束位置
|
||||
all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset)
|
||||
all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset)
|
||||
# 将叠加区域赋值为相应的像素值
|
||||
_, sketch_mask = cv2.threshold(data[i]['mask'], 127, 255, cv2.THRESH_BINARY)
|
||||
background = np.zeros_like(top_outer_mask)
|
||||
background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end]
|
||||
top_outer_mask = background + top_outer_mask
|
||||
elif bottom and data[i]['name'] in ["trousers_front", "skirt_front", "bottoms_front", "dress_front"]:
|
||||
bottom = False
|
||||
mask_shape = data[i]['mask'].shape
|
||||
y_offset, x_offset = data[i]['adaptive_position']
|
||||
# 初始化叠加区域的起始和结束位置
|
||||
all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset)
|
||||
all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset)
|
||||
# 将叠加区域赋值为相应的像素值
|
||||
_, sketch_mask = cv2.threshold(data[i]['mask'], 127, 255, cv2.THRESH_BINARY)
|
||||
background = np.zeros_like(top_outer_mask)
|
||||
background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end]
|
||||
bottom_outer_mask = background + bottom_outer_mask
|
||||
elif bottom is False and top is False:
|
||||
break
|
||||
|
||||
all_mask = cv2.bitwise_or(top_outer_mask, bottom_outer_mask)
|
||||
|
||||
for layer in data:
|
||||
if layer['image'] is not None:
|
||||
if layer['name'] != "body":
|
||||
test_image = Image.new('RGBA', size, (0, 0, 0, 0))
|
||||
test_image.paste(layer['image'], (layer['adaptive_position'][1], layer['adaptive_position'][0]), layer['image'])
|
||||
mask_data = np.where(all_mask > 0, 255, 0).astype(np.uint8)
|
||||
mask_alpha = Image.fromarray(mask_data)
|
||||
cropped_image = Image.composite(test_image, Image.new("RGBA", test_image.size, (255, 255, 255, 0)), mask_alpha)
|
||||
base_image.paste(test_image, (0, 0), cropped_image) # test_image 已经按照坐标贴到最大宽值的图片上 坐着这里坐标为00
|
||||
else:
|
||||
base_image.paste(layer['image'], (layer['adaptive_position'][1], layer['adaptive_position'][0]), layer['image'])
|
||||
|
||||
result_image = base_image
|
||||
|
||||
image_data = io.BytesIO()
|
||||
result_image.save(image_data, format='PNG')
|
||||
image_data.seek(0)
|
||||
|
||||
# oss upload
|
||||
image_bytes = image_data.read()
|
||||
bucket_name = "aida-results"
|
||||
object_name = f'result_{generate_uuid()}.png'
|
||||
req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
|
||||
return f"{bucket_name}/{object_name}"
|
||||
# return f"aida-results/{minio_client.put_object('aida-results', f'result_{generate_uuid()}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
|
||||
|
||||
# object_name = f'result_{generate_uuid()}.png'
|
||||
# response = s3.put_object(Bucket="aida-results", Key=object_name, Body=data, ContentType='image/png')
|
||||
# object_url = f"aida-results/{object_name}"
|
||||
# if response['ResponseMetadata']['HTTPStatusCode'] == 200:
|
||||
# return object_url
|
||||
# else:
|
||||
# return ""
|
||||
|
||||
except Exception as e:
|
||||
logging.warning(f"synthesis runtime exception : {e}")
|
||||
|
||||
|
||||
def synthesis_single(front_image, back_image):
|
||||
result_image = None
|
||||
if front_image:
|
||||
result_image = front_image
|
||||
if back_image:
|
||||
result_image.paste(back_image, (0, 0), back_image)
|
||||
|
||||
# with io.BytesIO() as output:
|
||||
# result_image.save(output, format='PNG')
|
||||
# data = output.getvalue()
|
||||
# object_name = f'result_{generate_uuid()}.png'
|
||||
# response = s3.put_object(Bucket="aida-results", Key=object_name, Body=data, ContentType='image/png')
|
||||
# object_url = f"aida-results/{object_name}"
|
||||
# if response['ResponseMetadata']['HTTPStatusCode'] == 200:
|
||||
# return object_url
|
||||
# else:
|
||||
# return ""
|
||||
image_data = io.BytesIO()
|
||||
result_image.save(image_data, format='PNG')
|
||||
image_data.seek(0)
|
||||
image_bytes = image_data.read()
|
||||
# return f"aida-results/{minio_client.put_object('aida-results', f'result_{generate_uuid()}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
|
||||
# oss upload
|
||||
bucket_name = 'aida-results'
|
||||
object_name = f'result_{generate_uuid()}.png'
|
||||
req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
|
||||
return f"{bucket_name}/{object_name}"
|
||||
@@ -4,20 +4,20 @@ import threading
|
||||
from celery import Celery
|
||||
from minio import Minio
|
||||
|
||||
from app.core.config import *
|
||||
from app.service.design_batch.item import BodyItem, TopItem, BottomItem, AccessoriesItem
|
||||
from app.core.config import settings
|
||||
from app.service.design_batch.item import BodyItem, TopItem, BottomItem, OthersItem
|
||||
from app.service.design_batch.utils.MQ import publish_status
|
||||
from app.service.design_batch.utils.organize import organize_body, organize_clothing, organize_accessories
|
||||
from app.service.design_batch.utils.organize import organize_body, organize_clothing, organize_others
|
||||
from app.service.design_batch.utils.save_json import oss_upload_json
|
||||
from app.service.design_batch.utils.synthesis_item import update_base_size_priority, synthesis, synthesis_single
|
||||
|
||||
id_lock = threading.Lock()
|
||||
celery_app = Celery('tasks', broker=f'amqp://rabbit:123456@18.167.251.121:5672//', backend='rpc://', BROKER_CONNECTION_RETRY_ON_STARTUP=True)
|
||||
celery_app = Celery('tasks', broker=f'amqp://{settings.MQ_USERNAME}:{settings.MQ_PASSWORD}@{settings.MQ_HOST}:{settings.MQ_PORT}//', backend='rpc://')
|
||||
celery_app.conf.worker_log_format = '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s'
|
||||
celery_app.conf.worker_hijack_root_logger = False
|
||||
logging.getLogger('pika').setLevel(logging.WARNING)
|
||||
logger = logging.getLogger()
|
||||
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
|
||||
|
||||
print("start")
|
||||
|
||||
@@ -33,8 +33,8 @@ def process_item(item, basic):
|
||||
elif item['type'].lower() in ['skirt', 'trousers', 'bottoms']:
|
||||
bottom_server = BottomItem(data=item, basic=basic, minio_client=minio_client)
|
||||
item_data = bottom_server.process()
|
||||
elif item['type'].lower() in ['accessories']:
|
||||
bottom_server = AccessoriesItem(data=item, basic=basic, minio_client=minio_client)
|
||||
elif item['type'].lower() in ['others']:
|
||||
bottom_server = OthersItem(data=item, basic=basic, minio_client=minio_client)
|
||||
item_data = bottom_server.process()
|
||||
else:
|
||||
raise NotImplementedError(f"Item type {item['type']} not implemented")
|
||||
@@ -47,14 +47,16 @@ def process_layer(item, layers):
|
||||
body_layer = organize_body(item)
|
||||
layers.append(body_layer)
|
||||
return item['body_image'].size
|
||||
elif item['name'] == 'accessories':
|
||||
front_layer, back_layer = organize_accessories(item)
|
||||
elif item['name'] == 'others':
|
||||
front_layer, back_layer = organize_others(item)
|
||||
layers.append(front_layer)
|
||||
layers.append(back_layer)
|
||||
return None
|
||||
else:
|
||||
front_layer, back_layer = organize_clothing(item)
|
||||
layers.append(front_layer)
|
||||
layers.append(back_layer)
|
||||
return None
|
||||
|
||||
|
||||
@celery_app.task
|
||||
@@ -76,12 +78,11 @@ def batch_design(objects_data, tasks_id, json_name):
|
||||
for item in object['items']:
|
||||
item_results.append(process_item(item, basic))
|
||||
layers = []
|
||||
body_size = None
|
||||
for item in item_results:
|
||||
body_size = process_layer(item, layers)
|
||||
process_layer(item, layers)
|
||||
layers = sorted(layers, key=lambda s: s.get("priority", float('inf')))
|
||||
|
||||
layers, new_size = update_base_size_priority(layers, body_size)
|
||||
layers, new_size = update_base_size_priority(layers)
|
||||
|
||||
for lay in layers:
|
||||
items_response['layers'].append({
|
||||
|
||||
@@ -9,10 +9,10 @@ class BaseItem:
|
||||
self.result.update(basic)
|
||||
|
||||
|
||||
class AccessoriesItem(BaseItem):
|
||||
class OthersItem(BaseItem):
|
||||
def __init__(self, data, basic, minio_client):
|
||||
super().__init__(data, basic)
|
||||
self.Accessories_pipeline = [
|
||||
self.Others_pipeline = [
|
||||
LoadImage(minio_client),
|
||||
# KeyPoint(),
|
||||
ContourDetection(),
|
||||
@@ -25,7 +25,7 @@ class AccessoriesItem(BaseItem):
|
||||
]
|
||||
|
||||
def process(self):
|
||||
for item in self.Accessories_pipeline:
|
||||
for item in self.Others_pipeline:
|
||||
self.result = item(self.result)
|
||||
return self.result
|
||||
|
||||
|
||||
@@ -18,11 +18,11 @@ class BackPerspective:
|
||||
result['back_perspective_url'] = file_path
|
||||
return result
|
||||
else:
|
||||
seg_result = get_seg_result("1", result['image'])[0]
|
||||
seg_result = get_seg_result(result['image'])[0]
|
||||
elif result['name'] in ['blouse', 'outwear', 'dress', 'tops']:
|
||||
seg_result = result['seg_result']
|
||||
else:
|
||||
seg_result = get_seg_result("1", result['image'])[0]
|
||||
seg_result = get_seg_result(result['image'])[0]
|
||||
|
||||
m = self.thicken_contours_and_display(seg_result, thickness=10, color=(0, 0, 0))
|
||||
back_sketch = result['image'].copy()
|
||||
@@ -34,7 +34,8 @@ class BackPerspective:
|
||||
result['back_perspective_url'] = f"{resp.bucket_name}/{resp.object_name}"
|
||||
return result
|
||||
|
||||
def thicken_contours_and_display(self, mask, thickness=10, color=(0, 0, 0)):
|
||||
@staticmethod
|
||||
def thicken_contours_and_display(mask, thickness=10, color=(0, 0, 0)):
|
||||
mask = mask.astype(np.uint8) * 255
|
||||
# 查找轮廓
|
||||
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
@@ -48,9 +49,9 @@ class BackPerspective:
|
||||
# 在空白图像上绘制白色的轮廓
|
||||
cv2.drawContours(blank, [contour], -1, 255, thickness=thick)
|
||||
# 找到轮廓的中心(可以用重心等方法近似)
|
||||
M = cv2.moments(contour)
|
||||
cx = int(M['m10'] / M['m00'])
|
||||
cy = int(M['m01'] / M['m00'])
|
||||
m = cv2.moments(contour)
|
||||
cx = int(m['m10'] / m['m00'])
|
||||
cy = int(m['m01'] / m['m00'])
|
||||
# 进行距离变换,离中心越近的值越小
|
||||
dist_transform = cv2.distanceTransform(255 - blank, cv2.DIST_L2, 5)
|
||||
# 根据距离变换的值来决定是否保留像素,离中心近的像素更容易被保留
|
||||
|
||||
@@ -79,9 +79,9 @@ class Color:
|
||||
def get_pattern(single_color):
|
||||
if single_color is None:
|
||||
raise False
|
||||
R, G, B = single_color.split(' ')
|
||||
r, g, b = single_color.split(' ')
|
||||
pattern = np.zeros([1, 1, 3], np.uint8)
|
||||
pattern[0, 0, 0] = int(B)
|
||||
pattern[0, 0, 1] = int(G)
|
||||
pattern[0, 0, 2] = int(R)
|
||||
pattern[0, 0, 0] = int(b)
|
||||
pattern[0, 0, 1] = int(g)
|
||||
pattern[0, 0, 2] = int(r)
|
||||
return pattern
|
||||
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
import numpy as np
|
||||
from pymilvus import MilvusClient
|
||||
|
||||
from app.core.config import *
|
||||
from app.core.config import KEYPOINT_RESULT_TABLE_FIELD_SET, MILVUS_TABLE_KEYPOINT, settings
|
||||
from app.service.design_fast.utils.design_ensemble import get_keypoint_result
|
||||
from app.service.utils.decorator import ClassCallRunTime, RunTime
|
||||
|
||||
@@ -21,12 +21,12 @@ class KeyPoint:
|
||||
def __call__(self, result):
|
||||
if result['name'] in ['blouse', 'skirt', 'dress', 'outwear', 'trousers', 'tops', 'bottoms']: # 查询是否有数据 且类别相同 相同则直接读 不同则推理后更新
|
||||
# result['clothes_keypoint'] = self.infer_keypoint_result(result)
|
||||
site = 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
|
||||
# 'up' if result['name'] in ['blouse', 'outwear', 'dress', 'tops'] else 'down'
|
||||
# keypoint_cache = search_keypoint_cache(result["image_id"], site)
|
||||
# keypoint_cache = self.keypoint_cache(result, site)
|
||||
keypoint_cache = False
|
||||
# 取消向量查询 直接过模型推理
|
||||
if keypoint_cache is False:
|
||||
if not keypoint_cache:
|
||||
keypoint_infer_result, site = self.infer_keypoint_result(result)
|
||||
result['clothes_keypoint'] = self.save_keypoint_cache(result["image_id"], keypoint_infer_result, site)
|
||||
else:
|
||||
@@ -55,8 +55,8 @@ class KeyPoint:
|
||||
}
|
||||
]
|
||||
try:
|
||||
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
|
||||
res = client.upsert(collection_name=MILVUS_TABLE_KEYPOINT, data=data)
|
||||
client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS)
|
||||
client.upsert(collection_name=MILVUS_TABLE_KEYPOINT, data=data)
|
||||
client.close()
|
||||
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
|
||||
except Exception as e:
|
||||
@@ -79,7 +79,7 @@ class KeyPoint:
|
||||
]
|
||||
|
||||
try:
|
||||
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
|
||||
client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS)
|
||||
client.upsert(
|
||||
collection_name=MILVUS_TABLE_KEYPOINT,
|
||||
data=data
|
||||
@@ -92,7 +92,7 @@ class KeyPoint:
|
||||
@RunTime
|
||||
def keypoint_cache(self, result, site):
|
||||
try:
|
||||
client = MilvusClient(uri=MILVUS_URL, token=MILVUS_TOKEN, db_name=MILVUS_ALIAS)
|
||||
client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS)
|
||||
keypoint_id = result['image_id']
|
||||
res = client.query(
|
||||
collection_name=MILVUS_TABLE_KEYPOINT,
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
import io
|
||||
import logging
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from app.service.utils.new_oss_client import oss_get_image
|
||||
|
||||
@@ -74,8 +71,8 @@ class LoadImage:
|
||||
keypoint = 'head_point'
|
||||
elif name == 'earring':
|
||||
keypoint = 'ear_point'
|
||||
elif name == 'accessories':
|
||||
keypoint = "accessories"
|
||||
elif name == 'others':
|
||||
keypoint = "others"
|
||||
else:
|
||||
raise KeyError(f"{name} does not belong to item category list: blouse, outwear, dress, trousers, skirt, "
|
||||
f"bag, shoes, hairstyle, earring.")
|
||||
|
||||
@@ -9,6 +9,7 @@ from app.service.utils.new_oss_client import oss_get_image
|
||||
|
||||
class PrintPainting:
|
||||
def __init__(self, minio_client):
|
||||
self.random_seed = None
|
||||
self.minio_client = minio_client
|
||||
|
||||
def __call__(self, result):
|
||||
@@ -408,7 +409,7 @@ class PrintPainting:
|
||||
change_mask = print_mask[start_h: length_h, start_w: length_w]
|
||||
# get real part into change mask
|
||||
_, change_mask = cv2.threshold(change_mask, 220, 255, cv2.THRESH_BINARY)
|
||||
mask = cv2.bitwise_not(painting_dict['mask_inv_print'])
|
||||
cv2.bitwise_not(painting_dict['mask_inv_print'])
|
||||
img_fg[start_h:start_h + painting_dict['dim_print_h'], start_w:start_w + painting_dict['dim_print_w'], :] = change_region
|
||||
|
||||
clothes_mask_print = cv2.bitwise_not(print_mask)
|
||||
|
||||
@@ -46,7 +46,7 @@ class Scaling:
|
||||
result['scale'] = result['scale_bag']
|
||||
elif result['keypoint'] == 'ear_point':
|
||||
result['scale'] = result['scale_earrings']
|
||||
elif result['keypoint'] == 'accessories':
|
||||
elif result['keypoint'] == 'others':
|
||||
# 由于没有识别配饰keypoint的模型 所以统一将配饰的两个关键点设定为 (0,0) (0,img.width)
|
||||
# 模特的关键点设定为(0,0) (0,320/2) 距离比例简写为 160 / img.width
|
||||
distance_clo = result['img_shape'][1]
|
||||
|
||||
@@ -4,7 +4,7 @@ import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
from app.core.config import SEG_CACHE_PATH
|
||||
from app.core.config import settings
|
||||
from app.service.design_fast.utils.design_ensemble import get_seg_result
|
||||
from app.service.utils.decorator import ClassCallRunTime
|
||||
from app.service.utils.new_oss_client import oss_get_image
|
||||
@@ -36,11 +36,11 @@ class Segmentation:
|
||||
# preview 过模型 不缓存
|
||||
if "preview_submit" in result.keys() and result['preview_submit'] == "preview":
|
||||
# 推理获得seg 结果
|
||||
seg_result = get_seg_result(result["image_id"], result['image'])
|
||||
seg_result = get_seg_result(result['image'])
|
||||
# submit 过模型 缓存
|
||||
elif "preview_submit" in result.keys() and result['preview_submit'] == "submit":
|
||||
# 推理获得seg 结果
|
||||
seg_result = get_seg_result(result["image_id"], result['image'])
|
||||
seg_result = get_seg_result(result['image'])
|
||||
self.save_seg_result(seg_result, result['image_id'])
|
||||
# null 正常流程 加载本地缓存 无缓存则过模型
|
||||
else:
|
||||
@@ -49,7 +49,7 @@ class Segmentation:
|
||||
# 判断缓存和实际图片size是否相同
|
||||
if not _ or result["image"].shape[:2] != seg_result.shape:
|
||||
# 推理获得seg 结果
|
||||
seg_result = get_seg_result(result["image_id"], result['image'])
|
||||
seg_result = get_seg_result(result['image'])
|
||||
self.save_seg_result(seg_result, result['image_id'])
|
||||
result['seg_result'] = seg_result
|
||||
|
||||
@@ -63,7 +63,7 @@ class Segmentation:
|
||||
|
||||
@staticmethod
|
||||
def save_seg_result(seg_result, image_id):
|
||||
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
|
||||
file_path = f"{settings.SEG_CACHE_PATH}{image_id}.npy"
|
||||
try:
|
||||
np.save(file_path, seg_result)
|
||||
logger.debug(f"保存成功 :{os.path.abspath(file_path)}")
|
||||
@@ -72,7 +72,7 @@ class Segmentation:
|
||||
|
||||
@staticmethod
|
||||
def load_seg_result(image_id):
|
||||
file_path = f"{SEG_CACHE_PATH}{image_id}.npy"
|
||||
file_path = f"{settings.SEG_CACHE_PATH}{image_id}.npy"
|
||||
# logger.info(f"load seg file name is :{SEG_CACHE_PATH}{image_id}.npy")
|
||||
try:
|
||||
seg_result = np.load(file_path)
|
||||
|
||||
@@ -4,9 +4,7 @@ import logging
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from cv2 import cvtColor, COLOR_BGR2RGBA
|
||||
|
||||
from app.core.config import AIDA_CLOTHING
|
||||
from app.service.design_fast.utils.conversion_image import rgb_to_rgba
|
||||
from app.service.design_fast.utils.transparent import sketch_to_transparent
|
||||
from app.service.design_fast.utils.upload_image import upload_png_mask
|
||||
@@ -21,7 +19,7 @@ class Split(object):
|
||||
def __call__(self, result):
|
||||
try:
|
||||
|
||||
if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms', 'accessories'):
|
||||
if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms', 'others'):
|
||||
|
||||
if result['resize_scale'][0] == 1.0 and result['resize_scale'][1] == 1.0:
|
||||
front_mask = result['front_mask']
|
||||
@@ -40,7 +38,7 @@ class Split(object):
|
||||
result_front_image = np.zeros_like(rgba_image)
|
||||
front_mask = cv2.resize(front_mask, new_size)
|
||||
result_front_image[front_mask != 0] = rgba_image[front_mask != 0]
|
||||
result_front_image_pil = Image.fromarray(cvtColor(result_front_image, COLOR_BGR2RGBA))
|
||||
result_front_image_pil = Image.fromarray(cv2.cvtColor(result_front_image, cv2.COLOR_BGR2RGBA))
|
||||
if 'transparent' in result.keys():
|
||||
# 用户自选区域transparent
|
||||
transparent = result['transparent']
|
||||
@@ -98,21 +96,21 @@ class Split(object):
|
||||
result_back_image = np.zeros_like(rgba_image)
|
||||
back_mask = cv2.resize(back_mask, new_size)
|
||||
result_back_image[back_mask != 0] = rgba_image[back_mask != 0]
|
||||
result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA))
|
||||
result_back_image_pil = Image.fromarray(cv2.cvtColor(result_back_image, cv2.COLOR_BGR2RGBA))
|
||||
result['back_image'], result["back_image_url"], _ = upload_png_mask(self.minio_client, result_back_image_pil, f'{generate_uuid()}', mask=None)
|
||||
mask_image[back_mask != 0] = [0, 255, 0]
|
||||
|
||||
rbga_mask = rgb_to_rgba(mask_image, front_mask + back_mask)
|
||||
mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
|
||||
mask_pil = Image.fromarray(cv2.cvtColor(rbga_mask.astype(np.uint8), cv2.COLOR_BGR2RGBA))
|
||||
image_data = io.BytesIO()
|
||||
mask_pil.save(image_data, format='PNG')
|
||||
image_data.seek(0)
|
||||
image_bytes = image_data.read()
|
||||
req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
|
||||
req = oss_upload_image(oss_client=self.minio_client, bucket="aida-clothing", object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
|
||||
result['mask_url'] = req.bucket_name + "/" + req.object_name
|
||||
# 创建中间图层
|
||||
result_pattern_image_rgba = rgb_to_rgba(result['pattern_image'], result['mask'])
|
||||
result_pattern_image_pil = Image.fromarray(cvtColor(result_pattern_image_rgba, COLOR_BGR2RGBA))
|
||||
result_pattern_image_pil = Image.fromarray(cv2.cvtColor(result_pattern_image_rgba, cv2.COLOR_BGR2RGBA))
|
||||
result['pattern_image'], result['pattern_image_url'], _ = upload_png_mask(self.minio_client, result_pattern_image_pil, f'{generate_uuid()}')
|
||||
return result
|
||||
except Exception as e:
|
||||
|
||||
@@ -2,16 +2,17 @@ import json
|
||||
|
||||
import pika
|
||||
|
||||
from app.core.config import RABBITMQ_PARAMS, BATCH_DESIGN_RABBITMQ_QUEUES
|
||||
from app.core.config import settings
|
||||
from app.core.rabbit_mq_config import RABBITMQ_PARAMS
|
||||
|
||||
|
||||
def publish_status(task_id, progress, result):
|
||||
connection = pika.BlockingConnection(pika.ConnectionParameters(**RABBITMQ_PARAMS))
|
||||
channel = connection.channel()
|
||||
channel.queue_declare(queue=BATCH_DESIGN_RABBITMQ_QUEUES, durable=True)
|
||||
channel.queue_declare(queue=settings.BATCH_DESIGN_RABBITMQ_QUEUES, durable=True)
|
||||
message = {'task_id': task_id, 'progress': progress, "result": result}
|
||||
channel.basic_publish(exchange='',
|
||||
routing_key=BATCH_DESIGN_RABBITMQ_QUEUES,
|
||||
routing_key=settings.BATCH_DESIGN_RABBITMQ_QUEUES,
|
||||
body=json.dumps(message),
|
||||
properties=pika.BasicProperties(
|
||||
delivery_mode=2,
|
||||
|
||||
@@ -16,7 +16,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
import tritonclient.http as httpclient
|
||||
|
||||
from app.core.config import *
|
||||
from app.core.config import DESIGN_MODEL_URL, DESIGN_MODEL_NAME
|
||||
|
||||
"""
|
||||
keypoint
|
||||
@@ -91,29 +91,29 @@ def seg_preprocess(img_path):
|
||||
|
||||
|
||||
# @ RunTime
|
||||
def get_seg_result(image_id, image):
|
||||
def get_seg_result(image):
|
||||
image, ori_shape = seg_preprocess(image)
|
||||
client = httpclient.InferenceServerClient(url=f"{DESIGN_MODEL_URL}")
|
||||
transformed_img = image.astype(np.float32)
|
||||
# 输入集
|
||||
inputs = [
|
||||
httpclient.InferInput(SEGMENTATION['input'], transformed_img.shape, datatype="FP32")
|
||||
httpclient.InferInput(DESIGN_MODEL_NAME, transformed_img.shape, datatype="FP32")
|
||||
]
|
||||
inputs[0].set_data_from_numpy(transformed_img, binary_data=True)
|
||||
# 输出集
|
||||
outputs = [
|
||||
httpclient.InferRequestedOutput(SEGMENTATION['output'], binary_data=True),
|
||||
httpclient.InferRequestedOutput("seg_input__0", binary_data=True),
|
||||
]
|
||||
results = client.infer(model_name=SEGMENTATION['new_model_name'], inputs=inputs, outputs=outputs)
|
||||
results = client.infer(model_name=DESIGN_MODEL_NAME, inputs=inputs, outputs=outputs)
|
||||
# 推理
|
||||
# 取结果
|
||||
inference_output1 = results.as_numpy(SEGMENTATION['output'])
|
||||
seg_result = seg_postprocess(int(image_id), inference_output1, ori_shape)
|
||||
inference_output1 = results.as_numpy("seg_input__0")
|
||||
seg_result = seg_postprocess(inference_output1, ori_shape)
|
||||
return seg_result
|
||||
|
||||
|
||||
# no cache
|
||||
def seg_postprocess(image_id, output, ori_shape):
|
||||
def seg_postprocess(output, ori_shape):
|
||||
seg_logit = F.interpolate(torch.tensor(output).float(), size=ori_shape, scale_factor=None, mode='bilinear', align_corners=False)
|
||||
seg_pred = seg_logit.cpu().numpy()
|
||||
return seg_pred[0]
|
||||
|
||||
@@ -55,7 +55,7 @@ def organize_clothing(layer):
|
||||
return front_layer, back_layer
|
||||
|
||||
|
||||
def organize_accessories(layer):
|
||||
def organize_others(layer):
|
||||
# 起始坐标
|
||||
start_point = (0, 0)
|
||||
# 前片数据
|
||||
@@ -98,6 +98,8 @@ def calculate_start_point(keypoint_type, scale, clothes_point, body_point, offse
|
||||
"""
|
||||
Align left
|
||||
Args:
|
||||
offset:
|
||||
resize_scale:
|
||||
keypoint_type: string, "waistband" | "shoulder" | "ear_point"
|
||||
scale: float
|
||||
clothes_point: dict{'left': [x1, y1, z1], 'right': [x2, y2, z2]}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
|
||||
from app.service.design_fast.utils.redis_utils import Redis
|
||||
from app.service.utils.redis_utils import Redis
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,99 +0,0 @@
|
||||
import redis
|
||||
|
||||
from app.core.config import REDIS_HOST, REDIS_PORT
|
||||
|
||||
|
||||
class Redis(object):
|
||||
"""
|
||||
redis数据库操作
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _get_r():
|
||||
host = REDIS_HOST
|
||||
port = REDIS_PORT
|
||||
db = 0
|
||||
r = redis.StrictRedis(host, port, db)
|
||||
return r
|
||||
|
||||
@classmethod
|
||||
def write(cls, key, value, expire=None):
|
||||
"""
|
||||
写入键值对
|
||||
"""
|
||||
# 判断是否有过期时间,没有就设置默认值
|
||||
if expire:
|
||||
expire_in_seconds = expire
|
||||
else:
|
||||
expire_in_seconds = 100
|
||||
r = cls._get_r()
|
||||
r.set(key, value, ex=expire_in_seconds)
|
||||
|
||||
@classmethod
|
||||
def read(cls, key):
|
||||
"""
|
||||
读取键值对内容
|
||||
"""
|
||||
r = cls._get_r()
|
||||
value = r.get(key)
|
||||
return value.decode('utf-8') if value else value
|
||||
|
||||
@classmethod
|
||||
def hset(cls, name, key, value):
|
||||
"""
|
||||
写入hash表
|
||||
"""
|
||||
r = cls._get_r()
|
||||
r.hset(name, key, value)
|
||||
|
||||
@classmethod
|
||||
def hget(cls, name, key):
|
||||
"""
|
||||
读取指定hash表的键值
|
||||
"""
|
||||
r = cls._get_r()
|
||||
value = r.hget(name, key)
|
||||
return value.decode('utf-8') if value else value
|
||||
|
||||
@classmethod
|
||||
def hgetall(cls, name):
|
||||
"""
|
||||
获取指定hash表所有的值
|
||||
"""
|
||||
r = cls._get_r()
|
||||
return r.hgetall(name)
|
||||
|
||||
@classmethod
|
||||
def delete(cls, *names):
|
||||
"""
|
||||
删除一个或者多个
|
||||
"""
|
||||
r = cls._get_r()
|
||||
r.delete(*names)
|
||||
|
||||
@classmethod
|
||||
def hdel(cls, name, key):
|
||||
"""
|
||||
删除指定hash表的键值
|
||||
"""
|
||||
r = cls._get_r()
|
||||
r.hdel(name, key)
|
||||
|
||||
@classmethod
|
||||
def expire(cls, name, expire=None):
|
||||
"""
|
||||
设置过期时间
|
||||
"""
|
||||
if expire:
|
||||
expire_in_seconds = expire
|
||||
else:
|
||||
expire_in_seconds = 100
|
||||
r = cls._get_r()
|
||||
r.expire(name, expire_in_seconds)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
redis_client = Redis()
|
||||
# print(redis_client.write(key="1230", value=0))
|
||||
redis_client.write(key="1230", value=10)
|
||||
# print(redis_client.read(key="1230"))
|
||||
@@ -13,9 +13,12 @@ import logging
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from minio import Minio
|
||||
from app.core.config import settings
|
||||
from app.service.utils.generate_uuid import generate_uuid
|
||||
from app.service.utils.oss_client import oss_upload_image
|
||||
from app.service.utils.new_oss_client import oss_upload_image
|
||||
|
||||
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
|
||||
|
||||
|
||||
def positioning(all_mask_shape, mask_shape, offset):
|
||||
@@ -136,7 +139,7 @@ def synthesis(data, size, basic_info):
|
||||
image_bytes = image_data.read()
|
||||
bucket_name = "aida-results"
|
||||
object_name = f'result_{generate_uuid()}.png'
|
||||
req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
|
||||
oss_upload_image(oss_client=minio_client, bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
|
||||
return f"{bucket_name}/{object_name}"
|
||||
# return f"aida-results/{minio_client.put_object('aida-results', f'result_{generate_uuid()}.png', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
|
||||
|
||||
@@ -177,11 +180,11 @@ def synthesis_single(front_image, back_image):
|
||||
# oss upload
|
||||
bucket_name = 'aida-results'
|
||||
object_name = f'result_{generate_uuid()}.png'
|
||||
req = oss_upload_image(bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
|
||||
oss_upload_image(oss_client=minio_client, bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
|
||||
return f"{bucket_name}/{object_name}"
|
||||
|
||||
|
||||
def update_base_size_priority(layers, size):
|
||||
def update_base_size_priority(layers):
|
||||
# 计算透明背景图片的宽度
|
||||
min_x = min(info['position'][1] for info in layers)
|
||||
x_list = []
|
||||
|
||||
@@ -12,7 +12,6 @@ import logging
|
||||
|
||||
import cv2
|
||||
|
||||
from app.core.config import *
|
||||
from app.service.utils.new_oss_client import oss_upload_image
|
||||
|
||||
|
||||
@@ -25,15 +24,15 @@ def upload_png_mask(minio_client, front_image, object_name, mask=None):
|
||||
# 将掩模的3通道转换为4通道,白色部分不透明,黑色部分透明
|
||||
rgba_image = cv2.cvtColor(mask_inverted, cv2.COLOR_BGR2BGRA)
|
||||
rgba_image[rgba_image[:, :, 0] == 0] = [0, 0, 0, 0]
|
||||
req = oss_upload_image(oss_client=minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{object_name}.png", image_bytes=cv2.imencode('.png', rgba_image)[1])
|
||||
mask_url = f"{AIDA_CLOTHING}/mask/mask_{object_name}.png"
|
||||
req = oss_upload_image(oss_client=minio_client, bucket="aida-clothing", object_name=f"mask/mask_{object_name}.png", image_bytes=cv2.imencode('.png', rgba_image)[1])
|
||||
mask_url = f"aida-clothing/mask/mask_{object_name}.png"
|
||||
|
||||
image_data = io.BytesIO()
|
||||
front_image.save(image_data, format='PNG')
|
||||
image_data.seek(0)
|
||||
image_bytes = image_data.read()
|
||||
req = oss_upload_image(oss_client=minio_client, bucket=AIDA_CLOTHING, object_name=f"image/image_{object_name}.png", image_bytes=image_bytes)
|
||||
image_url = f"{AIDA_CLOTHING}/image/image_{object_name}.png"
|
||||
req = oss_upload_image(oss_client=minio_client, bucket="aida-clothing", object_name=f"image/image_{object_name}.png", image_bytes=image_bytes)
|
||||
image_url = f"aida-clothing/image/image_{object_name}.png"
|
||||
return front_image, image_url, mask_url
|
||||
except Exception as e:
|
||||
logging.warning(f"upload_png_mask runtime exception : {e}")
|
||||
|
||||
@@ -5,36 +5,60 @@ import time
|
||||
import requests
|
||||
from minio import Minio
|
||||
|
||||
from app.core.config import *
|
||||
from app.service.design_fast.item import BodyItem, TopItem, BottomItem, AccessoriesItem
|
||||
from app.service.design_fast.utils.organize import organize_body, organize_clothing, organize_accessories
|
||||
from app.core.config import settings
|
||||
from app.service.design_fast.item import BodyItem, TopItem, BottomItem, OthersItem, TopMergeItem, BottomMergeItem, OthersMergeItem
|
||||
from app.service.design_fast.utils.organize import organize_body, organize_clothing, organize_others
|
||||
from app.service.design_fast.utils.progress import final_progress, update_progress
|
||||
from app.service.design_fast.utils.synthesis_item import synthesis, synthesis_single, update_base_size_priority
|
||||
from app.service.design_fast.utils.synthesis_item import synthesis, synthesis_single, update_base_size_priority, merge
|
||||
from app.service.utils.decorator import RunTime
|
||||
|
||||
id_lock = threading.Lock()
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
|
||||
|
||||
|
||||
def process_item(item, basic):
|
||||
# 处理project中单个item
|
||||
if item['type'] == "Body":
|
||||
body_server = BodyItem(data=item, basic=basic, minio_client=minio_client)
|
||||
item_data = body_server.process()
|
||||
elif item['type'].lower() in ['blouse', 'outwear', 'dress', 'tops']:
|
||||
top_server = TopItem(data=item, basic=basic, minio_client=minio_client)
|
||||
item_data = top_server.process()
|
||||
elif item['type'].lower() in ['skirt', 'trousers', 'bottoms']:
|
||||
bottom_server = BottomItem(data=item, basic=basic, minio_client=minio_client)
|
||||
item_data = bottom_server.process()
|
||||
elif item['type'].lower() in ['accessories']:
|
||||
bottom_server = AccessoriesItem(data=item, basic=basic, minio_client=minio_client)
|
||||
item_data = bottom_server.process()
|
||||
def process_item(item, basic, design_type):
|
||||
# 1. 定义映射配置
|
||||
# key 为 item_type 的小写,value 为对应的处理类
|
||||
DESIGN_MAP = {
|
||||
'body': BodyItem,
|
||||
'blouse': TopItem, 'outwear': TopItem,
|
||||
'dress': TopItem, 'tops': TopItem,
|
||||
'skirt': BottomItem, 'trousers': BottomItem,
|
||||
'bottoms': BottomItem,
|
||||
'others': OthersItem
|
||||
}
|
||||
|
||||
MERGE_MAP = {
|
||||
'body_merge': BodyItem,
|
||||
'blouse_merge': TopMergeItem, 'outwear_merge': TopMergeItem,
|
||||
'dress_merge': TopMergeItem, 'tops_merge': TopMergeItem,
|
||||
'skirt_merge': BottomMergeItem, 'trousers_merge': BottomMergeItem,
|
||||
'bottoms_merge': BottomMergeItem,
|
||||
'others_merge': OthersMergeItem
|
||||
}
|
||||
|
||||
# 2. 根据 design_type 选择映射表
|
||||
mapping = MERGE_MAP if design_type == 'merge' else DESIGN_MAP
|
||||
|
||||
if design_type == 'merge':
|
||||
item_type_key = f"{item['type'].lower()}_merge"
|
||||
elif design_type == 'default':
|
||||
item_type_key = item['type'].lower()
|
||||
else:
|
||||
raise NotImplementedError(f"Item type {item['type']} not implemented")
|
||||
item_type_key = item['type'].lower()
|
||||
|
||||
handler_class = mapping.get(item_type_key)
|
||||
|
||||
if not handler_class:
|
||||
raise NotImplementedError(f"Item type {item['type']} not implemented for design_type={design_type}")
|
||||
|
||||
# 4. 统一实例化并执行
|
||||
# 注意:这里假设所有 Item 类构造函数签名一致
|
||||
server = handler_class(data=item, basic=basic, minio_client=minio_client)
|
||||
item_data = server.process()
|
||||
return item_data
|
||||
|
||||
|
||||
@@ -44,14 +68,16 @@ def process_layer(item, layers):
|
||||
body_layer = organize_body(item)
|
||||
layers.append(body_layer)
|
||||
return item['body_image'].size
|
||||
elif item['name'] == 'accessories':
|
||||
front_layer, back_layer = organize_accessories(item)
|
||||
elif item['name'] in ['others', 'others_merge']:
|
||||
front_layer, back_layer = organize_others(item)
|
||||
layers.append(front_layer)
|
||||
layers.append(back_layer)
|
||||
return None
|
||||
else:
|
||||
front_layer, back_layer = organize_clothing(item)
|
||||
layers.append(front_layer)
|
||||
layers.append(back_layer)
|
||||
return None
|
||||
|
||||
|
||||
@RunTime
|
||||
@@ -68,18 +94,18 @@ def design_generate(request_data):
|
||||
nonlocal active_threads
|
||||
basic = object['basic']
|
||||
items_response = {'layers': [], 'objectSign': object['objectSign'] if 'objectSign' in object.keys() else ""}
|
||||
design_type = basic.get('design_type', "default")
|
||||
if basic['single_overall'] == "overall":
|
||||
item_results = []
|
||||
for item in object['items']:
|
||||
item_results.append(process_item(item, basic))
|
||||
item_results.append(process_item(item, basic, design_type))
|
||||
layers = []
|
||||
body_size = None
|
||||
for item in item_results:
|
||||
body_size = process_layer(item, layers)
|
||||
process_layer(item, layers)
|
||||
layers = sorted(layers, key=lambda s: s.get("priority", float('inf')))
|
||||
|
||||
layers, new_size = update_base_size_priority(layers, body_size)
|
||||
|
||||
layers, new_size = update_base_size_priority(layers)
|
||||
# pattern_overall_image_url 、 pattern_print_image_url
|
||||
for lay in layers:
|
||||
items_response['layers'].append({
|
||||
'image_category': "body" if lay['name'] == 'mannequin' else lay['name'],
|
||||
@@ -90,10 +116,19 @@ def design_generate(request_data):
|
||||
'gradient_string': lay['gradient_string'] if 'gradient_string' in lay.keys() else "",
|
||||
'mask_url': lay['mask_url'],
|
||||
'image_url': lay['image_url'] if 'image_url' in lay.keys() else None,
|
||||
'pattern_image_url': lay['pattern_image_url'] if 'pattern_image_url' in lay.keys() else None,
|
||||
'pattern_overall_image_url': lay['pattern_overall_image_url'] if 'pattern_overall_image_url' in lay.keys() else None,
|
||||
'pattern_print_image_url': lay['pattern_print_image_url'] if 'pattern_print_image_url' in lay.keys() else None,
|
||||
'transpose': lay.get('transpose', None),
|
||||
'rotate': lay.get('rotate', None),
|
||||
# 'back_perspective_url': lay['back_perspective_url'] if 'back_perspective_url' in lay.keys() else None,
|
||||
})
|
||||
items_response['synthesis_url'] = synthesis(layers, new_size, basic)
|
||||
if basic.get('design_type') == 'default':
|
||||
items_response['synthesis_url'] = synthesis(layers, new_size, basic)
|
||||
elif basic.get('design_type') == 'merge':
|
||||
items_response['synthesis_url'] = merge(layers, new_size, basic)
|
||||
else:
|
||||
items_response['synthesis_url'] = synthesis(layers, new_size, basic)
|
||||
|
||||
else:
|
||||
item_result = process_item(object['items'][0], basic)
|
||||
items_response['layers'].append({
|
||||
@@ -104,7 +139,9 @@ def design_generate(request_data):
|
||||
'image_url': item_result['front_image_url'],
|
||||
'mask_url': item_result['mask_url'],
|
||||
"gradient_string": item_result['gradient_string'] if 'gradient_string' in item_result.keys() else "",
|
||||
'pattern_image_url': item_result['pattern_image_url'] if 'pattern_image_url' in item_result.keys() else None,
|
||||
'pattern_overall_image_url': item_result['pattern_overall_image_url'] if 'pattern_overall_image_url' in item_result.keys() else None,
|
||||
'pattern_print_image_url': item_result['pattern_print_image_url'] if 'pattern_print_image_url' in item_result.keys() else None,
|
||||
|
||||
})
|
||||
items_response['layers'].append({
|
||||
'image_category': f"{item_result['name']}_back",
|
||||
@@ -114,7 +151,9 @@ def design_generate(request_data):
|
||||
'image_url': item_result['back_image_url'],
|
||||
'mask_url': item_result['mask_url'],
|
||||
"gradient_string": item_result['gradient_string'] if 'gradient_string' in item_result.keys() else "",
|
||||
'pattern_image_url': item_result['pattern_image_url'] if 'pattern_image_url' in item_result.keys() else None,
|
||||
'pattern_overall_image_url': item_result['pattern_overall_image_url'] if 'pattern_overall_image_url' in item_result.keys() else None,
|
||||
'pattern_print_image_url': item_result['pattern_print_image_url'] if 'pattern_print_image_url' in item_result.keys() else None,
|
||||
|
||||
})
|
||||
items_response['synthesis_url'] = synthesis_single(item_result['front_image'], item_result['back_image'])
|
||||
update_progress(process_id, total)
|
||||
@@ -139,10 +178,11 @@ def design_generate(request_data):
|
||||
@RunTime
|
||||
def design_generate_v2(request_data):
|
||||
objects_data = request_data.dict()['objects']
|
||||
callback_url = request_data.callback_url
|
||||
request_id = request_data.requestId
|
||||
threads = []
|
||||
|
||||
def process_object(step, object):
|
||||
def process_object(object, callback_url):
|
||||
basic = object['basic']
|
||||
items_response = {
|
||||
'layers': [],
|
||||
@@ -154,12 +194,11 @@ def design_generate_v2(request_data):
|
||||
for item in object['items']:
|
||||
item_results.append(process_item(item, basic))
|
||||
layers = []
|
||||
body_size = None
|
||||
for item in item_results:
|
||||
body_size = process_layer(item, layers)
|
||||
process_layer(item, layers)
|
||||
layers = sorted(layers, key=lambda s: s.get("priority", float('inf')))
|
||||
|
||||
layers, new_size = update_base_size_priority(layers, body_size)
|
||||
layers, new_size = update_base_size_priority(layers)
|
||||
|
||||
for lay in layers:
|
||||
items_response['layers'].append({
|
||||
@@ -171,7 +210,9 @@ def design_generate_v2(request_data):
|
||||
'gradient_string': lay['gradient_string'] if 'gradient_string' in lay.keys() else "",
|
||||
'mask_url': lay['mask_url'],
|
||||
'image_url': lay['image_url'] if 'image_url' in lay.keys() else None,
|
||||
'pattern_image_url': lay['pattern_image_url'] if 'pattern_image_url' in lay.keys() else None,
|
||||
'pattern_overall_image_url': lay['pattern_overall_image_url'] if 'pattern_overall_image_url' in lay.keys() else None,
|
||||
'pattern_print_image_url': lay['pattern_print_image_url'] if 'pattern_print_image_url' in lay.keys() else None,
|
||||
|
||||
# 'back_perspective_url': lay['back_perspective_url'] if 'back_perspective_url' in lay.keys() else None,
|
||||
})
|
||||
items_response['synthesis_url'] = synthesis(layers, new_size, basic)
|
||||
@@ -185,7 +226,9 @@ def design_generate_v2(request_data):
|
||||
'image_url': item_result['front_image_url'],
|
||||
'mask_url': item_result['mask_url'],
|
||||
"gradient_string": item_result['gradient_string'] if 'gradient_string' in item_result.keys() else "",
|
||||
'pattern_image_url': item_result['pattern_image_url'] if 'pattern_image_url' in item_result.keys() else None,
|
||||
'pattern_overall_image_url': item_result['pattern_overall_image_url'] if 'pattern_overall_image_url' in item_result.keys() else None,
|
||||
'pattern_print_image_url': item_result['pattern_print_image_url'] if 'pattern_print_image_url' in item_result.keys() else None,
|
||||
|
||||
})
|
||||
items_response['layers'].append({
|
||||
'image_category': f"{item_result['name']}_back",
|
||||
@@ -195,16 +238,14 @@ def design_generate_v2(request_data):
|
||||
'image_url': item_result['back_image_url'],
|
||||
'mask_url': item_result['mask_url'],
|
||||
"gradient_string": item_result['gradient_string'] if 'gradient_string' in item_result.keys() else "",
|
||||
'pattern_image_url': item_result['pattern_image_url'] if 'pattern_image_url' in item_result.keys() else None,
|
||||
'pattern_overall_image_url': item_result['pattern_overall_image_url'] if 'pattern_overall_image_url' in item_result.keys() else None,
|
||||
'pattern_print_image_url': item_result['pattern_print_image_url'] if 'pattern_print_image_url' in item_result.keys() else None,
|
||||
|
||||
})
|
||||
items_response['synthesis_url'] = synthesis_single(item_result['front_image'], item_result['back_image'])
|
||||
# 发送结果给java端
|
||||
url = JAVA_STREAM_API_URL
|
||||
# xu_pei_test_url = "https://cd21b9110505.ngrok-free.app/api/third/party/receiveDesignResults"
|
||||
tianxaing_test_url = "https://c2ae520723c9.ngrok-free.app/api/third/party/receiveDesignResults"
|
||||
url = callback_url
|
||||
logger.info(f"java 回调 -> {url}")
|
||||
# logger.info(f"xupei java 回调 -> {xu_pei_test_url}")
|
||||
logger.info(f"tianxiang java 回调 -> {tianxaing_test_url}")
|
||||
|
||||
headers = {
|
||||
'Accept': "*/*",
|
||||
@@ -219,16 +260,8 @@ def design_generate_v2(request_data):
|
||||
# 打印结果
|
||||
logger.info(response.text)
|
||||
|
||||
# test_response = post_request(xu_pei_test_url, json_data=items_response, headers=headers)
|
||||
test_response = post_request(tianxaing_test_url, json_data=items_response, headers=headers)
|
||||
|
||||
if test_response:
|
||||
# 打印结果
|
||||
# logger.info(f"xupei test response : {test_response.text}")
|
||||
logger.info(f"tianxiang test response : {test_response.text}")
|
||||
|
||||
for step, object in enumerate(objects_data):
|
||||
t = threading.Thread(target=process_object, args=(step, object))
|
||||
t = threading.Thread(target=process_object, args=(object, callback_url))
|
||||
threads.append(t)
|
||||
t.start()
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from app.service.design_fast.pipeline import LoadImage, KeyPoint, Segmentation, Color, PrintPainting, Scaling, Split, LoadBodyImage, ContourDetection
|
||||
from app.service.design_fast.pipeline import LoadImage, KeyPoint, Segmentation, Color, PrintPainting, Scaling, Split, LoadBodyImage, ContourDetection, NoSegPrintPainting
|
||||
|
||||
|
||||
class BaseItem:
|
||||
@@ -7,25 +7,21 @@ class BaseItem:
|
||||
self.result['name'] = data['type'].lower()
|
||||
self.result.pop("type")
|
||||
self.result.update(basic)
|
||||
self.result['design_type'] = basic.get('design_type', None)
|
||||
|
||||
|
||||
class AccessoriesItem(BaseItem):
|
||||
class OthersItem(BaseItem):
|
||||
def __init__(self, data, basic, minio_client):
|
||||
super().__init__(data, basic)
|
||||
self.Accessories_pipeline = [
|
||||
self.Others_pipeline = [
|
||||
LoadImage(minio_client),
|
||||
# KeyPoint(),
|
||||
# ContourDetection(),
|
||||
Segmentation(minio_client),
|
||||
# BackPerspective(minio_client),
|
||||
Color(minio_client),
|
||||
PrintPainting(minio_client),
|
||||
Scaling(),
|
||||
Split(minio_client)
|
||||
]
|
||||
|
||||
def process(self):
|
||||
for item in self.Accessories_pipeline:
|
||||
for item in self.Others_pipeline:
|
||||
self.result = item(self.result)
|
||||
return self.result
|
||||
|
||||
@@ -39,6 +35,7 @@ class TopItem(BaseItem):
|
||||
Segmentation(minio_client),
|
||||
# BackPerspective(minio_client),
|
||||
Color(minio_client),
|
||||
NoSegPrintPainting(minio_client),
|
||||
PrintPainting(minio_client),
|
||||
Scaling(),
|
||||
Split(minio_client)
|
||||
@@ -60,6 +57,7 @@ class BottomItem(BaseItem):
|
||||
Segmentation(minio_client),
|
||||
# BackPerspective(minio_client),
|
||||
Color(minio_client),
|
||||
NoSegPrintPainting(minio_client),
|
||||
PrintPainting(minio_client),
|
||||
Scaling(),
|
||||
Split(minio_client)
|
||||
@@ -71,6 +69,65 @@ class BottomItem(BaseItem):
|
||||
return self.result
|
||||
|
||||
|
||||
"""merge"""
|
||||
|
||||
|
||||
class OthersMergeItem(BaseItem):
|
||||
def __init__(self, data, basic, minio_client):
|
||||
super().__init__(data, basic)
|
||||
self.Others_pipeline = [
|
||||
LoadImage(minio_client),
|
||||
# KeyPoint(),
|
||||
# ContourDetection(),
|
||||
Segmentation(minio_client),
|
||||
# BackPerspective(minio_client),
|
||||
Color(minio_client),
|
||||
NoSegPrintPainting(minio_client),
|
||||
PrintPainting(minio_client),
|
||||
Scaling(),
|
||||
Split(minio_client)
|
||||
]
|
||||
|
||||
def process(self):
|
||||
for item in self.Others_pipeline:
|
||||
self.result = item(self.result)
|
||||
return self.result
|
||||
|
||||
|
||||
class TopMergeItem(BaseItem):
|
||||
def __init__(self, data, basic, minio_client):
|
||||
super().__init__(data, basic)
|
||||
self.top_pipeline = [
|
||||
LoadImage(minio_client),
|
||||
KeyPoint(),
|
||||
Segmentation(minio_client),
|
||||
Scaling(),
|
||||
Split(minio_client)
|
||||
]
|
||||
|
||||
def process(self):
|
||||
for item in self.top_pipeline:
|
||||
self.result = item(self.result)
|
||||
return self.result
|
||||
|
||||
|
||||
class BottomMergeItem(BaseItem):
|
||||
def __init__(self, data, basic, minio_client):
|
||||
super().__init__(data, basic)
|
||||
self.bottom_pipeline = [
|
||||
LoadImage(minio_client),
|
||||
KeyPoint(),
|
||||
Segmentation(minio_client),
|
||||
Scaling(),
|
||||
Split(minio_client)
|
||||
]
|
||||
|
||||
def process(self):
|
||||
for item in self.bottom_pipeline:
|
||||
self.result = item(self.result)
|
||||
return self.result
|
||||
|
||||
|
||||
class BodyItem(BaseItem):
|
||||
def __init__(self, data, basic, minio_client):
|
||||
super().__init__(data, basic)
|
||||
|
||||
@@ -1,13 +1,18 @@
|
||||
import io
|
||||
|
||||
from app.service.utils.oss_client import oss_get_image, oss_upload_image
|
||||
from minio import Minio
|
||||
from app.core.config import settings
|
||||
|
||||
from app.service.utils.new_oss_client import oss_get_image, oss_upload_image
|
||||
|
||||
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
|
||||
|
||||
|
||||
def model_transpose(image_path):
|
||||
bucket = image_path.split("/", 1)[0]
|
||||
object_name = image_path.split("/", 1)[1]
|
||||
new_object_name = f'{object_name[:object_name.rfind(".")]}.png'
|
||||
image = oss_get_image(bucket=bucket, object_name=object_name, data_type="PIL")
|
||||
image = oss_get_image(oss_client=minio_client, bucket=bucket, object_name=object_name, data_type="PIL")
|
||||
image = image.convert("RGBA")
|
||||
data = image.getdata()
|
||||
#
|
||||
@@ -23,6 +28,6 @@ def model_transpose(image_path):
|
||||
image.save(image_data, format='PNG')
|
||||
image_data.seek(0)
|
||||
image_bytes = image_data.read()
|
||||
oss_upload_image(bucket=bucket, object_name=new_object_name, image_bytes=image_bytes)
|
||||
oss_upload_image(oss_client=minio_client, bucket=bucket, object_name=new_object_name, image_bytes=image_bytes)
|
||||
image_path = f"{bucket}/{new_object_name}"
|
||||
return image_path
|
||||
@@ -5,6 +5,7 @@ from .keypoint import KeyPoint
|
||||
from .keypoint import KeyPoint
|
||||
from .loading import LoadImage, LoadBodyImage
|
||||
from .print_painting import PrintPainting
|
||||
from .no_seg_print_painting import NoSegPrintPainting
|
||||
from .scale import Scaling
|
||||
from .segmentation import Segmentation
|
||||
from .split import Split
|
||||
@@ -16,6 +17,7 @@ __all__ = [
|
||||
'Segmentation',
|
||||
'BackPerspective',
|
||||
'Color',
|
||||
'NoSegPrintPainting',
|
||||
'PrintPainting',
|
||||
'Scaling',
|
||||
'Split'
|
||||
|
||||
@@ -18,11 +18,11 @@ class BackPerspective:
|
||||
result['back_perspective_url'] = file_path
|
||||
return result
|
||||
else:
|
||||
seg_result = get_seg_result("1", result['image'])[0]
|
||||
seg_result = get_seg_result(result['image'])[0]
|
||||
elif result['name'] in ['blouse', 'outwear', 'dress', 'tops']:
|
||||
seg_result = result['seg_result']
|
||||
else:
|
||||
seg_result = get_seg_result("1", result['image'])[0]
|
||||
seg_result = get_seg_result(result['image'])[0]
|
||||
|
||||
m = self.thicken_contours_and_display(seg_result, thickness=10, color=(0, 0, 0))
|
||||
back_sketch = result['image'].copy()
|
||||
@@ -34,7 +34,8 @@ class BackPerspective:
|
||||
result['back_perspective_url'] = f"{resp.bucket_name}/{resp.object_name}"
|
||||
return result
|
||||
|
||||
def thicken_contours_and_display(self, mask, thickness=10, color=(0, 0, 0)):
|
||||
@staticmethod
|
||||
def thicken_contours_and_display(mask, thickness=10, color=(0, 0, 0)):
|
||||
mask = mask.astype(np.uint8) * 255
|
||||
# 查找轮廓
|
||||
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
@@ -48,9 +49,9 @@ class BackPerspective:
|
||||
# 在空白图像上绘制白色的轮廓
|
||||
cv2.drawContours(blank, [contour], -1, 255, thickness=thick)
|
||||
# 找到轮廓的中心(可以用重心等方法近似)
|
||||
M = cv2.moments(contour)
|
||||
cx = int(M['m10'] / M['m00'])
|
||||
cy = int(M['m01'] / M['m00'])
|
||||
m = cv2.moments(contour)
|
||||
# cx = int(m['m10'] / m['m00'])
|
||||
# cy = int(m['m01'] / m['m00'])
|
||||
# 进行距离变换,离中心越近的值越小
|
||||
dist_transform = cv2.distanceTransform(255 - blank, cv2.DIST_L2, 5)
|
||||
# 根据距离变换的值来决定是否保留像素,离中心近的像素更容易被保留
|
||||
|
||||
@@ -22,7 +22,7 @@ class Color:
|
||||
resize_pattern = cv2.resize(pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA)
|
||||
# 无色
|
||||
elif "color" not in result.keys() or result['color'] == "":
|
||||
result['final_image'] = result['pattern_image'] = result['single_image'] = result['image']
|
||||
result['no_seg_sketch_overall'] = result['no_seg_sketch_print'] = result['final_image'] = result['pattern_image'] = result['single_image'] = result['image']
|
||||
result['alpha'] = 100 / 255.0
|
||||
return result
|
||||
# 正常颜色
|
||||
@@ -48,7 +48,7 @@ class Color:
|
||||
resize_pattern[mask_3ch] = png_rgb[mask_3ch]
|
||||
resize_pattern = cv2.resize(resize_pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA)
|
||||
closed_mo = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
|
||||
gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
|
||||
gray_mo = np.expand_dims(cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY), axis=2).repeat(3, axis=2)
|
||||
get_image_fir = resize_pattern * (closed_mo / 255) * (gray_mo / 255)
|
||||
result['pattern_image'] = get_image_fir.astype(np.uint8)
|
||||
result['final_image'] = result['pattern_image']
|
||||
@@ -59,6 +59,8 @@ class Color:
|
||||
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
|
||||
result['single_image'] = cv2.add(tmp1, tmp2)
|
||||
result['alpha'] = 100 / 255.0
|
||||
|
||||
result['no_seg_sketch_overall'] = result['no_seg_sketch_print'] = result['final_image'].copy()
|
||||
return result
|
||||
|
||||
def get_gradient(self, bucket_name, object_name):
|
||||
@@ -79,9 +81,9 @@ class Color:
|
||||
def get_pattern(single_color):
|
||||
if single_color is None:
|
||||
raise False
|
||||
R, G, B = single_color.split(' ')
|
||||
r, g, b = single_color.split(' ')
|
||||
pattern = np.zeros([1, 1, 3], np.uint8)
|
||||
pattern[0, 0, 0] = int(B)
|
||||
pattern[0, 0, 1] = int(G)
|
||||
pattern[0, 0, 2] = int(R)
|
||||
pattern[0, 0, 0] = int(b)
|
||||
pattern[0, 0, 1] = int(g)
|
||||
pattern[0, 0, 2] = int(r)
|
||||
return pattern
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user