Compare commits
42 Commits
fix
...
1be716e414
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1be716e414 | ||
|
|
826bdcf9c1 | ||
|
|
f351184630 | ||
| fac1eab1bc | |||
| 832ca6fd05 | |||
| 673423131a | |||
| 6e15430a83 | |||
| 51068d2215 | |||
| d493d9eff6 | |||
| 7d970a7bba | |||
| 3fc6720bf7 | |||
| efa2e3a4a9 | |||
| c6af01bc51 | |||
|
|
448af4ab6b | ||
|
|
8a9f160cfa | ||
|
|
6e06c8b516 | ||
|
|
322fb9c46b | ||
|
|
30bfd22e3e | ||
|
|
e8d8b715ae | ||
|
|
7d2149dcaf | ||
|
|
fee9334b1f | ||
|
|
85c486c3dc | ||
|
|
0e7ef80eed | ||
|
|
8ccbbe41b1 | ||
|
|
98468ea7aa | ||
|
|
a9d9bdcb71 | ||
|
|
7459583377 | ||
|
|
385ff2d4aa | ||
|
|
02ad5db269 | ||
|
|
1d90963ded | ||
|
|
d1fefceebf | ||
|
|
242ebfc1df | ||
|
|
b8cf3d25b4 | ||
|
|
95647be610 | ||
|
|
e966ed5aa5 | ||
|
|
0d4d464e3f | ||
|
|
4bc79e62ca | ||
|
|
bf1fb8e514 | ||
|
|
d720bf2209 | ||
|
|
8f486867d5 | ||
|
|
1f45fe48a3 | ||
|
|
94c3d1e30d |
44
.gitea/workflows/develop_build_commit.yaml
Normal file
44
.gitea/workflows/develop_build_commit.yaml
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
name: git commit AiDA python develop 分支构建部署
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
push:
|
||||||
|
branches:
|
||||||
|
- develop
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
scheduled_deploy:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
if: "contains(github.event.head_commit.message, '[run build]')"
|
||||||
|
|
||||||
|
env:
|
||||||
|
REMOTE_DEPLOY_PATH: /workspace/Trinity/Fastapi_AiDA_Trinity_Dev
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: 1.检出代码
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: 'develop'
|
||||||
|
|
||||||
|
- name: 2.复制文件到服务器
|
||||||
|
uses: appleboy/scp-action@v0.1.7
|
||||||
|
with:
|
||||||
|
host: ${{ secrets.SERVER_HOST }}
|
||||||
|
username: ${{ secrets.SERVER_USER }}
|
||||||
|
password: ${{ secrets.SERVER_PASSWORD }}
|
||||||
|
source: "."
|
||||||
|
target: ${{ env.REMOTE_DEPLOY_PATH }}
|
||||||
|
|
||||||
|
- name: Restart Docker containers
|
||||||
|
uses: appleboy/ssh-action@v0.1.10
|
||||||
|
with:
|
||||||
|
host: ${{ secrets.SERVER_HOST }}
|
||||||
|
username: ${{ secrets.SERVER_USER }}
|
||||||
|
password: ${{ secrets.SERVER_PASSWORD }}
|
||||||
|
script: |
|
||||||
|
# 进入项目目录
|
||||||
|
cd ${{ env.REMOTE_DEPLOY_PATH }}
|
||||||
|
|
||||||
|
docker-compose down 2>&1
|
||||||
|
docker-compose up -d --build --remove-orphans 2>&1
|
||||||
|
|
||||||
|
docker image prune -f 2>&1
|
||||||
40
.gitea/workflows/develop_build_manual.yaml
Normal file
40
.gitea/workflows/develop_build_manual.yaml
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
name: 手动 AiDA python develop 分支构建部署
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
scheduled_deploy:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
env:
|
||||||
|
REMOTE_DEPLOY_PATH: /workspace/Trinity/Fastapi_AiDA_Trinity_Dev
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: 1.检出代码
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: 'develop'
|
||||||
|
|
||||||
|
- name: 2.复制文件到服务器
|
||||||
|
uses: appleboy/scp-action@v0.1.7
|
||||||
|
with:
|
||||||
|
host: ${{ secrets.SERVER_HOST }}
|
||||||
|
username: ${{ secrets.SERVER_USER }}
|
||||||
|
password: ${{ secrets.SERVER_PASSWORD }}
|
||||||
|
source: "."
|
||||||
|
target: ${{ env.REMOTE_DEPLOY_PATH }}
|
||||||
|
|
||||||
|
- name: 3.重启docker-compose
|
||||||
|
uses: appleboy/ssh-action@v0.1.10
|
||||||
|
with:
|
||||||
|
host: ${{ secrets.SERVER_HOST }}
|
||||||
|
username: ${{ secrets.SERVER_USER }}
|
||||||
|
password: ${{ secrets.SERVER_PASSWORD }}
|
||||||
|
script: |
|
||||||
|
# 进入项目目录
|
||||||
|
cd ${{ env.REMOTE_DEPLOY_PATH }}
|
||||||
|
|
||||||
|
docker-compose down 2>&1
|
||||||
|
docker-compose up -d --build --remove-orphans 2>&1
|
||||||
|
|
||||||
|
docker image prune -f 2>&1
|
||||||
42
.gitea/workflows/develop_build_scheduled.yaml
Normal file
42
.gitea/workflows/develop_build_scheduled.yaml
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
name: 定时 AiDA python develop 分支构建部署
|
||||||
|
on:
|
||||||
|
# 使用 schedule 触发器,遵循标准的 Cron 格式 (分钟 小时-8 日期 月份 星期)
|
||||||
|
schedule:
|
||||||
|
- cron: '30 9 * * *'
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
scheduled_deploy:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
env:
|
||||||
|
REMOTE_DEPLOY_PATH: /workspace/Trinity/Fastapi_AiDA_Trinity_Dev
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: 1.检出代码
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
ref: 'develop'
|
||||||
|
|
||||||
|
- name: 2.复制文件到服务器
|
||||||
|
uses: appleboy/scp-action@v0.1.7
|
||||||
|
with:
|
||||||
|
host: ${{ secrets.SERVER_HOST }}
|
||||||
|
username: ${{ secrets.SERVER_USER }}
|
||||||
|
password: ${{ secrets.SERVER_PASSWORD }}
|
||||||
|
source: "."
|
||||||
|
target: ${{ env.REMOTE_DEPLOY_PATH }}
|
||||||
|
|
||||||
|
- name: Restart Docker containers
|
||||||
|
uses: appleboy/ssh-action@v0.1.10
|
||||||
|
with:
|
||||||
|
host: ${{ secrets.SERVER_HOST }}
|
||||||
|
username: ${{ secrets.SERVER_USER }}
|
||||||
|
password: ${{ secrets.SERVER_PASSWORD }}
|
||||||
|
script: |
|
||||||
|
# 进入项目目录
|
||||||
|
cd ${{ env.REMOTE_DEPLOY_PATH }}
|
||||||
|
|
||||||
|
docker-compose down 2>&1
|
||||||
|
docker-compose up -d --build --remove-orphans 2>&1
|
||||||
|
|
||||||
|
docker image prune -f 2>&1
|
||||||
@@ -9,7 +9,6 @@ from apscheduler.schedulers.background import BackgroundScheduler
|
|||||||
from apscheduler.triggers.cron import CronTrigger
|
from apscheduler.triggers.cron import CronTrigger
|
||||||
from fastapi import HTTPException, APIRouter
|
from fastapi import HTTPException, APIRouter
|
||||||
|
|
||||||
from app.service.recommend.service import load_resources, matrix_data
|
|
||||||
import pymysql
|
import pymysql
|
||||||
from app.core.config import DB_CONFIG, TABLE_CATEGORIES, RECOMMEND_PATH_PREFIX
|
from app.core.config import DB_CONFIG, TABLE_CATEGORIES, RECOMMEND_PATH_PREFIX
|
||||||
from minio import Minio
|
from minio import Minio
|
||||||
|
|||||||
116
app/api/api_import_sys_sketch.py
Normal file
116
app/api/api_import_sys_sketch.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from typing import Optional
|
||||||
|
from fastapi import APIRouter, HTTPException, Query
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
import threading
|
||||||
|
|
||||||
|
from app.schemas.response_template import ResponseModel
|
||||||
|
from app.service.recommendation_system.import_sys_sketch_to_milvus import main as import_main
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
# 使用线程池执行器来运行长时间任务
|
||||||
|
executor = ThreadPoolExecutor(max_workers=1)
|
||||||
|
# 用于跟踪任务状态
|
||||||
|
task_status = {"running": False}
|
||||||
|
|
||||||
|
|
||||||
|
def run_import_task(batch_size: int, retry_times: int, limit: Optional[int], offset: int, skip_create_collection: bool):
|
||||||
|
"""在后台线程中运行导入任务"""
|
||||||
|
original_argv = None
|
||||||
|
try:
|
||||||
|
task_status["running"] = True
|
||||||
|
# 保存原始 sys.argv
|
||||||
|
original_argv = sys.argv.copy()
|
||||||
|
|
||||||
|
# 模拟命令行参数
|
||||||
|
sys.argv = [
|
||||||
|
"import_sys_sketch_to_milvus.py",
|
||||||
|
"--batch-size", str(batch_size),
|
||||||
|
"--retry-times", str(retry_times),
|
||||||
|
]
|
||||||
|
if limit is not None:
|
||||||
|
sys.argv.extend(["--limit", str(limit)])
|
||||||
|
if offset > 0:
|
||||||
|
sys.argv.extend(["--offset", str(offset)])
|
||||||
|
if skip_create_collection:
|
||||||
|
sys.argv.append("--skip-create-collection")
|
||||||
|
|
||||||
|
import_main()
|
||||||
|
task_status["running"] = False
|
||||||
|
logger.info("导入任务完成")
|
||||||
|
except Exception as e:
|
||||||
|
task_status["running"] = False
|
||||||
|
logger.error(f"导入任务失败: {e}", exc_info=True)
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
# 恢复原始 sys.argv
|
||||||
|
if original_argv is not None:
|
||||||
|
sys.argv = original_argv
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/import-sys-sketch", response_model=ResponseModel)
|
||||||
|
async def import_sys_sketch(
|
||||||
|
batch_size: int = Query(1000, description="批量处理大小(默认:1000)"),
|
||||||
|
retry_times: int = Query(3, description="失败重试次数(默认:3)"),
|
||||||
|
limit: Optional[int] = Query(None, description="限制处理数量(用于测试,默认:不限制)"),
|
||||||
|
offset: int = Query(0, description="起始偏移量(默认:0)"),
|
||||||
|
skip_create_collection: bool = Query(False, description="跳过创建集合(如果集合已存在)"),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
从 t_sys_file 导入系统图向量到 Milvus
|
||||||
|
|
||||||
|
该接口会异步执行导入任务,任务在后台运行。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 检查是否有任务正在运行
|
||||||
|
if task_status["running"]:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=409,
|
||||||
|
detail="已有导入任务正在运行,请等待完成后再试"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 在后台线程中执行任务
|
||||||
|
executor.submit(
|
||||||
|
run_import_task,
|
||||||
|
batch_size,
|
||||||
|
retry_times,
|
||||||
|
limit,
|
||||||
|
offset,
|
||||||
|
skip_create_collection
|
||||||
|
)
|
||||||
|
|
||||||
|
return ResponseModel(
|
||||||
|
code=200,
|
||||||
|
msg="导入任务已启动,正在后台执行",
|
||||||
|
data={
|
||||||
|
"status": "started",
|
||||||
|
"batch_size": batch_size,
|
||||||
|
"retry_times": retry_times,
|
||||||
|
"limit": limit,
|
||||||
|
"offset": offset,
|
||||||
|
"skip_create_collection": skip_create_collection
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"启动导入任务失败: {e}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=500, detail=f"启动导入任务失败: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/import-sys-sketch/status", response_model=ResponseModel)
|
||||||
|
async def get_import_status():
|
||||||
|
"""
|
||||||
|
获取导入任务状态
|
||||||
|
"""
|
||||||
|
return ResponseModel(
|
||||||
|
code=200,
|
||||||
|
msg="OK",
|
||||||
|
data={
|
||||||
|
"running": task_status["running"]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
@@ -1,10 +1,16 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
import requests
|
||||||
from fastapi import APIRouter, BackgroundTasks, HTTPException
|
from fastapi import APIRouter, BackgroundTasks, HTTPException
|
||||||
|
|
||||||
|
from app.core.config import COMFYUI_SERVER_ADDRESS
|
||||||
|
from app.schemas.comfyui_i2v import ComfyuiI2VModel, ComfyuiFLF2VModel
|
||||||
from app.schemas.pose_transform import PoseTransformModel
|
from app.schemas.pose_transform import PoseTransformModel
|
||||||
from app.schemas.response_template import ResponseModel
|
from app.schemas.response_template import ResponseModel
|
||||||
|
from app.service.comfyui_I2V.flf2v_server import ComfyUIServerFLF2V
|
||||||
|
from app.service.comfyui_I2V.i2v_server import ComfyUIServerI2V
|
||||||
|
from app.service.comfyui_I2V.pose2v_server import ComfyUIServerPose2V
|
||||||
from app.service.generate_image.service_pose_transform import PoseTransformService, infer_cancel as pose_transform_infer_cancel
|
from app.service.generate_image.service_pose_transform import PoseTransformService, infer_cancel as pose_transform_infer_cancel
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
@@ -47,3 +53,116 @@ def pose_transform_cancel(tasks_id: str):
|
|||||||
logger.warning(f"pose_transform_cancel Run Exception @@@@@@:{e}")
|
logger.warning(f"pose_transform_cancel Run Exception @@@@@@:{e}")
|
||||||
raise HTTPException(status_code=404, detail=str(e))
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
return ResponseModel(data=data['data'])
|
return ResponseModel(data=data['data'])
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
骨架 + 产品图 => 视频
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/comfyui_image_pose_2_video")
|
||||||
|
def comfyui_image_pose_2_video(request_item: PoseTransformModel, background_tasks: BackgroundTasks):
|
||||||
|
"""
|
||||||
|
创建一个具有以下参数的请求体:
|
||||||
|
- **tasks_id**: 任务id 用于取消生成任务和获取生成结果
|
||||||
|
- **image_url**: 被生成图片的S3或minio url地址
|
||||||
|
- **pose_id**: 1
|
||||||
|
|
||||||
|
|
||||||
|
示例参数:
|
||||||
|
{
|
||||||
|
"tasks_id": "123-89",
|
||||||
|
"image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png",
|
||||||
|
"pose_id": "1"
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.info(f"image_pose_2_video request item is : @@@@@@:{json.dumps(request_item.dict())}")
|
||||||
|
service = ComfyUIServerPose2V(request_item)
|
||||||
|
background_tasks.add_task(service.get_result)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"image_pose_2_video Run Exception @@@@@@:{e}")
|
||||||
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
return ResponseModel()
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
产品图 + 文 => 视频
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/comfyui_image_2_video")
|
||||||
|
def comfyui_image_2_video(request_item: ComfyuiI2VModel, background_tasks: BackgroundTasks):
|
||||||
|
"""
|
||||||
|
创建一个具有以下参数的请求体:
|
||||||
|
- **tasks_id**: 任务id 用于取消生成任务和获取生成结果
|
||||||
|
- **image_url**: 被生成图片的S3或minio url地址
|
||||||
|
- **prompt**: 动作表述
|
||||||
|
|
||||||
|
示例参数:
|
||||||
|
{
|
||||||
|
"tasks_id": "12222515151123-89111",
|
||||||
|
"image_url": "aida-users/89/product_image/a6949500-2393-42ac-8723-440b5d5da2b2-0-89.png",
|
||||||
|
"prompt": "Model executing a series of poses, dynamic camera movement alternating between detailed close-ups and full shots."
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.info(f"image_2_video request item is : @@@@@@:{json.dumps(request_item.dict())}")
|
||||||
|
service = ComfyUIServerI2V(request_item)
|
||||||
|
background_tasks.add_task(service.get_result)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"image_2_video Run Exception @@@@@@:{e}")
|
||||||
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
return ResponseModel()
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
首尾帧 + 文 => 视频
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/comfyui_flf_2_video")
|
||||||
|
def comfyui_flf_2_video(request_item: ComfyuiFLF2VModel, background_tasks: BackgroundTasks):
|
||||||
|
"""
|
||||||
|
创建一个具有以下参数的请求体:
|
||||||
|
- **tasks_id**: 任务id 用于取消生成任务和获取生成结果
|
||||||
|
- **start_image_url**: 首帧
|
||||||
|
- **end_image_url**: 尾帧
|
||||||
|
- **prompt**: 动作描述
|
||||||
|
|
||||||
|
示例参数:
|
||||||
|
{
|
||||||
|
"tasks_id": "202511051619-89111",
|
||||||
|
"start_image_url": "test/start.png",
|
||||||
|
"end_image_url": "test/end.png",
|
||||||
|
"prompt": "Model executing a series of poses, dynamic camera movement alternating between detailed close-ups and full shots."
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
logger.info(f"flf_2_video request item is : @@@@@@:{json.dumps(request_item.dict())}")
|
||||||
|
service = ComfyUIServerFLF2V(request_item)
|
||||||
|
background_tasks.add_task(service.get_result)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"flf_2_video Run Exception @@@@@@:{e}")
|
||||||
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
return ResponseModel()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/comfyui_i_2_video_cancel/{tasks_id}")
|
||||||
|
def comfyui_i_2_video_cancel(tasks_id: str):
|
||||||
|
try:
|
||||||
|
logger.info(f"comfyui_i_2_video_cancel request item is : @@@@@@:{tasks_id}")
|
||||||
|
response = requests.post(
|
||||||
|
f"http://{COMFYUI_SERVER_ADDRESS}/interrupt",
|
||||||
|
json={"prompt_id": tasks_id}
|
||||||
|
)
|
||||||
|
data = {}
|
||||||
|
if response.status_code == 200:
|
||||||
|
data['data']['message'] = "任务已成功中断"
|
||||||
|
else:
|
||||||
|
data['data']['message'] = f"中断失败:{response.text}"
|
||||||
|
logger.info(f"comfyui_i_2_video_cancel response @@@@@@:{data}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"comfyui_i_2_video_cancel Run Exception @@@@@@:{e}")
|
||||||
|
raise HTTPException(status_code=404, detail=str(e))
|
||||||
|
return ResponseModel(data=data['data'])
|
||||||
|
|||||||
85
app/api/api_precompute.py
Normal file
85
app/api/api_precompute.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
import logging
|
||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
|
from app.schemas.response_template import ResponseModel
|
||||||
|
from app.service.recommendation_system.precompute import run_precompute
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
# 使用线程池执行器来运行长时间任务
|
||||||
|
executor = ThreadPoolExecutor(max_workers=1)
|
||||||
|
# 用于跟踪任务状态
|
||||||
|
task_status = {"running": False}
|
||||||
|
|
||||||
|
|
||||||
|
def run_precompute_task():
|
||||||
|
"""在后台线程中运行预计算任务"""
|
||||||
|
try:
|
||||||
|
task_status["running"] = True
|
||||||
|
logger.info("开始执行预计算任务...")
|
||||||
|
run_precompute()
|
||||||
|
task_status["running"] = False
|
||||||
|
logger.info("预计算任务完成")
|
||||||
|
except Exception as e:
|
||||||
|
task_status["running"] = False
|
||||||
|
logger.error(f"预计算任务失败: {e}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/precompute", response_model=ResponseModel)
|
||||||
|
async def precompute():
|
||||||
|
"""
|
||||||
|
运行预计算任务
|
||||||
|
|
||||||
|
该接口会异步执行预计算任务,包括:
|
||||||
|
1. 优化数据库表结构
|
||||||
|
2. 历史数据迁移
|
||||||
|
3. 初始用户偏好向量生成
|
||||||
|
|
||||||
|
任务在后台运行。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 检查是否有任务正在运行
|
||||||
|
if task_status["running"]:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=409,
|
||||||
|
detail="已有预计算任务正在运行,请等待完成后再试"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 在后台线程中执行任务
|
||||||
|
executor.submit(run_precompute_task)
|
||||||
|
|
||||||
|
return ResponseModel(
|
||||||
|
code=200,
|
||||||
|
msg="预计算任务已启动,正在后台执行",
|
||||||
|
data={
|
||||||
|
"status": "started",
|
||||||
|
"tasks": [
|
||||||
|
"优化数据库表结构",
|
||||||
|
"历史数据迁移",
|
||||||
|
"初始用户偏好向量生成"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"启动预计算任务失败: {e}", exc_info=True)
|
||||||
|
raise HTTPException(status_code=500, detail=f"启动预计算任务失败: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/precompute/status", response_model=ResponseModel)
|
||||||
|
async def get_precompute_status():
|
||||||
|
"""
|
||||||
|
获取预计算任务状态
|
||||||
|
"""
|
||||||
|
return ResponseModel(
|
||||||
|
code=200,
|
||||||
|
msg="OK",
|
||||||
|
data={
|
||||||
|
"running": task_status["running"]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
@@ -1,204 +1,175 @@
|
|||||||
import io
|
import io
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
import time
|
from typing import List, Optional
|
||||||
from typing import List
|
from fastapi import HTTPException, APIRouter, Query
|
||||||
import os
|
|
||||||
import json
|
|
||||||
import math
|
|
||||||
import random
|
|
||||||
import numpy as np
|
|
||||||
from apscheduler.schedulers.background import BackgroundScheduler
|
from apscheduler.schedulers.background import BackgroundScheduler
|
||||||
from apscheduler.triggers.cron import CronTrigger
|
|
||||||
from fastapi import HTTPException, APIRouter
|
|
||||||
|
|
||||||
from app.service.recommend.service import load_resources, matrix_data
|
from app.service.recommendation_system.recommendation_api import get_recommendations as get_new_recommendations
|
||||||
|
from app.service.recommendation_system.incremental_listener import start_background_listener
|
||||||
|
from app.service.recommendation_system.milvus_client import create_collection
|
||||||
|
|
||||||
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
|
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
|
||||||
logger = logging.getLogger()
|
logger = logging.getLogger()
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
@router.on_event("startup")
|
# ========== 旧版推荐接口(基于 npy 矩阵,已废弃)==========
|
||||||
|
# @router.get("/recommend/{user_id}/{category}/{num_recommendations}/{brand_id}/{brand_scale}", response_model=List[str])
|
||||||
|
# async def get_recommendations(user_id: int, category: str, brand_id: int, brand_scale: float, num_recommendations: int = 10):
|
||||||
|
# """
|
||||||
|
# :param user_id: 4
|
||||||
|
# :param category: female_skirt
|
||||||
|
# :param num_recommendations: 1
|
||||||
|
# :return:
|
||||||
|
# [
|
||||||
|
# "aida-sys-image/images/female/skirt/903000017.jpg"
|
||||||
|
# ]
|
||||||
|
# """
|
||||||
|
# try:
|
||||||
|
# start_time = time.time()
|
||||||
|
# cache_key = (user_id, category)
|
||||||
|
# # === 新增:用户存在性检查 ===
|
||||||
|
# user_exists_inter = user_id in matrix_data["user_index_interaction"]
|
||||||
|
# user_exists_feat = user_id in matrix_data["user_index_feature"]
|
||||||
|
#
|
||||||
|
# # 任一矩阵不存在用户则返回随机推荐
|
||||||
|
# if not (user_exists_inter and user_exists_feat):
|
||||||
|
# logger.info(f"用户 {user_id} 数据不完整,触发随机推荐")
|
||||||
|
# return get_random_recommendations(category, num_recommendations)
|
||||||
|
#
|
||||||
|
# # 检查缓存
|
||||||
|
# if cache_key in matrix_data["cached_scores"]:
|
||||||
|
# processed_inter, processed_feat = matrix_data["cached_scores"][cache_key]
|
||||||
|
# valid_sketch_idxs_inter = matrix_data["cached_valid_idxs"][cache_key]
|
||||||
|
# else:
|
||||||
|
# # 实时计算逻辑(同原代码)
|
||||||
|
# user_idx_inter = matrix_data["user_index_interaction"].get(user_id)
|
||||||
|
# user_idx_feature = matrix_data["user_index_feature"].get(user_id)
|
||||||
|
#
|
||||||
|
# category_iids = matrix_data["category_to_iids"].get(category, [])
|
||||||
|
# valid_sketch_idxs_inter = [
|
||||||
|
# idx for iid, idx in matrix_data["sketch_index_interaction"].items()
|
||||||
|
# if iid in category_iids
|
||||||
|
# ]
|
||||||
|
#
|
||||||
|
# # 处理交互分数
|
||||||
|
# raw_inter_scores = []
|
||||||
|
# if user_idx_inter is not None and valid_sketch_idxs_inter:
|
||||||
|
# raw_inter_scores = matrix_data["interaction_matrix"][user_idx_inter, valid_sketch_idxs_inter]
|
||||||
|
# processed_inter = raw_inter_scores * 0.7
|
||||||
|
#
|
||||||
|
# # 处理特征分数
|
||||||
|
# valid_sketch_idxs_feature = [
|
||||||
|
# idx for iid, idx in matrix_data["sketch_index_feature"].items()
|
||||||
|
# if iid in category_iids
|
||||||
|
# ]
|
||||||
|
# raw_feat_scores = []
|
||||||
|
# if user_idx_feature is not None and valid_sketch_idxs_feature:
|
||||||
|
# raw_feat_scores = matrix_data["feature_matrix"][user_idx_feature, valid_sketch_idxs_feature]
|
||||||
|
# raw_feat_scores = (raw_feat_scores - np.min(raw_feat_scores)) / (
|
||||||
|
# np.max(raw_feat_scores) - np.min(raw_feat_scores) + 1e-8)
|
||||||
|
# processed_feat = raw_feat_scores
|
||||||
|
# else:
|
||||||
|
# processed_feat = np.array([])
|
||||||
|
#
|
||||||
|
# # 更新缓存
|
||||||
|
# matrix_data["cached_scores"][cache_key] = (processed_inter, processed_feat)
|
||||||
|
# matrix_data["cached_valid_idxs"][cache_key] = valid_sketch_idxs_inter
|
||||||
|
#
|
||||||
|
# # 合并分数
|
||||||
|
# if brand_id is not None:
|
||||||
|
# brand_idx_feature = matrix_data["brand_index_map"].get(brand_id)
|
||||||
|
#
|
||||||
|
# brand_feat_valid = (
|
||||||
|
# matrix_data["brand_feature_matrix"].size > 0 and # 矩阵非空
|
||||||
|
# brand_idx_feature is not None and
|
||||||
|
# valid_sketch_idxs_feature # 有可用索引
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# if brand_feat_valid:
|
||||||
|
# raw_brand_feat_scores = matrix_data["brand_feature_matrix"][
|
||||||
|
# brand_idx_feature, valid_sketch_idxs_feature
|
||||||
|
# ]
|
||||||
|
# raw_brand_feat_scores = (raw_brand_feat_scores - np.min(raw_brand_feat_scores)) / (
|
||||||
|
# np.max(raw_brand_feat_scores) - np.min(raw_brand_feat_scores) + 1e-8
|
||||||
|
# )
|
||||||
|
# processed_brand_feat = raw_brand_feat_scores
|
||||||
|
#
|
||||||
|
# # 如果 processed_feat 是空的,替换为全 0,避免 shape 不一致
|
||||||
|
# if processed_feat.size == 0:
|
||||||
|
# processed_feat = np.zeros_like(processed_brand_feat)
|
||||||
|
#
|
||||||
|
# final_scores = processed_inter + 0.3 * (
|
||||||
|
# (1 - brand_scale) * processed_feat + brand_scale * processed_brand_feat
|
||||||
|
# )
|
||||||
|
# else:
|
||||||
|
# # brand 信息不可用
|
||||||
|
# final_scores = processed_inter + 0.3 * processed_feat if processed_feat.size > 0 else processed_inter
|
||||||
|
# else:
|
||||||
|
# final_scores = processed_inter + 0.3 * processed_feat if processed_feat.size > 0 else processed_inter
|
||||||
|
#
|
||||||
|
# valid_sketch_idxs = matrix_data["cached_valid_idxs"][cache_key]
|
||||||
|
#
|
||||||
|
# # 概率采样
|
||||||
|
# scores = np.array(final_scores)
|
||||||
|
#
|
||||||
|
# # 调整后的概率转换(带温度控制的softmax)
|
||||||
|
# def calibrated_softmax(scores, temperature=1.0):
|
||||||
|
# scores = scores / temperature
|
||||||
|
# scale = scores - max(scores)
|
||||||
|
# exps = np.exp(scale)
|
||||||
|
# return exps / np.sum(exps)
|
||||||
|
#
|
||||||
|
# probs = calibrated_softmax(scores, 0.09)
|
||||||
|
#
|
||||||
|
# chosen_indices = np.random.choice(
|
||||||
|
# len(valid_sketch_idxs),
|
||||||
|
# size=min(num_recommendations, len(valid_sketch_idxs)),
|
||||||
|
# p=probs,
|
||||||
|
# replace=False
|
||||||
|
# )
|
||||||
|
# recommendations = [matrix_data["iid_to_sketch"][valid_sketch_idxs[idx]] for idx in chosen_indices]
|
||||||
|
#
|
||||||
|
# logger.info(f"推荐生成完成,耗时: {time.time() - start_time:.2f}秒")
|
||||||
|
# return recommendations
|
||||||
|
# except Exception as e:
|
||||||
|
# logger.error(f"推荐失败: {str(e)}", exc_info=True)
|
||||||
|
# raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
# @router.on_event("startup")
|
||||||
async def startup_event():
|
async def startup_event():
|
||||||
# 初始加载
|
"""启动时初始化增量监听任务"""
|
||||||
load_resources()
|
try:
|
||||||
|
# 确保 Milvus 集合已创建(若已存在则直接返回)
|
||||||
|
try:
|
||||||
|
create_collection()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("Milvus 集合创建/检查失败,不影响服务继续启动: %s", exc, exc_info=True)
|
||||||
|
|
||||||
# 配置定时任务
|
# 配置定时任务
|
||||||
scheduler = BackgroundScheduler()
|
scheduler = BackgroundScheduler()
|
||||||
scheduler.add_job(
|
start_background_listener(scheduler)
|
||||||
load_resources,
|
|
||||||
trigger=CronTrigger(hour=0, minute=30),
|
|
||||||
name="每日资源刷新"
|
|
||||||
)
|
|
||||||
scheduler.start()
|
scheduler.start()
|
||||||
logger.info("定时任务已启动")
|
logger.info("增量监听定时任务已启动")
|
||||||
|
|
||||||
def softmax(scores):
|
|
||||||
max_score = max(scores)
|
|
||||||
exp_scores = [math.exp(s - max_score) for s in scores]
|
|
||||||
sum_exp = sum(exp_scores)
|
|
||||||
return [s / sum_exp for s in exp_scores]
|
|
||||||
|
|
||||||
# def get_random_recommendations(category: str, num: int) -> List[str]:
|
|
||||||
# """根据预加载热度向量推荐(冷启动)"""
|
|
||||||
# try:
|
|
||||||
# heat_data = matrix_data.get("heat_data", {})
|
|
||||||
#
|
|
||||||
# if category not in heat_data:
|
|
||||||
# raise ValueError(f"热度数据缺少类别 {category},使用随机推荐")
|
|
||||||
#
|
|
||||||
# heat_dict = heat_data[category] # {url: score}
|
|
||||||
# urls = list(heat_dict.keys())
|
|
||||||
# scores = list(heat_dict.values())
|
|
||||||
#
|
|
||||||
# if not urls:
|
|
||||||
# raise ValueError("该类别下无热度记录,使用随机推荐")
|
|
||||||
#
|
|
||||||
# probs = softmax(scores)
|
|
||||||
# sample_size = min(num, len(urls))
|
|
||||||
# sampled_urls = random.choices(urls, weights=probs, k=sample_size)
|
|
||||||
#
|
|
||||||
# return sampled_urls
|
|
||||||
#
|
|
||||||
# except Exception as e:
|
|
||||||
# # 回退:完全随机推荐
|
|
||||||
# all_iids = list(matrix_data["iid_to_sketch"].keys())
|
|
||||||
# category_iids = matrix_data["category_to_iids"].get(category, all_iids)
|
|
||||||
# sample_size = min(num, len(category_iids))
|
|
||||||
# sampled = np.random.choice(category_iids, size=sample_size, replace=False)
|
|
||||||
# return [matrix_data["iid_to_sketch"][iid] for iid in sampled]
|
|
||||||
|
|
||||||
def get_random_recommendations(category: str, num: int) -> List[str]:
|
|
||||||
"""全品类随机推荐"""
|
|
||||||
all_iids = list(matrix_data["iid_to_sketch"].keys())
|
|
||||||
# 优先从当前品类选择
|
|
||||||
category_iids = matrix_data["category_to_iids"].get(category, all_iids)
|
|
||||||
# 确保不超出实际数量
|
|
||||||
sample_size = min(num, len(category_iids))
|
|
||||||
sampled = np.random.choice(category_iids, size=sample_size, replace=False)
|
|
||||||
return [matrix_data["iid_to_sketch"][iid] for iid in sampled]
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/recommend/{user_id}/{category}/{num_recommendations}/{brand_id}/{brand_scale}", response_model=List[str])
|
|
||||||
async def get_recommendations(user_id: int, category: str, brand_id: int, brand_scale: float, num_recommendations: int = 10):
|
|
||||||
"""
|
|
||||||
:param user_id: 4
|
|
||||||
:param category: female_skirt
|
|
||||||
:param num_recommendations: 1
|
|
||||||
:return:
|
|
||||||
[
|
|
||||||
"aida-sys-image/images/female/skirt/903000017.jpg"
|
|
||||||
]
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
start_time = time.time()
|
|
||||||
cache_key = (user_id, category)
|
|
||||||
# === 新增:用户存在性检查 ===
|
|
||||||
user_exists_inter = user_id in matrix_data["user_index_interaction"]
|
|
||||||
user_exists_feat = user_id in matrix_data["user_index_feature"]
|
|
||||||
|
|
||||||
# 任一矩阵不存在用户则返回随机推荐
|
|
||||||
if not (user_exists_inter and user_exists_feat):
|
|
||||||
logger.info(f"用户 {user_id} 数据不完整,触发随机推荐")
|
|
||||||
return get_random_recommendations(category, num_recommendations)
|
|
||||||
|
|
||||||
# 检查缓存
|
|
||||||
if cache_key in matrix_data["cached_scores"]:
|
|
||||||
processed_inter, processed_feat = matrix_data["cached_scores"][cache_key]
|
|
||||||
valid_sketch_idxs_inter = matrix_data["cached_valid_idxs"][cache_key]
|
|
||||||
else:
|
|
||||||
# 实时计算逻辑(同原代码)
|
|
||||||
user_idx_inter = matrix_data["user_index_interaction"].get(user_id)
|
|
||||||
user_idx_feature = matrix_data["user_index_feature"].get(user_id)
|
|
||||||
|
|
||||||
category_iids = matrix_data["category_to_iids"].get(category, [])
|
|
||||||
valid_sketch_idxs_inter = [
|
|
||||||
idx for iid, idx in matrix_data["sketch_index_interaction"].items()
|
|
||||||
if iid in category_iids
|
|
||||||
]
|
|
||||||
|
|
||||||
# 处理交互分数
|
|
||||||
raw_inter_scores = []
|
|
||||||
if user_idx_inter is not None and valid_sketch_idxs_inter:
|
|
||||||
raw_inter_scores = matrix_data["interaction_matrix"][user_idx_inter, valid_sketch_idxs_inter]
|
|
||||||
processed_inter = raw_inter_scores * 0.7
|
|
||||||
|
|
||||||
# 处理特征分数
|
|
||||||
valid_sketch_idxs_feature = [
|
|
||||||
idx for iid, idx in matrix_data["sketch_index_feature"].items()
|
|
||||||
if iid in category_iids
|
|
||||||
]
|
|
||||||
raw_feat_scores = []
|
|
||||||
if user_idx_feature is not None and valid_sketch_idxs_feature:
|
|
||||||
raw_feat_scores = matrix_data["feature_matrix"][user_idx_feature, valid_sketch_idxs_feature]
|
|
||||||
raw_feat_scores = (raw_feat_scores - np.min(raw_feat_scores)) / (
|
|
||||||
np.max(raw_feat_scores) - np.min(raw_feat_scores) + 1e-8)
|
|
||||||
processed_feat = raw_feat_scores
|
|
||||||
else:
|
|
||||||
processed_feat = np.array([])
|
|
||||||
|
|
||||||
# 更新缓存
|
|
||||||
matrix_data["cached_scores"][cache_key] = (processed_inter, processed_feat)
|
|
||||||
matrix_data["cached_valid_idxs"][cache_key] = valid_sketch_idxs_inter
|
|
||||||
|
|
||||||
# 合并分数
|
|
||||||
if brand_id is not None:
|
|
||||||
brand_idx_feature = matrix_data["brand_index_map"].get(brand_id)
|
|
||||||
|
|
||||||
brand_feat_valid = (
|
|
||||||
matrix_data["brand_feature_matrix"].size > 0 and # 矩阵非空
|
|
||||||
brand_idx_feature is not None and
|
|
||||||
valid_sketch_idxs_feature # 有可用索引
|
|
||||||
)
|
|
||||||
|
|
||||||
if brand_feat_valid:
|
|
||||||
raw_brand_feat_scores = matrix_data["brand_feature_matrix"][
|
|
||||||
brand_idx_feature, valid_sketch_idxs_feature
|
|
||||||
]
|
|
||||||
raw_brand_feat_scores = (raw_brand_feat_scores - np.min(raw_brand_feat_scores)) / (
|
|
||||||
np.max(raw_brand_feat_scores) - np.min(raw_brand_feat_scores) + 1e-8
|
|
||||||
)
|
|
||||||
processed_brand_feat = raw_brand_feat_scores
|
|
||||||
|
|
||||||
# 如果 processed_feat 是空的,替换为全 0,避免 shape 不一致
|
|
||||||
if processed_feat.size == 0:
|
|
||||||
processed_feat = np.zeros_like(processed_brand_feat)
|
|
||||||
|
|
||||||
final_scores = processed_inter + 0.3 * (
|
|
||||||
(1 - brand_scale) * processed_feat + brand_scale * processed_brand_feat
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# brand 信息不可用
|
|
||||||
final_scores = processed_inter + 0.3 * processed_feat if processed_feat.size > 0 else processed_inter
|
|
||||||
else:
|
|
||||||
final_scores = processed_inter + 0.3 * processed_feat if processed_feat.size > 0 else processed_inter
|
|
||||||
|
|
||||||
valid_sketch_idxs = matrix_data["cached_valid_idxs"][cache_key]
|
|
||||||
|
|
||||||
# 概率采样
|
|
||||||
scores = np.array(final_scores)
|
|
||||||
|
|
||||||
# 调整后的概率转换(带温度控制的softmax)
|
|
||||||
def calibrated_softmax(scores, temperature=1.0):
|
|
||||||
scores = scores / temperature
|
|
||||||
scale = scores - max(scores)
|
|
||||||
exps = np.exp(scale)
|
|
||||||
return exps / np.sum(exps)
|
|
||||||
|
|
||||||
probs = calibrated_softmax(scores, 0.09)
|
|
||||||
|
|
||||||
chosen_indices = np.random.choice(
|
|
||||||
len(valid_sketch_idxs),
|
|
||||||
size=min(num_recommendations, len(valid_sketch_idxs)),
|
|
||||||
p=probs,
|
|
||||||
replace=False
|
|
||||||
)
|
|
||||||
recommendations = [matrix_data["iid_to_sketch"][valid_sketch_idxs[idx]] for idx in chosen_indices]
|
|
||||||
|
|
||||||
logger.info(f"推荐生成完成,耗时: {time.time() - start_time:.2f}秒")
|
|
||||||
return recommendations
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"推荐失败: {str(e)}", exc_info=True)
|
logger.error(f"启动增量监听任务失败: {e}", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/recommend/{user_id}/{category}", response_model=List[str])
|
||||||
|
async def recommend(
|
||||||
|
user_id: int,
|
||||||
|
category: str,
|
||||||
|
style: Optional[str] = Query(
|
||||||
|
None,
|
||||||
|
description="风格样式(可选):若传入,则在利用分支对同 style 的候选进行加分",
|
||||||
|
),
|
||||||
|
):
|
||||||
|
"""新版推荐接口(Milvus + Redis 偏好向量)。"""
|
||||||
|
try:
|
||||||
|
results = get_new_recommendations(user_id, category, style)
|
||||||
|
path = results[0] if results else ""
|
||||||
|
return [path]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("新版推荐接口失败 [user=%s, category=%s]: %s", user_id, category, e, exc_info=True)
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
@@ -10,8 +10,10 @@ from app.api import api_design_pre_processing
|
|||||||
from app.api import api_extraction_project_info
|
from app.api import api_extraction_project_info
|
||||||
from app.api import api_generate_image
|
from app.api import api_generate_image
|
||||||
from app.api import api_image2sketch
|
from app.api import api_image2sketch
|
||||||
|
from app.api import api_import_sys_sketch
|
||||||
from app.api import api_mannequins_edit
|
from app.api import api_mannequins_edit
|
||||||
from app.api import api_pose_transform
|
from app.api import api_pose_transform
|
||||||
|
from app.api import api_precompute
|
||||||
from app.api import api_prompt_generation
|
from app.api import api_prompt_generation
|
||||||
from app.api import api_recommendation
|
from app.api import api_recommendation
|
||||||
from app.api import api_super_resolution
|
from app.api import api_super_resolution
|
||||||
@@ -36,3 +38,5 @@ router.include_router(api_mannequins_edit.router, tags=['api_mannequins_edit'],
|
|||||||
router.include_router(api_pose_transform.router, tags=['api_pose_transform'], prefix="/api")
|
router.include_router(api_pose_transform.router, tags=['api_pose_transform'], prefix="/api")
|
||||||
router.include_router(api_clothing_seg.router, tags=['api_clothing_seg'], prefix="/api")
|
router.include_router(api_clothing_seg.router, tags=['api_clothing_seg'], prefix="/api")
|
||||||
router.include_router(api_extraction_project_info.router, tags=['api_extraction_project_info'], prefix="/api")
|
router.include_router(api_extraction_project_info.router, tags=['api_extraction_project_info'], prefix="/api")
|
||||||
|
router.include_router(api_import_sys_sketch.router, tags=['api_import_sys_sketch'], prefix="/api")
|
||||||
|
router.include_router(api_precompute.router, tags=['api_precompute'], prefix="/api")
|
||||||
|
|||||||
@@ -82,9 +82,9 @@ MILVUS_TABLE_SEG = "seg_cache"
|
|||||||
DB_HOST = '18.167.251.121' # 数据库主机地址
|
DB_HOST = '18.167.251.121' # 数据库主机地址
|
||||||
# DB_PORT = int( 33006)
|
# DB_PORT = int( 33006)
|
||||||
DB_PORT = 33008 # 数据库端口
|
DB_PORT = 33008 # 数据库端口
|
||||||
DB_USERNAME = 'aida_con_python' # 数据库用户名
|
DB_USERNAME = 'aida_con' # 数据库用户名
|
||||||
DB_PASSWORD = '123456' # 数据库密码
|
DB_PASSWORD = '123456' # 数据库密码
|
||||||
DB_NAME = 'aida' # 数据库库名
|
DB_NAME = 'aida_back' # 数据库库名
|
||||||
|
|
||||||
# openai
|
# openai
|
||||||
os.environ['SERPAPI_API_KEY'] = "a793513017b0718db7966207c31703d280d12435c982f1e67bbcbffa52e7632c"
|
os.environ['SERPAPI_API_KEY'] = "a793513017b0718db7966207c31703d280d12435c982f1e67bbcbffa52e7632c"
|
||||||
@@ -230,3 +230,6 @@ TABLE_CATEGORIES = {
|
|||||||
"male_bottoms": "male/bottoms",
|
"male_bottoms": "male/bottoms",
|
||||||
"male_outwear": "male/outwear"
|
"male_outwear": "male/outwear"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# --- ComfyUI 配置信息 ---
|
||||||
|
COMFYUI_SERVER_ADDRESS = "10.1.2.227:8080" # 替换为您的 ComfyUI 服务器地址
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ from app.api.api_route import router
|
|||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
from app.core.record_api_count import count_api_calls
|
from app.core.record_api_count import count_api_calls
|
||||||
from app.schemas.response_template import ResponseModel
|
from app.schemas.response_template import ResponseModel
|
||||||
from app.service.recommend.service import load_resources
|
|
||||||
from logging_env import LOGGER_CONFIG_DICT
|
from logging_env import LOGGER_CONFIG_DICT
|
||||||
|
|
||||||
logging.config.dictConfig(LOGGER_CONFIG_DICT)
|
logging.config.dictConfig(LOGGER_CONFIG_DICT)
|
||||||
|
|||||||
23
app/schemas/comfyui_i2v.py
Normal file
23
app/schemas/comfyui_i2v.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class ComfyuiPose2VModel(BaseModel):
|
||||||
|
# 骨架生成视频
|
||||||
|
image_url: str
|
||||||
|
tasks_id: str
|
||||||
|
pose_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class ComfyuiI2VModel(BaseModel):
|
||||||
|
# 图生视频
|
||||||
|
image_url: str
|
||||||
|
prompt: str
|
||||||
|
tasks_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class ComfyuiFLF2VModel(BaseModel):
|
||||||
|
# 首尾帧生视频
|
||||||
|
start_image_url: str
|
||||||
|
end_image_url: str
|
||||||
|
prompt: str
|
||||||
|
tasks_id: str
|
||||||
@@ -10,6 +10,7 @@ class DesignStreamModel(BaseModel):
|
|||||||
objects: list[dict]
|
objects: list[dict]
|
||||||
process_id: str
|
process_id: str
|
||||||
requestId: str
|
requestId: str
|
||||||
|
callback_url: str
|
||||||
|
|
||||||
|
|
||||||
class DesignProgressModel(BaseModel):
|
class DesignProgressModel(BaseModel):
|
||||||
|
|||||||
639
app/service/comfyui_I2V/flf2v_server.py
Normal file
639
app/service/comfyui_I2V/flf2v_server.py
Normal file
@@ -0,0 +1,639 @@
|
|||||||
|
import io
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import tempfile
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from PIL import Image
|
||||||
|
from minio import Minio, S3Error
|
||||||
|
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||||
|
|
||||||
|
from app.core.config import MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, COMFYUI_SERVER_ADDRESS, PS_RABBITMQ_QUEUES, DEBUG
|
||||||
|
from app.schemas.comfyui_i2v import ComfyuiFLF2VModel
|
||||||
|
from app.service.generate_image.utils.mq import publish_status
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
# 首尾帧 + 文字 = 视频 工作流
|
||||||
|
workflow_json = {
|
||||||
|
"6": {
|
||||||
|
"inputs": {
|
||||||
|
"text": "A bearded man with red facial hair wearing a yellow straw hat and dark coat in Van Gogh's self-portrait style, slowly and continuously transforms into a space astronaut. The transformation flows like liquid paint - his beard fades away strand by strand, the yellow hat melts and reforms smoothly into a silver space helmet, dark coat gradually lightens and restructures into a white spacesuit. The background swirling brushstrokes slowly organize and clarify into realistic stars and space, with Earth appearing gradually in the distance. Every change happens in seamless waves, maintaining visual continuity throughout the metamorphosis.\n\nConsistent soft lighting throughout, medium close-up maintaining same framing, central composition stays fixed, gentle color temperature shift from warm to cool, gradual contrast increase, smooth style transition from painterly to photorealistic. Static camera with subtle slow zoom, emphasizing the flowing transformation process without abrupt changes.",
|
||||||
|
"clip": [
|
||||||
|
"38",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "CLIPTextEncode",
|
||||||
|
"_meta": {
|
||||||
|
"title": "CLIP Text Encode (Positive Prompt)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"7": {
|
||||||
|
"inputs": {
|
||||||
|
"text": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||||
|
"clip": [
|
||||||
|
"38",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "CLIPTextEncode",
|
||||||
|
"_meta": {
|
||||||
|
"title": "CLIP Text Encode (Negative Prompt)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"8": {
|
||||||
|
"inputs": {
|
||||||
|
"samples": [
|
||||||
|
"58",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"vae": [
|
||||||
|
"39",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "VAEDecode",
|
||||||
|
"_meta": {
|
||||||
|
"title": "VAE解码"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"37": {
|
||||||
|
"inputs": {
|
||||||
|
"unet_name": "wan2.2_i2v_high_noise_14B_fp8_scaled.safetensors",
|
||||||
|
"weight_dtype": "default"
|
||||||
|
},
|
||||||
|
"class_type": "UNETLoader",
|
||||||
|
"_meta": {
|
||||||
|
"title": "UNet加载器"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"38": {
|
||||||
|
"inputs": {
|
||||||
|
"clip_name": "umt5_xxl_fp8_e4m3fn_scaled.safetensors",
|
||||||
|
"type": "wan",
|
||||||
|
"device": "default"
|
||||||
|
},
|
||||||
|
"class_type": "CLIPLoader",
|
||||||
|
"_meta": {
|
||||||
|
"title": "加载CLIP"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"39": {
|
||||||
|
"inputs": {
|
||||||
|
"vae_name": "wan_2.1_vae.safetensors"
|
||||||
|
},
|
||||||
|
"class_type": "VAELoader",
|
||||||
|
"_meta": {
|
||||||
|
"title": "加载VAE"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"54": {
|
||||||
|
"inputs": {
|
||||||
|
"shift": 5,
|
||||||
|
"model": [
|
||||||
|
"91",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "ModelSamplingSD3",
|
||||||
|
"_meta": {
|
||||||
|
"title": "采样算法(SD3)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"55": {
|
||||||
|
"inputs": {
|
||||||
|
"shift": 5,
|
||||||
|
"model": [
|
||||||
|
"92",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "ModelSamplingSD3",
|
||||||
|
"_meta": {
|
||||||
|
"title": "采样算法(SD3)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"56": {
|
||||||
|
"inputs": {
|
||||||
|
"unet_name": "wan2.2_i2v_low_noise_14B_fp8_scaled.safetensors",
|
||||||
|
"weight_dtype": "default"
|
||||||
|
},
|
||||||
|
"class_type": "UNETLoader",
|
||||||
|
"_meta": {
|
||||||
|
"title": "UNet加载器"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"57": {
|
||||||
|
"inputs": {
|
||||||
|
"add_noise": "enable",
|
||||||
|
"noise_seed": 984937593540091,
|
||||||
|
"steps": 4,
|
||||||
|
"cfg": 1,
|
||||||
|
"sampler_name": "euler",
|
||||||
|
"scheduler": "simple",
|
||||||
|
"start_at_step": 0,
|
||||||
|
"end_at_step": 2,
|
||||||
|
"return_with_leftover_noise": "enable",
|
||||||
|
"model": [
|
||||||
|
"54",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"positive": [
|
||||||
|
"67",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"negative": [
|
||||||
|
"67",
|
||||||
|
1
|
||||||
|
],
|
||||||
|
"latent_image": [
|
||||||
|
"67",
|
||||||
|
2
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "KSamplerAdvanced",
|
||||||
|
"_meta": {
|
||||||
|
"title": "K采样器(高级)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"58": {
|
||||||
|
"inputs": {
|
||||||
|
"add_noise": "disable",
|
||||||
|
"noise_seed": 0,
|
||||||
|
"steps": 4,
|
||||||
|
"cfg": 1,
|
||||||
|
"sampler_name": "euler",
|
||||||
|
"scheduler": "simple",
|
||||||
|
"start_at_step": 2,
|
||||||
|
"end_at_step": 10000,
|
||||||
|
"return_with_leftover_noise": "disable",
|
||||||
|
"model": [
|
||||||
|
"55",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"positive": [
|
||||||
|
"67",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"negative": [
|
||||||
|
"67",
|
||||||
|
1
|
||||||
|
],
|
||||||
|
"latent_image": [
|
||||||
|
"57",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "KSamplerAdvanced",
|
||||||
|
"_meta": {
|
||||||
|
"title": "K采样器(高级)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"60": {
|
||||||
|
"inputs": {
|
||||||
|
"fps": 16,
|
||||||
|
"images": [
|
||||||
|
"8",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "CreateVideo",
|
||||||
|
"_meta": {
|
||||||
|
"title": "创建视频"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"61": {
|
||||||
|
"inputs": {
|
||||||
|
"filename_prefix": "video/ComfyUI",
|
||||||
|
"format": "auto",
|
||||||
|
"codec": "auto",
|
||||||
|
"video": [
|
||||||
|
"60",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "SaveVideo",
|
||||||
|
"_meta": {
|
||||||
|
"title": "保存视频"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"62": {
|
||||||
|
"inputs": {
|
||||||
|
"image": "video_wan2_2_14B_flf2v_start_image.png"
|
||||||
|
},
|
||||||
|
"class_type": "LoadImage",
|
||||||
|
"_meta": {
|
||||||
|
"title": "加载end图像"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"67": {
|
||||||
|
"inputs": {
|
||||||
|
"width": 640,
|
||||||
|
"height": 640,
|
||||||
|
"length": 81,
|
||||||
|
"batch_size": 1,
|
||||||
|
"positive": [
|
||||||
|
"6",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"negative": [
|
||||||
|
"7",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"vae": [
|
||||||
|
"39",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"start_image": [
|
||||||
|
"68",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"end_image": [
|
||||||
|
"62",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "WanFirstLastFrameToVideo",
|
||||||
|
"_meta": {
|
||||||
|
"title": "WanFirstLastFrameToVideo"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"68": {
|
||||||
|
"inputs": {
|
||||||
|
"image": "video_wan2_2_14B_flf2v_end_image.png"
|
||||||
|
},
|
||||||
|
"class_type": "LoadImage",
|
||||||
|
"_meta": {
|
||||||
|
"title": "加载start图像"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"91": {
|
||||||
|
"inputs": {
|
||||||
|
"lora_name": "wan2.2_i2v_lightx2v_4steps_lora_v1_high_noise.safetensors",
|
||||||
|
"strength_model": 1,
|
||||||
|
"model": [
|
||||||
|
"37",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "LoraLoaderModelOnly",
|
||||||
|
"_meta": {
|
||||||
|
"title": "LoRA加载器(仅模型)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"92": {
|
||||||
|
"inputs": {
|
||||||
|
"lora_name": "wan2.2_i2v_lightx2v_4steps_lora_v1_low_noise.safetensors",
|
||||||
|
"strength_model": 1,
|
||||||
|
"model": [
|
||||||
|
"56",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "LoraLoaderModelOnly",
|
||||||
|
"_meta": {
|
||||||
|
"title": "LoRA加载器(仅模型)"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ComfyUIServerFLF2V:
|
||||||
|
def __init__(self, request_data):
|
||||||
|
self.start_image_url = request_data.start_image_url
|
||||||
|
self.end_image_url = request_data.end_image_url
|
||||||
|
self.prompt = request_data.prompt
|
||||||
|
self.tasks_id = request_data.tasks_id
|
||||||
|
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
|
||||||
|
self.server_status_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'gif_url': '', 'video_url': '', 'image_url': ''}
|
||||||
|
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||||
|
|
||||||
|
def get_result(self):
|
||||||
|
workflow_json['6']['inputs']['text'] = self.prompt
|
||||||
|
workflow_json['57']['inputs']["noise_seed"] = random.randint(0, 10 ** 18)
|
||||||
|
|
||||||
|
if self.start_image_url:
|
||||||
|
# 下载图片 上传 comfyui server
|
||||||
|
# TODO 设置视频宽度为480,高度自适应
|
||||||
|
workflow_json['67']['inputs']["width"] = 480
|
||||||
|
workflow_json['67']['inputs']["height"] = 848
|
||||||
|
if self.start_image_url:
|
||||||
|
start_in_memory_file, start_object_name = self.download_from_minio_in_memory(self.start_image_url)
|
||||||
|
# 上传图片到comfyui server
|
||||||
|
filename = self.upload_in_memory_file_to_comfyui(start_in_memory_file, start_object_name)
|
||||||
|
workflow_json['68']['inputs']['image'] = filename
|
||||||
|
else:
|
||||||
|
assert "start_image_url is None"
|
||||||
|
|
||||||
|
if self.end_image_url:
|
||||||
|
end_in_memory_file, end_object_name = self.download_from_minio_in_memory(self.end_image_url)
|
||||||
|
# 上传图片到comfyui server
|
||||||
|
filename = self.upload_in_memory_file_to_comfyui(end_in_memory_file, end_object_name)
|
||||||
|
workflow_json['62']['inputs']['image'] = filename
|
||||||
|
else:
|
||||||
|
assert "end_image_url is None"
|
||||||
|
|
||||||
|
# 1. 提交任务
|
||||||
|
prompt_response = self.queue_prompt(workflow_json, self.tasks_id)
|
||||||
|
if not prompt_response:
|
||||||
|
return
|
||||||
|
|
||||||
|
prompt_id = prompt_response.get("prompt_id")
|
||||||
|
logger.info(f" 任务已提交,Prompt ID: {prompt_id}")
|
||||||
|
outputs = self.poll_history(prompt_id)
|
||||||
|
file_list = {}
|
||||||
|
for node_id, node_output in outputs.items():
|
||||||
|
# 检查当前节点输出中是否包含 'images' 列表
|
||||||
|
if 'images' in node_output and isinstance(node_output['images'], list):
|
||||||
|
# 'images' 列表中的每个元素都是一个文件对象
|
||||||
|
for file_info in node_output['images']:
|
||||||
|
# 确保关键字段存在
|
||||||
|
if all(key in file_info for key in ['filename', 'subfolder', 'type']):
|
||||||
|
file_list = {
|
||||||
|
'filename': file_info['filename'],
|
||||||
|
'subfolder': file_info['subfolder'],
|
||||||
|
'type': file_info['type']
|
||||||
|
}
|
||||||
|
logger.info(file_list)
|
||||||
|
return self.process_and_upload_comfyui_video(filename=file_list['filename'], subfolder=file_list['subfolder'], prompt_id=prompt_response['prompt_id']), prompt_id
|
||||||
|
|
||||||
|
def download_from_minio_in_memory(self, image_url):
|
||||||
|
bucket = image_url.split('/')[0]
|
||||||
|
object_name = image_url[image_url.find('/') + 1:]
|
||||||
|
|
||||||
|
try:
|
||||||
|
# get_object 返回一个 ResponseStream 对象
|
||||||
|
response_stream = self.minio_client.get_object(
|
||||||
|
bucket,
|
||||||
|
object_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 读取整个流到内存 (BytesIO),避免写入本地文件
|
||||||
|
image_bytes = response_stream.read()
|
||||||
|
|
||||||
|
response_stream.close()
|
||||||
|
response_stream.release_conn()
|
||||||
|
|
||||||
|
in_memory_file = io.BytesIO(image_bytes)
|
||||||
|
|
||||||
|
# print(f"✅ 图片已下载到内存 ({len(image_bytes)} 字节)。")
|
||||||
|
return in_memory_file, object_name.rsplit('/')[-1]
|
||||||
|
|
||||||
|
except S3Error as e:
|
||||||
|
logger.error(f"❌ MinIO S3 错误 (例如,对象不存在): {e}")
|
||||||
|
return None, None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ MinIO 下载过程中发生未知错误: {e}")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
def upload_in_memory_file_to_comfyui(self, in_memory_file, filename):
|
||||||
|
upload_url = f"http://{COMFYUI_SERVER_ADDRESS}/upload/image"
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"overwrite": "true",
|
||||||
|
"type": "input"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 构建 multipart/form-data: (文件名, 内存文件对象, MIME 类型)
|
||||||
|
# MIME 类型可以根据实际图片类型修改,这里使用常见的 png/jpeg
|
||||||
|
mime_type = 'image/png' if filename.lower().endswith('.png') else 'image/jpeg'
|
||||||
|
|
||||||
|
files = {
|
||||||
|
'image': (filename, in_memory_file, mime_type)
|
||||||
|
}
|
||||||
|
|
||||||
|
# print(f"⬆️ 正在上传图片 ({filename}) 到 ComfyUI...")
|
||||||
|
try:
|
||||||
|
comfyui_response = requests.post(upload_url, data=data, files=files)
|
||||||
|
comfyui_response.raise_for_status()
|
||||||
|
|
||||||
|
result = comfyui_response.json()
|
||||||
|
uploaded_name = result.get('name')
|
||||||
|
|
||||||
|
# print(f"🎉 ComfyUI 上传成功! 服务器文件名: {uploaded_name}")
|
||||||
|
return uploaded_name
|
||||||
|
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
logger.error(f"❌ ComfyUI 上传失败: {e}")
|
||||||
|
logger.error(f" 响应内容: {comfyui_response.text}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def process_and_upload_comfyui_video(self, filename: str, subfolder: str, prompt_id: str, ):
|
||||||
|
"""
|
||||||
|
完整的自动化流程:获取 ComfyUI 视频 -> 转换 GIF 并提取帧 -> 上传所有结果到 MinIO。
|
||||||
|
"""
|
||||||
|
# 1. 从 ComfyUI 获取视频二进制数据
|
||||||
|
mp4_bytes = self.get_comfyui_video_bytes(filename, subfolder)
|
||||||
|
if not mp4_bytes:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 2. 准备进行视频处理
|
||||||
|
# moviepy 不支持直接使用 bytes,需要将 bytes 写入一个 BytesIO 或临时文件
|
||||||
|
# 为了避免写磁盘,我们将使用 BytesIO,但 MoviePy 内部依赖 FFmpeg,有时需要一个可寻址的本地文件路径。
|
||||||
|
# 最可靠且避免写本地的方案是在内存中操作,然后将结果上传。
|
||||||
|
|
||||||
|
# ⚠️ 关键点:将 mp4_bytes 写入 BytesIO 以模拟文件,供 moviepy 读取
|
||||||
|
|
||||||
|
# 定义输出对象名
|
||||||
|
|
||||||
|
output_base_name = uuid.uuid4().hex
|
||||||
|
MP4_OBJECT = f"{self.user_id}/pose_transform_video/{prompt_id}/{output_base_name}.mp4"
|
||||||
|
GIF_OBJECT = f"{self.user_id}/pose_transform_gif/{prompt_id}/{output_base_name}.gif"
|
||||||
|
FRAME_OBJECT = f"{self.user_id}/pose_transform_first_img/{prompt_id}/{output_base_name}_frame.jpg"
|
||||||
|
|
||||||
|
# --- 视频处理和帧提取 ---
|
||||||
|
try:
|
||||||
|
# 1. 创建一个临时的 MP4 文件路径
|
||||||
|
# delete=False 确保文件在关闭后仍然存在,直到我们手动删除
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_file:
|
||||||
|
tmp_file.write(mp4_bytes) # 将内存数据写入磁盘
|
||||||
|
temp_mp4_path = tmp_file.name # 记录文件路径
|
||||||
|
|
||||||
|
# print(f"临时文件已写入: {temp_mp4_path}")
|
||||||
|
|
||||||
|
# 2. 使用 moviepy 打开临时文件 (传入文件路径字符串)
|
||||||
|
clip = VideoFileClip(temp_mp4_path)
|
||||||
|
|
||||||
|
# --- 在这里进行所有的视频处理和提取操作 ---
|
||||||
|
|
||||||
|
# 提取第一帧 (保持原尺寸)
|
||||||
|
frame_array = clip.get_frame(t=0.0)
|
||||||
|
image = Image.fromarray(frame_array)
|
||||||
|
|
||||||
|
frame_stream = io.BytesIO()
|
||||||
|
image.save(frame_stream, 'JPEG')
|
||||||
|
frame_bytes = frame_stream.getvalue()
|
||||||
|
|
||||||
|
logger.info("✅ 成功提取第一帧图片。")
|
||||||
|
|
||||||
|
# 视频转 GIF (使用另一个临时文件来保存 GIF)
|
||||||
|
temp_gif_path = ""
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".gif", delete=False) as tmp_file:
|
||||||
|
temp_gif_path = tmp_file.name
|
||||||
|
|
||||||
|
target_fps = int(round(clip.fps)) if clip.fps else 24
|
||||||
|
clip.write_gif(temp_gif_path, fps=target_fps)
|
||||||
|
|
||||||
|
with open(temp_gif_path, 'rb') as f:
|
||||||
|
gif_bytes = f.read()
|
||||||
|
|
||||||
|
logger.info("✅ 成功生成 GIF。")
|
||||||
|
|
||||||
|
# 返回结果 (例如: 上传到 MinIO)
|
||||||
|
# return mp4_bytes, gif_bytes, frame_bytes
|
||||||
|
|
||||||
|
# -----------------------------------------------
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ 视频处理或文件操作失败: {e}")
|
||||||
|
# 在失败时,也尝试清理文件
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# 3. 清理临时文件 (非常重要!)
|
||||||
|
if os.path.exists(temp_mp4_path):
|
||||||
|
os.remove(temp_mp4_path)
|
||||||
|
logger.info(f"🗑️ 已删除临时 MP4 文件: {temp_mp4_path}")
|
||||||
|
|
||||||
|
if 'temp_gif_path' in locals() and os.path.exists(temp_gif_path):
|
||||||
|
os.remove(temp_gif_path)
|
||||||
|
logger.info(f"🗑️ 已删除临时 GIF 文件: {temp_gif_path}")
|
||||||
|
|
||||||
|
# 3. 上传所有结果到 MinIO
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 上传原始 MP4
|
||||||
|
self.upload_stream_to_minio(mp4_bytes, MP4_OBJECT, "video/mp4")
|
||||||
|
|
||||||
|
# 上传生成的 GIF
|
||||||
|
self.upload_stream_to_minio(gif_bytes, GIF_OBJECT, "image/gif")
|
||||||
|
|
||||||
|
# 上传第一帧图片
|
||||||
|
self.upload_stream_to_minio(frame_bytes, FRAME_OBJECT, "image/jpeg")
|
||||||
|
|
||||||
|
self.pose_transform_data = {'tasks_id': self.tasks_id, 'status': 'SUCCESS', 'message': "success", 'gif_url': f'aida-users/{GIF_OBJECT}', 'video_url': f'aida-users/{MP4_OBJECT}', 'image_url': f'aida-users/{FRAME_OBJECT}'}
|
||||||
|
|
||||||
|
# 推送消息
|
||||||
|
if not DEBUG:
|
||||||
|
publish_status(json.dumps(self.pose_transform_data), PS_RABBITMQ_QUEUES)
|
||||||
|
logger.info(
|
||||||
|
f" [x] Sent to: {PS_RABBITMQ_QUEUES} data:@@@@ {json.dumps(self.pose_transform_data, indent=4)}")
|
||||||
|
|
||||||
|
return "\n🎉 所有任务完成!"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# --- 辅助函数:提交任务到队列 ---
|
||||||
|
def queue_prompt(self, prompt, client_id):
|
||||||
|
"""向 ComfyUI 提交工作流提示。"""
|
||||||
|
p = {"prompt": prompt, "client_id": client_id, "prompt_id": client_id}
|
||||||
|
data = json.dumps(p).encode('utf-8')
|
||||||
|
|
||||||
|
# 提交任务到 /prompt 端点
|
||||||
|
response = requests.post(f"http://{COMFYUI_SERVER_ADDRESS}/prompt", data=data)
|
||||||
|
# print(f"-------------{response.text}")
|
||||||
|
# print(f"------------{client_id}")
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
return response.json()
|
||||||
|
else:
|
||||||
|
logger.warning(f"提交任务失败,状态码: {response.status_code}")
|
||||||
|
logger.warning(response.text)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def poll_history(self, prompt_id, interval_seconds=5):
|
||||||
|
"""步骤 2: 轮询 /history/{prompt_id} 检查任务是否完成"""
|
||||||
|
url = f"http://{COMFYUI_SERVER_ADDRESS}/history/{prompt_id}"
|
||||||
|
|
||||||
|
logger.info(f"⏳ 开始轮询状态 (间隔 {interval_seconds} 秒)...")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
time.sleep(interval_seconds)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.get(url)
|
||||||
|
# 任务未完成时,ComfyUI可能会返回404或空响应,我们只关注成功响应
|
||||||
|
if response.status_code == 200:
|
||||||
|
history_data = response.json()
|
||||||
|
|
||||||
|
# ComfyUI 返回的历史记录结构是 {prompt_id: {outputs: ...}}
|
||||||
|
if prompt_id in history_data:
|
||||||
|
logger.info("🎉 任务已完成!")
|
||||||
|
return history_data[prompt_id]['outputs']
|
||||||
|
|
||||||
|
logger.info("⏳ 任务仍在执行或等待中...")
|
||||||
|
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
# 处理可能的连接错误,但通常不会在内部轮询中发生
|
||||||
|
logger.info(f"⚠️ 轮询时发生错误: {e}")
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_comfyui_video_bytes(self, filename: str, subfolder: str, file_type: str = "output"):
|
||||||
|
"""
|
||||||
|
从 ComfyUI 的 /view 端点获取视频文件的二进制数据。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- filename: 视频文件名 (例如: 'ComfyUI_00002_.mp4')
|
||||||
|
- subfolder: 存储子文件夹 (例如: 'ComfyUI_2025-10-31')
|
||||||
|
- file_type: 文件类型 (通常是 'output')
|
||||||
|
|
||||||
|
返回:
|
||||||
|
- 视频文件的二进制内容 (bytes) 或 None。
|
||||||
|
"""
|
||||||
|
url = f"http://{COMFYUI_SERVER_ADDRESS}/view"
|
||||||
|
params = {
|
||||||
|
"filename": filename,
|
||||||
|
"subfolder": subfolder,
|
||||||
|
"type": file_type
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"📡 正在从 ComfyUI 下载视频: {filename}")
|
||||||
|
try:
|
||||||
|
# 使用 requests.get 下载文件
|
||||||
|
response = requests.get(url, params=params, stream=True)
|
||||||
|
response.raise_for_status() # 检查 HTTP 错误
|
||||||
|
|
||||||
|
# 返回文件的完整二进制内容
|
||||||
|
return response.content
|
||||||
|
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
logger.error(f"❌ 从 ComfyUI 获取视频失败: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def upload_stream_to_minio(self, video_bytes: bytes, object_name: str, content_type: str):
|
||||||
|
"""从内存流上传数据到 MinIO。"""
|
||||||
|
logger.info(f"☁️ 正在上传对象到 MinIO: {object_name}")
|
||||||
|
try:
|
||||||
|
|
||||||
|
data_stream = io.BytesIO(video_bytes)
|
||||||
|
|
||||||
|
result = self.minio_client.put_object(
|
||||||
|
bucket_name='aida-users',
|
||||||
|
object_name=object_name,
|
||||||
|
data=data_stream,
|
||||||
|
length=len(video_bytes),
|
||||||
|
content_type=content_type
|
||||||
|
)
|
||||||
|
logger.info(f"✅ MinIO 上传成功: {result.object_name}")
|
||||||
|
return True
|
||||||
|
except S3Error as e:
|
||||||
|
logger.error(f"❌ MinIO 上传失败: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
request_data = ComfyuiFLF2VModel(
|
||||||
|
tasks_id="202511051619-89111",
|
||||||
|
start_image_url="test/start.png",
|
||||||
|
end_image_url="test/end.png",
|
||||||
|
prompt="Model executing a series of poses, dynamic camera movement alternating between detailed close-ups and full shots."
|
||||||
|
)
|
||||||
|
|
||||||
|
server = ComfyUIServerFLF2V(request_data)
|
||||||
|
print(server.get_result())
|
||||||
616
app/service/comfyui_I2V/i2v_server.py
Normal file
616
app/service/comfyui_I2V/i2v_server.py
Normal file
@@ -0,0 +1,616 @@
|
|||||||
|
import io
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import tempfile
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from PIL import Image
|
||||||
|
from minio import Minio, S3Error
|
||||||
|
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||||
|
|
||||||
|
from app.core.config import MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, COMFYUI_SERVER_ADDRESS, PS_RABBITMQ_QUEUES, DEBUG
|
||||||
|
from app.schemas.comfyui_i2v import ComfyuiPose2VModel, ComfyuiI2VModel
|
||||||
|
from app.service.generate_image.utils.mq import publish_status
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
# 图 + 文字 = 视频 工作流
|
||||||
|
workflow_json = {
|
||||||
|
"84": {
|
||||||
|
"inputs": {
|
||||||
|
"clip_name": "umt5_xxl_fp8_e4m3fn_scaled.safetensors",
|
||||||
|
"type": "wan",
|
||||||
|
"device": "default"
|
||||||
|
},
|
||||||
|
"class_type": "CLIPLoader",
|
||||||
|
"_meta": {
|
||||||
|
"title": "加载CLIP"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"85": {
|
||||||
|
"inputs": {
|
||||||
|
"add_noise": "disable",
|
||||||
|
"noise_seed": 0,
|
||||||
|
"steps": 4,
|
||||||
|
"cfg": 1,
|
||||||
|
"sampler_name": "euler",
|
||||||
|
"scheduler": "simple",
|
||||||
|
"start_at_step": 2,
|
||||||
|
"end_at_step": 4,
|
||||||
|
"return_with_leftover_noise": "disable",
|
||||||
|
"model": [
|
||||||
|
"103",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"positive": [
|
||||||
|
"98",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"negative": [
|
||||||
|
"98",
|
||||||
|
1
|
||||||
|
],
|
||||||
|
"latent_image": [
|
||||||
|
"86",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "KSamplerAdvanced",
|
||||||
|
"_meta": {
|
||||||
|
"title": "K采样器(高级)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"86": {
|
||||||
|
"inputs": {
|
||||||
|
"add_noise": "enable",
|
||||||
|
"noise_seed": 823962998672127,
|
||||||
|
"steps": 4,
|
||||||
|
"cfg": 1,
|
||||||
|
"sampler_name": "euler",
|
||||||
|
"scheduler": "simple",
|
||||||
|
"start_at_step": 0,
|
||||||
|
"end_at_step": 2,
|
||||||
|
"return_with_leftover_noise": "enable",
|
||||||
|
"model": [
|
||||||
|
"104",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"positive": [
|
||||||
|
"98",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"negative": [
|
||||||
|
"98",
|
||||||
|
1
|
||||||
|
],
|
||||||
|
"latent_image": [
|
||||||
|
"98",
|
||||||
|
2
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "KSamplerAdvanced",
|
||||||
|
"_meta": {
|
||||||
|
"title": "K采样器(高级)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"87": {
|
||||||
|
"inputs": {
|
||||||
|
"samples": [
|
||||||
|
"85",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"vae": [
|
||||||
|
"90",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "VAEDecode",
|
||||||
|
"_meta": {
|
||||||
|
"title": "VAE解码"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"89": {
|
||||||
|
"inputs": {
|
||||||
|
"text": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||||
|
"clip": [
|
||||||
|
"84",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "CLIPTextEncode",
|
||||||
|
"_meta": {
|
||||||
|
"title": "CLIP Text Encode (Negative Prompt)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"90": {
|
||||||
|
"inputs": {
|
||||||
|
"vae_name": "wan_2.1_vae.safetensors"
|
||||||
|
},
|
||||||
|
"class_type": "VAELoader",
|
||||||
|
"_meta": {
|
||||||
|
"title": "加载VAE"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"93": {
|
||||||
|
"inputs": {
|
||||||
|
"text": "Model executing a series of poses, dynamic camera movement alternating between detailed close-ups and full shots.",
|
||||||
|
"clip": [
|
||||||
|
"84",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "CLIPTextEncode",
|
||||||
|
"_meta": {
|
||||||
|
"title": "CLIP Text Encode (Positive Prompt)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"94": {
|
||||||
|
"inputs": {
|
||||||
|
"fps": 16,
|
||||||
|
"images": [
|
||||||
|
"87",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "CreateVideo",
|
||||||
|
"_meta": {
|
||||||
|
"title": "创建视频"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"95": {
|
||||||
|
"inputs": {
|
||||||
|
"unet_name": "wan2.2_i2v_high_noise_14B_fp8_scaled.safetensors",
|
||||||
|
"weight_dtype": "default"
|
||||||
|
},
|
||||||
|
"class_type": "UNETLoader",
|
||||||
|
"_meta": {
|
||||||
|
"title": "UNet加载器"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"96": {
|
||||||
|
"inputs": {
|
||||||
|
"unet_name": "wan2.2_i2v_low_noise_14B_fp8_scaled.safetensors",
|
||||||
|
"weight_dtype": "default"
|
||||||
|
},
|
||||||
|
"class_type": "UNETLoader",
|
||||||
|
"_meta": {
|
||||||
|
"title": "UNet加载器"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"97": {
|
||||||
|
"inputs": {
|
||||||
|
"image": "start (1).png"
|
||||||
|
},
|
||||||
|
"class_type": "LoadImage",
|
||||||
|
"_meta": {
|
||||||
|
"title": "加载图像"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"98": {
|
||||||
|
"inputs": {
|
||||||
|
"width": 480,
|
||||||
|
"height": 848,
|
||||||
|
"length": 81,
|
||||||
|
"batch_size": 1,
|
||||||
|
"positive": [
|
||||||
|
"93",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"negative": [
|
||||||
|
"89",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"vae": [
|
||||||
|
"90",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"start_image": [
|
||||||
|
"97",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "WanImageToVideo",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Wan图像到视频"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"101": {
|
||||||
|
"inputs": {
|
||||||
|
"lora_name": "wan2.2_i2v_lightx2v_4steps_lora_v1_high_noise.safetensors",
|
||||||
|
"strength_model": 1.0000000000000002,
|
||||||
|
"model": [
|
||||||
|
"95",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "LoraLoaderModelOnly",
|
||||||
|
"_meta": {
|
||||||
|
"title": "LoRA加载器(仅模型)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"102": {
|
||||||
|
"inputs": {
|
||||||
|
"lora_name": "wan2.2_i2v_lightx2v_4steps_lora_v1_low_noise.safetensors",
|
||||||
|
"strength_model": 1.0000000000000002,
|
||||||
|
"model": [
|
||||||
|
"96",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "LoraLoaderModelOnly",
|
||||||
|
"_meta": {
|
||||||
|
"title": "LoRA加载器(仅模型)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"103": {
|
||||||
|
"inputs": {
|
||||||
|
"shift": 5.000000000000001,
|
||||||
|
"model": [
|
||||||
|
"102",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "ModelSamplingSD3",
|
||||||
|
"_meta": {
|
||||||
|
"title": "采样算法(SD3)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"104": {
|
||||||
|
"inputs": {
|
||||||
|
"shift": 5.000000000000001,
|
||||||
|
"model": [
|
||||||
|
"101",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "ModelSamplingSD3",
|
||||||
|
"_meta": {
|
||||||
|
"title": "采样算法(SD3)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"108": {
|
||||||
|
"inputs": {
|
||||||
|
"filename_prefix": "video/ComfyUI",
|
||||||
|
"format": "auto",
|
||||||
|
"codec": "auto",
|
||||||
|
"video-preview": "",
|
||||||
|
"video": [
|
||||||
|
"94",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "SaveVideo",
|
||||||
|
"_meta": {
|
||||||
|
"title": "保存视频"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ComfyUIServerI2V:
|
||||||
|
def __init__(self, request_data):
|
||||||
|
self.image_url = request_data.image_url
|
||||||
|
self.prompt = request_data.prompt
|
||||||
|
|
||||||
|
self.tasks_id = request_data.tasks_id
|
||||||
|
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
|
||||||
|
self.server_status_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'gif_url': '', 'video_url': '', 'image_url': ''}
|
||||||
|
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||||
|
|
||||||
|
def get_result(self):
|
||||||
|
workflow_json['93']['inputs']['text'] = self.prompt
|
||||||
|
workflow_json['86']['inputs']["noise_seed"] = random.randint(0, 10 ** 18)
|
||||||
|
|
||||||
|
if self.image_url:
|
||||||
|
# 下载图片 上传 comfyui server
|
||||||
|
in_memory_file, object_name = self.download_from_minio_in_memory(self.image_url)
|
||||||
|
# TODO 设置视频宽度为480,高度自适应
|
||||||
|
workflow_json['98']['inputs']["width"] = 480
|
||||||
|
workflow_json['98']['inputs']["height"] = 848
|
||||||
|
if in_memory_file and object_name:
|
||||||
|
# 上传图片到comfyui server
|
||||||
|
filename = self.upload_in_memory_file_to_comfyui(in_memory_file, object_name)
|
||||||
|
workflow_json['97']['inputs']['image'] = filename
|
||||||
|
|
||||||
|
# 1. 提交任务
|
||||||
|
prompt_response = self.queue_prompt(workflow_json, self.tasks_id)
|
||||||
|
if not prompt_response:
|
||||||
|
return
|
||||||
|
prompt_id = prompt_response.get("prompt_id")
|
||||||
|
logger.info(f" 任务已提交,Prompt ID: {prompt_id}")
|
||||||
|
outputs = self.poll_history(prompt_id)
|
||||||
|
file_list = {}
|
||||||
|
for node_id, node_output in outputs.items():
|
||||||
|
# 检查当前节点输出中是否包含 'images' 列表
|
||||||
|
if 'images' in node_output and isinstance(node_output['images'], list):
|
||||||
|
|
||||||
|
# 'images' 列表中的每个元素都是一个文件对象
|
||||||
|
for file_info in node_output['images']:
|
||||||
|
# 确保关键字段存在
|
||||||
|
if all(key in file_info for key in ['filename', 'subfolder', 'type']):
|
||||||
|
file_list = {
|
||||||
|
'filename': file_info['filename'],
|
||||||
|
'subfolder': file_info['subfolder'],
|
||||||
|
'type': file_info['type']
|
||||||
|
}
|
||||||
|
logger.info(file_list)
|
||||||
|
return self.process_and_upload_comfyui_video(filename=file_list['filename'], subfolder=file_list['subfolder'], prompt_id=prompt_response['prompt_id']), prompt_id
|
||||||
|
|
||||||
|
def download_from_minio_in_memory(self, image_url):
|
||||||
|
bucket = image_url.split('/')[0]
|
||||||
|
object_name = image_url[image_url.find('/') + 1:]
|
||||||
|
|
||||||
|
try:
|
||||||
|
# get_object 返回一个 ResponseStream 对象
|
||||||
|
response_stream = self.minio_client.get_object(
|
||||||
|
bucket,
|
||||||
|
object_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 读取整个流到内存 (BytesIO),避免写入本地文件
|
||||||
|
image_bytes = response_stream.read()
|
||||||
|
|
||||||
|
response_stream.close()
|
||||||
|
response_stream.release_conn()
|
||||||
|
|
||||||
|
in_memory_file = io.BytesIO(image_bytes)
|
||||||
|
|
||||||
|
# print(f"✅ 图片已下载到内存 ({len(image_bytes)} 字节)。")
|
||||||
|
return in_memory_file, object_name.rsplit('/')[-1]
|
||||||
|
|
||||||
|
except S3Error as e:
|
||||||
|
logger.error(f"❌ MinIO S3 错误 (例如,对象不存在): {e}")
|
||||||
|
return None, None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ MinIO 下载过程中发生未知错误: {e}")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
def upload_in_memory_file_to_comfyui(self, in_memory_file, filename):
|
||||||
|
upload_url = f"http://{COMFYUI_SERVER_ADDRESS}/upload/image"
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"overwrite": "true",
|
||||||
|
"type": "input"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 构建 multipart/form-data: (文件名, 内存文件对象, MIME 类型)
|
||||||
|
# MIME 类型可以根据实际图片类型修改,这里使用常见的 png/jpeg
|
||||||
|
mime_type = 'image/png' if filename.lower().endswith('.png') else 'image/jpeg'
|
||||||
|
|
||||||
|
files = {
|
||||||
|
'image': (filename, in_memory_file, mime_type)
|
||||||
|
}
|
||||||
|
|
||||||
|
# print(f"⬆️ 正在上传图片 ({filename}) 到 ComfyUI...")
|
||||||
|
try:
|
||||||
|
comfyui_response = requests.post(upload_url, data=data, files=files)
|
||||||
|
comfyui_response.raise_for_status()
|
||||||
|
|
||||||
|
result = comfyui_response.json()
|
||||||
|
uploaded_name = result.get('name')
|
||||||
|
|
||||||
|
# print(f"🎉 ComfyUI 上传成功! 服务器文件名: {uploaded_name}")
|
||||||
|
return uploaded_name
|
||||||
|
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
logger.error(f"❌ ComfyUI 上传失败: {e}")
|
||||||
|
logger.error(f" 响应内容: {comfyui_response.text}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def process_and_upload_comfyui_video(self, filename: str, subfolder: str, prompt_id: str, ):
|
||||||
|
"""
|
||||||
|
完整的自动化流程:获取 ComfyUI 视频 -> 转换 GIF 并提取帧 -> 上传所有结果到 MinIO。
|
||||||
|
"""
|
||||||
|
# 1. 从 ComfyUI 获取视频二进制数据
|
||||||
|
mp4_bytes = self.get_comfyui_video_bytes(filename, subfolder)
|
||||||
|
if not mp4_bytes:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 2. 准备进行视频处理
|
||||||
|
# moviepy 不支持直接使用 bytes,需要将 bytes 写入一个 BytesIO 或临时文件
|
||||||
|
# 为了避免写磁盘,我们将使用 BytesIO,但 MoviePy 内部依赖 FFmpeg,有时需要一个可寻址的本地文件路径。
|
||||||
|
# 最可靠且避免写本地的方案是在内存中操作,然后将结果上传。
|
||||||
|
|
||||||
|
# ⚠️ 关键点:将 mp4_bytes 写入 BytesIO 以模拟文件,供 moviepy 读取
|
||||||
|
|
||||||
|
# 定义输出对象名
|
||||||
|
|
||||||
|
output_base_name = uuid.uuid4().hex
|
||||||
|
MP4_OBJECT = f"{self.user_id}/pose_transform_video/{prompt_id}/{output_base_name}.mp4"
|
||||||
|
GIF_OBJECT = f"{self.user_id}/pose_transform_gif/{prompt_id}/{output_base_name}.gif"
|
||||||
|
FRAME_OBJECT = f"{self.user_id}/pose_transform_first_img/{prompt_id}/{output_base_name}_frame.jpg"
|
||||||
|
|
||||||
|
# --- 视频处理和帧提取 ---
|
||||||
|
try:
|
||||||
|
# 1. 创建一个临时的 MP4 文件路径
|
||||||
|
# delete=False 确保文件在关闭后仍然存在,直到我们手动删除
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_file:
|
||||||
|
tmp_file.write(mp4_bytes) # 将内存数据写入磁盘
|
||||||
|
temp_mp4_path = tmp_file.name # 记录文件路径
|
||||||
|
|
||||||
|
# print(f"临时文件已写入: {temp_mp4_path}")
|
||||||
|
|
||||||
|
# 2. 使用 moviepy 打开临时文件 (传入文件路径字符串)
|
||||||
|
clip = VideoFileClip(temp_mp4_path)
|
||||||
|
|
||||||
|
# --- 在这里进行所有的视频处理和提取操作 ---
|
||||||
|
|
||||||
|
# 提取第一帧 (保持原尺寸)
|
||||||
|
frame_array = clip.get_frame(t=0.0)
|
||||||
|
image = Image.fromarray(frame_array)
|
||||||
|
|
||||||
|
frame_stream = io.BytesIO()
|
||||||
|
image.save(frame_stream, 'JPEG')
|
||||||
|
frame_bytes = frame_stream.getvalue()
|
||||||
|
|
||||||
|
logger.info("✅ 成功提取第一帧图片。")
|
||||||
|
|
||||||
|
# 视频转 GIF (使用另一个临时文件来保存 GIF)
|
||||||
|
temp_gif_path = ""
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".gif", delete=False) as tmp_file:
|
||||||
|
temp_gif_path = tmp_file.name
|
||||||
|
|
||||||
|
target_fps = int(round(clip.fps)) if clip.fps else 24
|
||||||
|
clip.write_gif(temp_gif_path, fps=target_fps)
|
||||||
|
|
||||||
|
with open(temp_gif_path, 'rb') as f:
|
||||||
|
gif_bytes = f.read()
|
||||||
|
|
||||||
|
logger.info("✅ 成功生成 GIF。")
|
||||||
|
|
||||||
|
# 返回结果 (例如: 上传到 MinIO)
|
||||||
|
# return mp4_bytes, gif_bytes, frame_bytes
|
||||||
|
|
||||||
|
# -----------------------------------------------
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ 视频处理或文件操作失败: {e}")
|
||||||
|
# 在失败时,也尝试清理文件
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# 3. 清理临时文件 (非常重要!)
|
||||||
|
if os.path.exists(temp_mp4_path):
|
||||||
|
os.remove(temp_mp4_path)
|
||||||
|
logger.info(f"🗑️ 已删除临时 MP4 文件: {temp_mp4_path}")
|
||||||
|
|
||||||
|
if 'temp_gif_path' in locals() and os.path.exists(temp_gif_path):
|
||||||
|
os.remove(temp_gif_path)
|
||||||
|
logger.info(f"🗑️ 已删除临时 GIF 文件: {temp_gif_path}")
|
||||||
|
|
||||||
|
# 3. 上传所有结果到 MinIO
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 上传原始 MP4
|
||||||
|
self.upload_stream_to_minio(mp4_bytes, MP4_OBJECT, "video/mp4")
|
||||||
|
|
||||||
|
# 上传生成的 GIF
|
||||||
|
self.upload_stream_to_minio(gif_bytes, GIF_OBJECT, "image/gif")
|
||||||
|
|
||||||
|
# 上传第一帧图片
|
||||||
|
self.upload_stream_to_minio(frame_bytes, FRAME_OBJECT, "image/jpeg")
|
||||||
|
|
||||||
|
self.pose_transform_data = {'tasks_id': self.tasks_id, 'status': 'SUCCESS', 'message': "success", 'gif_url': f'aida-users/{GIF_OBJECT}', 'video_url': f'aida-users/{MP4_OBJECT}', 'image_url': f'aida-users/{FRAME_OBJECT}'}
|
||||||
|
|
||||||
|
# 推送消息
|
||||||
|
if not DEBUG:
|
||||||
|
publish_status(json.dumps(self.pose_transform_data), PS_RABBITMQ_QUEUES)
|
||||||
|
logger.info(
|
||||||
|
f" [x] Sent to: {PS_RABBITMQ_QUEUES} data:@@@@ {json.dumps(self.pose_transform_data, indent=4)}")
|
||||||
|
|
||||||
|
return "\n🎉 所有任务完成!"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# --- 辅助函数:提交任务到队列 ---
|
||||||
|
def queue_prompt(self, prompt, client_id):
|
||||||
|
"""向 ComfyUI 提交工作流提示。"""
|
||||||
|
p = {"prompt": prompt, "client_id": client_id, "prompt_id": client_id}
|
||||||
|
data = json.dumps(p).encode('utf-8')
|
||||||
|
|
||||||
|
# 提交任务到 /prompt 端点
|
||||||
|
response = requests.post(f"http://{COMFYUI_SERVER_ADDRESS}/prompt", data=data)
|
||||||
|
# print(f"-------------{response.text}")
|
||||||
|
# print(f"------------{client_id}")
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
return response.json()
|
||||||
|
else:
|
||||||
|
logger.warning(f"提交任务失败,状态码: {response.status_code}")
|
||||||
|
logger.warning(response.text)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def poll_history(self, prompt_id, interval_seconds=5):
|
||||||
|
"""步骤 2: 轮询 /history/{prompt_id} 检查任务是否完成"""
|
||||||
|
url = f"http://{COMFYUI_SERVER_ADDRESS}/history/{prompt_id}"
|
||||||
|
|
||||||
|
logger.info(f"⏳ 开始轮询状态 (间隔 {interval_seconds} 秒)...")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
time.sleep(interval_seconds)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.get(url)
|
||||||
|
# 任务未完成时,ComfyUI可能会返回404或空响应,我们只关注成功响应
|
||||||
|
if response.status_code == 200:
|
||||||
|
history_data = response.json()
|
||||||
|
|
||||||
|
# ComfyUI 返回的历史记录结构是 {prompt_id: {outputs: ...}}
|
||||||
|
if prompt_id in history_data:
|
||||||
|
logger.info("🎉 任务已完成!")
|
||||||
|
return history_data[prompt_id]['outputs']
|
||||||
|
|
||||||
|
logger.info("⏳ 任务仍在执行或等待中...")
|
||||||
|
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
# 处理可能的连接错误,但通常不会在内部轮询中发生
|
||||||
|
logger.info(f"⚠️ 轮询时发生错误: {e}")
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_comfyui_video_bytes(self, filename: str, subfolder: str, file_type: str = "output"):
|
||||||
|
"""
|
||||||
|
从 ComfyUI 的 /view 端点获取视频文件的二进制数据。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- filename: 视频文件名 (例如: 'ComfyUI_00002_.mp4')
|
||||||
|
- subfolder: 存储子文件夹 (例如: 'ComfyUI_2025-10-31')
|
||||||
|
- file_type: 文件类型 (通常是 'output')
|
||||||
|
|
||||||
|
返回:
|
||||||
|
- 视频文件的二进制内容 (bytes) 或 None。
|
||||||
|
"""
|
||||||
|
url = f"http://{COMFYUI_SERVER_ADDRESS}/view"
|
||||||
|
params = {
|
||||||
|
"filename": filename,
|
||||||
|
"subfolder": subfolder,
|
||||||
|
"type": file_type
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"📡 正在从 ComfyUI 下载视频: {filename}")
|
||||||
|
try:
|
||||||
|
# 使用 requests.get 下载文件
|
||||||
|
response = requests.get(url, params=params, stream=True)
|
||||||
|
response.raise_for_status() # 检查 HTTP 错误
|
||||||
|
|
||||||
|
# 返回文件的完整二进制内容
|
||||||
|
return response.content
|
||||||
|
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
logger.error(f"❌ 从 ComfyUI 获取视频失败: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def upload_stream_to_minio(self, video_bytes: bytes, object_name: str, content_type: str):
|
||||||
|
"""从内存流上传数据到 MinIO。"""
|
||||||
|
logger.info(f"☁️ 正在上传对象到 MinIO: {object_name}")
|
||||||
|
try:
|
||||||
|
|
||||||
|
data_stream = io.BytesIO(video_bytes)
|
||||||
|
|
||||||
|
result = self.minio_client.put_object(
|
||||||
|
bucket_name='aida-users',
|
||||||
|
object_name=object_name,
|
||||||
|
data=data_stream,
|
||||||
|
length=len(video_bytes),
|
||||||
|
content_type=content_type
|
||||||
|
)
|
||||||
|
logger.info(f"✅ MinIO 上传成功: {result.object_name}")
|
||||||
|
return True
|
||||||
|
except S3Error as e:
|
||||||
|
logger.error(f"❌ MinIO 上传失败: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
request_data = ComfyuiI2VModel(
|
||||||
|
tasks_id="12222515151123-89111",
|
||||||
|
image_url="aida-users/89/product_image/a6949500-2393-42ac-8723-440b5d5da2b2-0-89.png",
|
||||||
|
prompt="Model executing a series of poses, dynamic camera movement alternating between detailed close-ups and full shots."
|
||||||
|
)
|
||||||
|
|
||||||
|
server = ComfyUIServerI2V(request_data)
|
||||||
|
print(server.get_result())
|
||||||
739
app/service/comfyui_I2V/pose2v_server.py
Normal file
739
app/service/comfyui_I2V/pose2v_server.py
Normal file
@@ -0,0 +1,739 @@
|
|||||||
|
import io
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import tempfile
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import redis
|
||||||
|
import requests
|
||||||
|
from PIL import Image
|
||||||
|
from minio import Minio, S3Error
|
||||||
|
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||||
|
|
||||||
|
from app.core.config import REDIS_HOST, REDIS_PORT, REDIS_DB, MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, COMFYUI_SERVER_ADDRESS, PS_RABBITMQ_QUEUES, DEBUG
|
||||||
|
from app.schemas.comfyui_i2v import ComfyuiPose2VModel
|
||||||
|
from app.service.generate_image.utils.mq import publish_status
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
|
||||||
|
# 图 + 骨架 = 视频 工作流
|
||||||
|
workflow_json = {
|
||||||
|
"162": {
|
||||||
|
"inputs": {
|
||||||
|
"text": "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
|
||||||
|
"clip": [
|
||||||
|
"167",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "CLIPTextEncode",
|
||||||
|
"_meta": {
|
||||||
|
"title": "CLIP Text Encode (Negative Prompt)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"163": {
|
||||||
|
"inputs": {
|
||||||
|
"fps": 24,
|
||||||
|
"images": [
|
||||||
|
"192",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "CreateVideo",
|
||||||
|
"_meta": {
|
||||||
|
"title": "创建视频"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"164": {
|
||||||
|
"inputs": {
|
||||||
|
"samples": [
|
||||||
|
"175",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"vae": [
|
||||||
|
"168",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "VAEDecode",
|
||||||
|
"_meta": {
|
||||||
|
"title": "VAE解码"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"165": {
|
||||||
|
"inputs": {
|
||||||
|
"unet_name": "wan2.2_fun_control_high_noise_14B_fp8_scaled.safetensors",
|
||||||
|
"weight_dtype": "default"
|
||||||
|
},
|
||||||
|
"class_type": "UNETLoader",
|
||||||
|
"_meta": {
|
||||||
|
"title": "UNet加载器"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"166": {
|
||||||
|
"inputs": {
|
||||||
|
"unet_name": "wan2.2_fun_control_low_noise_14B_fp8_scaled.safetensors",
|
||||||
|
"weight_dtype": "default"
|
||||||
|
},
|
||||||
|
"class_type": "UNETLoader",
|
||||||
|
"_meta": {
|
||||||
|
"title": "UNet加载器"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"167": {
|
||||||
|
"inputs": {
|
||||||
|
"clip_name": "umt5_xxl_fp8_e4m3fn_scaled.safetensors",
|
||||||
|
"type": "wan",
|
||||||
|
"device": "default"
|
||||||
|
},
|
||||||
|
"class_type": "CLIPLoader",
|
||||||
|
"_meta": {
|
||||||
|
"title": "加载CLIP"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"168": {
|
||||||
|
"inputs": {
|
||||||
|
"vae_name": "wan_2.1_vae.safetensors"
|
||||||
|
},
|
||||||
|
"class_type": "VAELoader",
|
||||||
|
"_meta": {
|
||||||
|
"title": "加载VAE"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"169": {
|
||||||
|
"inputs": {
|
||||||
|
"add_noise": "enable",
|
||||||
|
"noise_seed": 8860422635573,
|
||||||
|
"steps": 4,
|
||||||
|
"cfg": 1,
|
||||||
|
"sampler_name": "euler",
|
||||||
|
"scheduler": "simple",
|
||||||
|
"start_at_step": 0,
|
||||||
|
"end_at_step": 2,
|
||||||
|
"return_with_leftover_noise": "enable",
|
||||||
|
"model": [
|
||||||
|
"176",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"positive": [
|
||||||
|
"180",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"negative": [
|
||||||
|
"180",
|
||||||
|
1
|
||||||
|
],
|
||||||
|
"latent_image": [
|
||||||
|
"180",
|
||||||
|
2
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "KSamplerAdvanced",
|
||||||
|
"_meta": {
|
||||||
|
"title": "K采样器(高级)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"170": {
|
||||||
|
"inputs": {
|
||||||
|
"filename_prefix": "video/wan2.2_fun_control",
|
||||||
|
"format": "auto",
|
||||||
|
"codec": "auto",
|
||||||
|
"video-preview": "",
|
||||||
|
"video": [
|
||||||
|
"163",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "SaveVideo",
|
||||||
|
"_meta": {
|
||||||
|
"title": "保存视频"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"171": {
|
||||||
|
"inputs": {
|
||||||
|
"video": [
|
||||||
|
"174",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "GetVideoComponents",
|
||||||
|
"_meta": {
|
||||||
|
"title": "获取视频组件"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"174": {
|
||||||
|
"inputs": {
|
||||||
|
"file": "skeleton_3.mp4"
|
||||||
|
},
|
||||||
|
"class_type": "LoadVideo",
|
||||||
|
"_meta": {
|
||||||
|
"title": "加载视频"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"175": {
|
||||||
|
"inputs": {
|
||||||
|
"add_noise": "disable",
|
||||||
|
"noise_seed": 0,
|
||||||
|
"steps": 4,
|
||||||
|
"cfg": 1,
|
||||||
|
"sampler_name": "euler",
|
||||||
|
"scheduler": "simple",
|
||||||
|
"start_at_step": 2,
|
||||||
|
"end_at_step": 4,
|
||||||
|
"return_with_leftover_noise": "disable",
|
||||||
|
"model": [
|
||||||
|
"177",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"positive": [
|
||||||
|
"180",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"negative": [
|
||||||
|
"180",
|
||||||
|
1
|
||||||
|
],
|
||||||
|
"latent_image": [
|
||||||
|
"169",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "KSamplerAdvanced",
|
||||||
|
"_meta": {
|
||||||
|
"title": "K采样器(高级)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"176": {
|
||||||
|
"inputs": {
|
||||||
|
"shift": 8.000000000000002,
|
||||||
|
"model": [
|
||||||
|
"181",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "ModelSamplingSD3",
|
||||||
|
"_meta": {
|
||||||
|
"title": "采样算法(SD3)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"177": {
|
||||||
|
"inputs": {
|
||||||
|
"shift": 8.000000000000002,
|
||||||
|
"model": [
|
||||||
|
"182",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "ModelSamplingSD3",
|
||||||
|
"_meta": {
|
||||||
|
"title": "采样算法(SD3)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"178": {
|
||||||
|
"inputs": {
|
||||||
|
"image": "296f5fd6-c5e4-4003-9798-f378a4f08411-0-89.png"
|
||||||
|
},
|
||||||
|
"class_type": "LoadImage",
|
||||||
|
"_meta": {
|
||||||
|
"title": "加载图像"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"179": {
|
||||||
|
"inputs": {
|
||||||
|
"text": "The model is catwalking at the fashion show.",
|
||||||
|
"clip": [
|
||||||
|
"167",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "CLIPTextEncode",
|
||||||
|
"_meta": {
|
||||||
|
"title": "CLIP Text Encode (Positive Prompt)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"180": {
|
||||||
|
"inputs": {
|
||||||
|
"width": 480,
|
||||||
|
"height": 720,
|
||||||
|
"length": 121,
|
||||||
|
"batch_size": 1,
|
||||||
|
"positive": [
|
||||||
|
"179",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"negative": [
|
||||||
|
"162",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"vae": [
|
||||||
|
"168",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"ref_image": [
|
||||||
|
"178",
|
||||||
|
0
|
||||||
|
],
|
||||||
|
"control_video": [
|
||||||
|
"171",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "Wan22FunControlToVideo",
|
||||||
|
"_meta": {
|
||||||
|
"title": "Wan22FunControlToVideo"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"181": {
|
||||||
|
"inputs": {
|
||||||
|
"lora_name": "wan2.2_i2v_lightx2v_4steps_lora_v1_high_noise.safetensors",
|
||||||
|
"strength_model": 1,
|
||||||
|
"model": [
|
||||||
|
"165",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "LoraLoaderModelOnly",
|
||||||
|
"_meta": {
|
||||||
|
"title": "LoRA加载器(仅模型)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"182": {
|
||||||
|
"inputs": {
|
||||||
|
"lora_name": "wan2.2_i2v_lightx2v_4steps_lora_v1_low_noise.safetensors",
|
||||||
|
"strength_model": 1,
|
||||||
|
"model": [
|
||||||
|
"166",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "LoraLoaderModelOnly",
|
||||||
|
"_meta": {
|
||||||
|
"title": "LoRA加载器(仅模型)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"189": {
|
||||||
|
"inputs": {
|
||||||
|
"images": [
|
||||||
|
"171",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "PreviewImage",
|
||||||
|
"_meta": {
|
||||||
|
"title": "预览图像"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"190": {
|
||||||
|
"inputs": {
|
||||||
|
"images": [
|
||||||
|
"192",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "PreviewImage",
|
||||||
|
"_meta": {
|
||||||
|
"title": "预览图像"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"192": {
|
||||||
|
"inputs": {
|
||||||
|
"batch_index": 4,
|
||||||
|
"length": 117,
|
||||||
|
"image": [
|
||||||
|
"164",
|
||||||
|
0
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"class_type": "ImageFromBatch",
|
||||||
|
"_meta": {
|
||||||
|
"title": "从批次获取图像"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
# 骨架映射
|
||||||
|
video_map = {
|
||||||
|
"1": "input_pose_video/1.mp4",
|
||||||
|
"2": "input_pose_video/2.mp4",
|
||||||
|
"3": "input_pose_video/3.mp4",
|
||||||
|
"4": "input_pose_video/4.mp4",
|
||||||
|
"5": "input_pose_video/5.mp4",
|
||||||
|
"6": "input_pose_video/6.mp4"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ComfyUIServerPose2V:
|
||||||
|
def __init__(self, request_data):
|
||||||
|
self.image_url = request_data.image_url
|
||||||
|
self.pose_num = request_data.pose_id
|
||||||
|
self.tasks_id = request_data.tasks_id
|
||||||
|
self.user_id = self.tasks_id[self.tasks_id.rfind('-') + 1:]
|
||||||
|
self.redis_client = redis.StrictRedis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB, decode_responses=True)
|
||||||
|
self.pose_transform_data = {'tasks_id': self.tasks_id, 'status': 'PENDING', 'message': "pending", 'gif_url': '', 'video_url': '', 'image_url': ''}
|
||||||
|
self.redis_client.set(self.tasks_id, json.dumps(self.pose_transform_data))
|
||||||
|
self.redis_client.expire(self.tasks_id, 600)
|
||||||
|
self.minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||||
|
|
||||||
|
def get_result(self):
|
||||||
|
workflow_json['174']['inputs']['file'] = video_map[self.pose_num]
|
||||||
|
workflow_json['169']['inputs']['noise_seed'] = random.randint(0, 10 ** 18)
|
||||||
|
|
||||||
|
# 下载图片 上传 comfyui server
|
||||||
|
in_memory_file, object_name = self.download_from_minio_in_memory()
|
||||||
|
if in_memory_file and object_name:
|
||||||
|
uploaded_filename = self.upload_in_memory_file_to_comfyui(in_memory_file, object_name)
|
||||||
|
workflow_json['178']['inputs']['image'] = uploaded_filename
|
||||||
|
# 1. 提交任务
|
||||||
|
prompt_response = self.queue_prompt(workflow_json, self.tasks_id)
|
||||||
|
if not prompt_response:
|
||||||
|
return
|
||||||
|
|
||||||
|
prompt_id = prompt_response.get("prompt_id")
|
||||||
|
logger.info(f" 任务已提交,Prompt ID: {prompt_id}")
|
||||||
|
|
||||||
|
outputs = self.poll_history(prompt_id)
|
||||||
|
file_list = {}
|
||||||
|
for node_id, node_output in outputs.items():
|
||||||
|
# 检查当前节点输出中是否包含 'images' 列表
|
||||||
|
if 'images' in node_output and isinstance(node_output['images'], list):
|
||||||
|
|
||||||
|
# 'images' 列表中的每个元素都是一个文件对象
|
||||||
|
for file_info in node_output['images']:
|
||||||
|
# 确保关键字段存在
|
||||||
|
if all(key in file_info for key in ['filename', 'subfolder', 'type']):
|
||||||
|
file_list = {
|
||||||
|
'filename': file_info['filename'],
|
||||||
|
'subfolder': file_info['subfolder'],
|
||||||
|
'type': file_info['type']
|
||||||
|
}
|
||||||
|
logger.info(file_list)
|
||||||
|
return self.process_and_upload_comfyui_video(filename=file_list['filename'], subfolder=file_list['subfolder'], prompt_id=prompt_response['prompt_id']), prompt_id
|
||||||
|
|
||||||
|
def read_tasks_status(self):
|
||||||
|
status_data = self.redis_client.get(self.tasks_id)
|
||||||
|
return json.loads(status_data), status_data
|
||||||
|
|
||||||
|
def download_from_minio_in_memory(self):
|
||||||
|
bucket = self.image_url.split('/')[0]
|
||||||
|
object_name = self.image_url[self.image_url.find('/') + 1:]
|
||||||
|
# print("🚀 正在连接 MinIO 客户端...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# get_object 返回一个 ResponseStream 对象
|
||||||
|
response_stream = self.minio_client.get_object(
|
||||||
|
bucket,
|
||||||
|
object_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 读取整个流到内存 (BytesIO),避免写入本地文件
|
||||||
|
image_bytes = response_stream.read()
|
||||||
|
|
||||||
|
response_stream.close()
|
||||||
|
response_stream.release_conn()
|
||||||
|
|
||||||
|
in_memory_file = io.BytesIO(image_bytes)
|
||||||
|
|
||||||
|
# print(f"✅ 图片已下载到内存 ({len(image_bytes)} 字节)。")
|
||||||
|
return in_memory_file, object_name.rsplit('/')[-1]
|
||||||
|
|
||||||
|
except S3Error as e:
|
||||||
|
logger.error(f"❌ MinIO S3 错误 (例如,对象不存在): {e}")
|
||||||
|
return None, None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ MinIO 下载过程中发生未知错误: {e}")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
def upload_video_to_minio(self, BUCKET_NAME, OBJECT_NAME, LOCAL_FILE_PATH):
|
||||||
|
"""使用 fput_object 从本地路径上传 MP4 文件"""
|
||||||
|
try:
|
||||||
|
# 使用 fput_object 上传文件
|
||||||
|
# content_type 对于视频流播放非常重要,MP4 文件应使用 'video/mp4'
|
||||||
|
result = self.minio_client.fput_object(
|
||||||
|
bucket_name=BUCKET_NAME,
|
||||||
|
object_name=OBJECT_NAME,
|
||||||
|
file_path=LOCAL_FILE_PATH,
|
||||||
|
content_type="video/mp4" # 设置正确的内容类型
|
||||||
|
)
|
||||||
|
|
||||||
|
# print(f"✅ 文件 '{LOCAL_FILE_PATH}' 已成功上传至 MinIO:")
|
||||||
|
# print(f" 对象名: {result.object_name}")
|
||||||
|
# print(f" Etag: {result.etag}")
|
||||||
|
|
||||||
|
except S3Error as e:
|
||||||
|
logger.error(f"❌ MinIO 操作失败: {e}")
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.error(f"❌ 找不到本地文件: {LOCAL_FILE_PATH}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ 发生未知错误: {e}")
|
||||||
|
|
||||||
|
def upload_gif_to_minio(self, BUCKET_NAME, OBJECT_NAME, LOCAL_FILE_PATH):
|
||||||
|
"""使用 fput_object 从本地路径上传 MP4 文件"""
|
||||||
|
try:
|
||||||
|
# 使用 fput_object 上传文件
|
||||||
|
# content_type 对于视频流播放非常重要,MP4 文件应使用 'video/mp4'
|
||||||
|
result = self.minio_client.fput_object(
|
||||||
|
bucket_name=BUCKET_NAME,
|
||||||
|
object_name=OBJECT_NAME,
|
||||||
|
file_path=LOCAL_FILE_PATH,
|
||||||
|
content_type="video/mp4" # 设置正确的内容类型
|
||||||
|
)
|
||||||
|
|
||||||
|
# print(f"✅ 文件 '{LOCAL_FILE_PATH}' 已成功上传至 MinIO:")
|
||||||
|
# print(f" 对象名: {result.object_name}")
|
||||||
|
# print(f" Etag: {result.etag}")
|
||||||
|
|
||||||
|
except S3Error as e:
|
||||||
|
logger.error(f"❌ MinIO 操作失败: {e}")
|
||||||
|
except FileNotFoundError:
|
||||||
|
logger.error(f"❌ 找不到本地文件: {LOCAL_FILE_PATH}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ 发生未知错误: {e}")
|
||||||
|
|
||||||
|
def upload_in_memory_file_to_comfyui(self, in_memory_file, filename):
|
||||||
|
upload_url = f"http://{COMFYUI_SERVER_ADDRESS}/upload/image"
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"overwrite": "true",
|
||||||
|
"type": "input"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 构建 multipart/form-data: (文件名, 内存文件对象, MIME 类型)
|
||||||
|
# MIME 类型可以根据实际图片类型修改,这里使用常见的 png/jpeg
|
||||||
|
mime_type = 'image/png' if filename.lower().endswith('.png') else 'image/jpeg'
|
||||||
|
|
||||||
|
files = {
|
||||||
|
'image': (filename, in_memory_file, mime_type)
|
||||||
|
}
|
||||||
|
|
||||||
|
# print(f"⬆️ 正在上传图片 ({filename}) 到 ComfyUI...")
|
||||||
|
try:
|
||||||
|
comfyui_response = requests.post(upload_url, data=data, files=files)
|
||||||
|
comfyui_response.raise_for_status()
|
||||||
|
|
||||||
|
result = comfyui_response.json()
|
||||||
|
uploaded_name = result.get('name')
|
||||||
|
|
||||||
|
# print(f"🎉 ComfyUI 上传成功! 服务器文件名: {uploaded_name}")
|
||||||
|
return uploaded_name
|
||||||
|
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
logger.error(f"❌ ComfyUI 上传失败: {e}")
|
||||||
|
logger.error(f" 响应内容: {comfyui_response.text}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def process_and_upload_comfyui_video(self, filename: str, subfolder: str, prompt_id: str, ):
|
||||||
|
"""
|
||||||
|
完整的自动化流程:获取 ComfyUI 视频 -> 转换 GIF 并提取帧 -> 上传所有结果到 MinIO。
|
||||||
|
"""
|
||||||
|
# 1. 从 ComfyUI 获取视频二进制数据
|
||||||
|
mp4_bytes = self.get_comfyui_video_bytes(filename, subfolder)
|
||||||
|
if not mp4_bytes:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 2. 准备进行视频处理
|
||||||
|
# moviepy 不支持直接使用 bytes,需要将 bytes 写入一个 BytesIO 或临时文件
|
||||||
|
# 为了避免写磁盘,我们将使用 BytesIO,但 MoviePy 内部依赖 FFmpeg,有时需要一个可寻址的本地文件路径。
|
||||||
|
# 最可靠且避免写本地的方案是在内存中操作,然后将结果上传。
|
||||||
|
|
||||||
|
# ⚠️ 关键点:将 mp4_bytes 写入 BytesIO 以模拟文件,供 moviepy 读取
|
||||||
|
|
||||||
|
# 定义输出对象名
|
||||||
|
|
||||||
|
output_base_name = uuid.uuid4().hex
|
||||||
|
MP4_OBJECT = f"{self.user_id}/pose_transform_video/{prompt_id}/{output_base_name}.mp4"
|
||||||
|
GIF_OBJECT = f"{self.user_id}/pose_transform_gif/{prompt_id}/{output_base_name}.gif"
|
||||||
|
FRAME_OBJECT = f"{self.user_id}/pose_transform_first_img/{prompt_id}/{output_base_name}_frame.jpg"
|
||||||
|
|
||||||
|
# --- 视频处理和帧提取 ---
|
||||||
|
try:
|
||||||
|
# 1. 创建一个临时的 MP4 文件路径
|
||||||
|
# delete=False 确保文件在关闭后仍然存在,直到我们手动删除
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_file:
|
||||||
|
tmp_file.write(mp4_bytes) # 将内存数据写入磁盘
|
||||||
|
temp_mp4_path = tmp_file.name # 记录文件路径
|
||||||
|
|
||||||
|
# print(f"临时文件已写入: {temp_mp4_path}")
|
||||||
|
|
||||||
|
# 2. 使用 moviepy 打开临时文件 (传入文件路径字符串)
|
||||||
|
clip = VideoFileClip(temp_mp4_path)
|
||||||
|
|
||||||
|
# --- 在这里进行所有的视频处理和提取操作 ---
|
||||||
|
|
||||||
|
# 提取第一帧 (保持原尺寸)
|
||||||
|
frame_array = clip.get_frame(t=0.0)
|
||||||
|
image = Image.fromarray(frame_array)
|
||||||
|
|
||||||
|
frame_stream = io.BytesIO()
|
||||||
|
image.save(frame_stream, 'JPEG')
|
||||||
|
frame_bytes = frame_stream.getvalue()
|
||||||
|
|
||||||
|
logger.info("✅ 成功提取第一帧图片。")
|
||||||
|
|
||||||
|
# 视频转 GIF (使用另一个临时文件来保存 GIF)
|
||||||
|
temp_gif_path = ""
|
||||||
|
with tempfile.NamedTemporaryFile(suffix=".gif", delete=False) as tmp_file:
|
||||||
|
temp_gif_path = tmp_file.name
|
||||||
|
|
||||||
|
target_fps = int(round(clip.fps)) if clip.fps else 24
|
||||||
|
clip.write_gif(temp_gif_path, fps=target_fps)
|
||||||
|
|
||||||
|
with open(temp_gif_path, 'rb') as f:
|
||||||
|
gif_bytes = f.read()
|
||||||
|
|
||||||
|
logger.info("✅ 成功生成 GIF。")
|
||||||
|
|
||||||
|
# 返回结果 (例如: 上传到 MinIO)
|
||||||
|
# return mp4_bytes, gif_bytes, frame_bytes
|
||||||
|
|
||||||
|
# -----------------------------------------------
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ 视频处理或文件操作失败: {e}")
|
||||||
|
# 在失败时,也尝试清理文件
|
||||||
|
|
||||||
|
finally:
|
||||||
|
# 3. 清理临时文件 (非常重要!)
|
||||||
|
if os.path.exists(temp_mp4_path):
|
||||||
|
os.remove(temp_mp4_path)
|
||||||
|
logger.info(f"🗑️ 已删除临时 MP4 文件: {temp_mp4_path}")
|
||||||
|
|
||||||
|
if 'temp_gif_path' in locals() and os.path.exists(temp_gif_path):
|
||||||
|
os.remove(temp_gif_path)
|
||||||
|
logger.info(f"🗑️ 已删除临时 GIF 文件: {temp_gif_path}")
|
||||||
|
|
||||||
|
# 3. 上传所有结果到 MinIO
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 上传原始 MP4
|
||||||
|
self.upload_stream_to_minio(mp4_bytes, MP4_OBJECT, "video/mp4")
|
||||||
|
|
||||||
|
# 上传生成的 GIF
|
||||||
|
self.upload_stream_to_minio(gif_bytes, GIF_OBJECT, "image/gif")
|
||||||
|
|
||||||
|
# 上传第一帧图片
|
||||||
|
self.upload_stream_to_minio(frame_bytes, FRAME_OBJECT, "image/jpeg")
|
||||||
|
|
||||||
|
self.pose_transform_data = {'tasks_id': self.tasks_id, 'status': 'SUCCESS', 'message': "success", 'gif_url': f'aida-users/{GIF_OBJECT}', 'video_url': f'aida-users/{MP4_OBJECT}', 'image_url': f'aida-users/{FRAME_OBJECT}'}
|
||||||
|
|
||||||
|
# 推送消息
|
||||||
|
if not DEBUG:
|
||||||
|
publish_status(json.dumps(self.pose_transform_data), PS_RABBITMQ_QUEUES)
|
||||||
|
logger.info(
|
||||||
|
f" [x] Sent to: {PS_RABBITMQ_QUEUES} data:@@@@ {json.dumps(self.pose_transform_data, indent=4)}")
|
||||||
|
|
||||||
|
return "\n🎉 所有任务完成!"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e)
|
||||||
|
return None
|
||||||
|
|
||||||
|
# --- 辅助函数:提交任务到队列 ---
|
||||||
|
def queue_prompt(self, prompt, client_id):
|
||||||
|
"""向 ComfyUI 提交工作流提示。"""
|
||||||
|
p = {"prompt": prompt, "client_id": client_id, "prompt_id": client_id}
|
||||||
|
data = json.dumps(p).encode('utf-8')
|
||||||
|
|
||||||
|
# 提交任务到 /prompt 端点
|
||||||
|
response = requests.post(f"http://{COMFYUI_SERVER_ADDRESS}/prompt", data=data)
|
||||||
|
# print(f"-------------{response.text}")
|
||||||
|
# print(f"------------{client_id}")
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
return response.json()
|
||||||
|
else:
|
||||||
|
logger.warning(f"提交任务失败,状态码: {response.status_code}")
|
||||||
|
logger.warning(response.text)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def poll_history(self, prompt_id, interval_seconds=5):
|
||||||
|
"""步骤 2: 轮询 /history/{prompt_id} 检查任务是否完成"""
|
||||||
|
url = f"http://{COMFYUI_SERVER_ADDRESS}/history/{prompt_id}"
|
||||||
|
|
||||||
|
logger.info(f"⏳ 开始轮询状态 (间隔 {interval_seconds} 秒)...")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
time.sleep(interval_seconds)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.get(url)
|
||||||
|
# 任务未完成时,ComfyUI可能会返回404或空响应,我们只关注成功响应
|
||||||
|
if response.status_code == 200:
|
||||||
|
history_data = response.json()
|
||||||
|
|
||||||
|
# ComfyUI 返回的历史记录结构是 {prompt_id: {outputs: ...}}
|
||||||
|
if prompt_id in history_data:
|
||||||
|
logger.info("🎉 任务已完成!")
|
||||||
|
return history_data[prompt_id]['outputs']
|
||||||
|
|
||||||
|
logger.info("⏳ 任务仍在执行或等待中...")
|
||||||
|
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
# 处理可能的连接错误,但通常不会在内部轮询中发生
|
||||||
|
logger.info(f"⚠️ 轮询时发生错误: {e}")
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_comfyui_video_bytes(self, filename: str, subfolder: str, file_type: str = "output"):
|
||||||
|
"""
|
||||||
|
从 ComfyUI 的 /view 端点获取视频文件的二进制数据。
|
||||||
|
|
||||||
|
参数:
|
||||||
|
- filename: 视频文件名 (例如: 'ComfyUI_00002_.mp4')
|
||||||
|
- subfolder: 存储子文件夹 (例如: 'ComfyUI_2025-10-31')
|
||||||
|
- file_type: 文件类型 (通常是 'output')
|
||||||
|
|
||||||
|
返回:
|
||||||
|
- 视频文件的二进制内容 (bytes) 或 None。
|
||||||
|
"""
|
||||||
|
url = f"http://{COMFYUI_SERVER_ADDRESS}/view"
|
||||||
|
params = {
|
||||||
|
"filename": filename,
|
||||||
|
"subfolder": subfolder,
|
||||||
|
"type": file_type
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"📡 正在从 ComfyUI 下载视频: {filename}")
|
||||||
|
try:
|
||||||
|
# 使用 requests.get 下载文件
|
||||||
|
response = requests.get(url, params=params, stream=True)
|
||||||
|
response.raise_for_status() # 检查 HTTP 错误
|
||||||
|
|
||||||
|
# 返回文件的完整二进制内容
|
||||||
|
return response.content
|
||||||
|
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
logger.error(f"❌ 从 ComfyUI 获取视频失败: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def upload_stream_to_minio(self, video_bytes: bytes, object_name: str, content_type: str):
|
||||||
|
"""从内存流上传数据到 MinIO。"""
|
||||||
|
logger.info(f"☁️ 正在上传对象到 MinIO: {object_name}")
|
||||||
|
try:
|
||||||
|
|
||||||
|
data_stream = io.BytesIO(video_bytes)
|
||||||
|
|
||||||
|
result = self.minio_client.put_object(
|
||||||
|
bucket_name='aida-users',
|
||||||
|
object_name=object_name,
|
||||||
|
data=data_stream,
|
||||||
|
length=len(video_bytes),
|
||||||
|
content_type=content_type
|
||||||
|
)
|
||||||
|
logger.info(f"✅ MinIO 上传成功: {result.object_name}")
|
||||||
|
return True
|
||||||
|
except S3Error as e:
|
||||||
|
logger.error(f"❌ MinIO 上传失败: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
request_data = ComfyuiPose2VModel(
|
||||||
|
tasks_id="122522251123-89111",
|
||||||
|
image_url="aida-users/89/product_image/a6949500-2393-42ac-8723-440b5d5da2b2-0-89.png",
|
||||||
|
pose_id="6"
|
||||||
|
)
|
||||||
|
|
||||||
|
server = ComfyUIServerPose2V(request_data)
|
||||||
|
print(server.get_result())
|
||||||
@@ -5,7 +5,7 @@ from .top import Top, Blouse, Outwear, Dress
|
|||||||
from .bottom import Bottom, Trousers, Skirt
|
from .bottom import Bottom, Trousers, Skirt
|
||||||
from .shoes import Shoes
|
from .shoes import Shoes
|
||||||
from .bag import Bag
|
from .bag import Bag
|
||||||
from .accessories import Hairstyle, Earring
|
from .others import Hairstyle, Earring
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'ITEMS', 'build_item',
|
'ITEMS', 'build_item',
|
||||||
|
|||||||
@@ -5,9 +5,9 @@ from celery import Celery
|
|||||||
from minio import Minio
|
from minio import Minio
|
||||||
|
|
||||||
from app.core.config import *
|
from app.core.config import *
|
||||||
from app.service.design_batch.item import BodyItem, TopItem, BottomItem, AccessoriesItem
|
from app.service.design_batch.item import BodyItem, TopItem, BottomItem, OthersItem
|
||||||
from app.service.design_batch.utils.MQ import publish_status
|
from app.service.design_batch.utils.MQ import publish_status
|
||||||
from app.service.design_batch.utils.organize import organize_body, organize_clothing, organize_accessories
|
from app.service.design_batch.utils.organize import organize_body, organize_clothing, organize_others
|
||||||
from app.service.design_batch.utils.save_json import oss_upload_json
|
from app.service.design_batch.utils.save_json import oss_upload_json
|
||||||
from app.service.design_batch.utils.synthesis_item import update_base_size_priority, synthesis, synthesis_single
|
from app.service.design_batch.utils.synthesis_item import update_base_size_priority, synthesis, synthesis_single
|
||||||
|
|
||||||
@@ -33,8 +33,8 @@ def process_item(item, basic):
|
|||||||
elif item['type'].lower() in ['skirt', 'trousers', 'bottoms']:
|
elif item['type'].lower() in ['skirt', 'trousers', 'bottoms']:
|
||||||
bottom_server = BottomItem(data=item, basic=basic, minio_client=minio_client)
|
bottom_server = BottomItem(data=item, basic=basic, minio_client=minio_client)
|
||||||
item_data = bottom_server.process()
|
item_data = bottom_server.process()
|
||||||
elif item['type'].lower() in ['accessories']:
|
elif item['type'].lower() in ['others']:
|
||||||
bottom_server = AccessoriesItem(data=item, basic=basic, minio_client=minio_client)
|
bottom_server = OthersItem(data=item, basic=basic, minio_client=minio_client)
|
||||||
item_data = bottom_server.process()
|
item_data = bottom_server.process()
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Item type {item['type']} not implemented")
|
raise NotImplementedError(f"Item type {item['type']} not implemented")
|
||||||
@@ -47,8 +47,8 @@ def process_layer(item, layers):
|
|||||||
body_layer = organize_body(item)
|
body_layer = organize_body(item)
|
||||||
layers.append(body_layer)
|
layers.append(body_layer)
|
||||||
return item['body_image'].size
|
return item['body_image'].size
|
||||||
elif item['name'] == 'accessories':
|
elif item['name'] == 'others':
|
||||||
front_layer, back_layer = organize_accessories(item)
|
front_layer, back_layer = organize_others(item)
|
||||||
layers.append(front_layer)
|
layers.append(front_layer)
|
||||||
layers.append(back_layer)
|
layers.append(back_layer)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -9,10 +9,10 @@ class BaseItem:
|
|||||||
self.result.update(basic)
|
self.result.update(basic)
|
||||||
|
|
||||||
|
|
||||||
class AccessoriesItem(BaseItem):
|
class OthersItem(BaseItem):
|
||||||
def __init__(self, data, basic, minio_client):
|
def __init__(self, data, basic, minio_client):
|
||||||
super().__init__(data, basic)
|
super().__init__(data, basic)
|
||||||
self.Accessories_pipeline = [
|
self.Others_pipeline = [
|
||||||
LoadImage(minio_client),
|
LoadImage(minio_client),
|
||||||
# KeyPoint(),
|
# KeyPoint(),
|
||||||
ContourDetection(),
|
ContourDetection(),
|
||||||
@@ -25,7 +25,7 @@ class AccessoriesItem(BaseItem):
|
|||||||
]
|
]
|
||||||
|
|
||||||
def process(self):
|
def process(self):
|
||||||
for item in self.Accessories_pipeline:
|
for item in self.Others_pipeline:
|
||||||
self.result = item(self.result)
|
self.result = item(self.result)
|
||||||
return self.result
|
return self.result
|
||||||
|
|
||||||
|
|||||||
@@ -74,8 +74,8 @@ class LoadImage:
|
|||||||
keypoint = 'head_point'
|
keypoint = 'head_point'
|
||||||
elif name == 'earring':
|
elif name == 'earring':
|
||||||
keypoint = 'ear_point'
|
keypoint = 'ear_point'
|
||||||
elif name == 'accessories':
|
elif name == 'others':
|
||||||
keypoint = "accessories"
|
keypoint = "others"
|
||||||
else:
|
else:
|
||||||
raise KeyError(f"{name} does not belong to item category list: blouse, outwear, dress, trousers, skirt, "
|
raise KeyError(f"{name} does not belong to item category list: blouse, outwear, dress, trousers, skirt, "
|
||||||
f"bag, shoes, hairstyle, earring.")
|
f"bag, shoes, hairstyle, earring.")
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ class Scaling:
|
|||||||
result['scale'] = result['scale_bag']
|
result['scale'] = result['scale_bag']
|
||||||
elif result['keypoint'] == 'ear_point':
|
elif result['keypoint'] == 'ear_point':
|
||||||
result['scale'] = result['scale_earrings']
|
result['scale'] = result['scale_earrings']
|
||||||
elif result['keypoint'] == 'accessories':
|
elif result['keypoint'] == 'others':
|
||||||
# 由于没有识别配饰keypoint的模型 所以统一将配饰的两个关键点设定为 (0,0) (0,img.width)
|
# 由于没有识别配饰keypoint的模型 所以统一将配饰的两个关键点设定为 (0,0) (0,img.width)
|
||||||
# 模特的关键点设定为(0,0) (0,320/2) 距离比例简写为 160 / img.width
|
# 模特的关键点设定为(0,0) (0,320/2) 距离比例简写为 160 / img.width
|
||||||
distance_clo = result['img_shape'][1]
|
distance_clo = result['img_shape'][1]
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ class Split(object):
|
|||||||
def __call__(self, result):
|
def __call__(self, result):
|
||||||
try:
|
try:
|
||||||
|
|
||||||
if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms', 'accessories'):
|
if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms', 'others'):
|
||||||
|
|
||||||
if result['resize_scale'][0] == 1.0 and result['resize_scale'][1] == 1.0:
|
if result['resize_scale'][0] == 1.0 and result['resize_scale'][1] == 1.0:
|
||||||
front_mask = result['front_mask']
|
front_mask = result['front_mask']
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ def organize_clothing(layer):
|
|||||||
return front_layer, back_layer
|
return front_layer, back_layer
|
||||||
|
|
||||||
|
|
||||||
def organize_accessories(layer):
|
def organize_others(layer):
|
||||||
# 起始坐标
|
# 起始坐标
|
||||||
start_point = (0, 0)
|
start_point = (0, 0)
|
||||||
# 前片数据
|
# 前片数据
|
||||||
|
|||||||
@@ -6,8 +6,8 @@ import requests
|
|||||||
from minio import Minio
|
from minio import Minio
|
||||||
|
|
||||||
from app.core.config import *
|
from app.core.config import *
|
||||||
from app.service.design_fast.item import BodyItem, TopItem, BottomItem, AccessoriesItem
|
from app.service.design_fast.item import BodyItem, TopItem, BottomItem, OthersItem
|
||||||
from app.service.design_fast.utils.organize import organize_body, organize_clothing, organize_accessories
|
from app.service.design_fast.utils.organize import organize_body, organize_clothing, organize_others
|
||||||
from app.service.design_fast.utils.progress import final_progress, update_progress
|
from app.service.design_fast.utils.progress import final_progress, update_progress
|
||||||
from app.service.design_fast.utils.synthesis_item import synthesis, synthesis_single, update_base_size_priority
|
from app.service.design_fast.utils.synthesis_item import synthesis, synthesis_single, update_base_size_priority
|
||||||
from app.service.utils.decorator import RunTime
|
from app.service.utils.decorator import RunTime
|
||||||
@@ -30,8 +30,8 @@ def process_item(item, basic):
|
|||||||
elif item['type'].lower() in ['skirt', 'trousers', 'bottoms']:
|
elif item['type'].lower() in ['skirt', 'trousers', 'bottoms']:
|
||||||
bottom_server = BottomItem(data=item, basic=basic, minio_client=minio_client)
|
bottom_server = BottomItem(data=item, basic=basic, minio_client=minio_client)
|
||||||
item_data = bottom_server.process()
|
item_data = bottom_server.process()
|
||||||
elif item['type'].lower() in ['accessories']:
|
elif item['type'].lower() in ['others']:
|
||||||
bottom_server = AccessoriesItem(data=item, basic=basic, minio_client=minio_client)
|
bottom_server = OthersItem(data=item, basic=basic, minio_client=minio_client)
|
||||||
item_data = bottom_server.process()
|
item_data = bottom_server.process()
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Item type {item['type']} not implemented")
|
raise NotImplementedError(f"Item type {item['type']} not implemented")
|
||||||
@@ -44,8 +44,8 @@ def process_layer(item, layers):
|
|||||||
body_layer = organize_body(item)
|
body_layer = organize_body(item)
|
||||||
layers.append(body_layer)
|
layers.append(body_layer)
|
||||||
return item['body_image'].size
|
return item['body_image'].size
|
||||||
elif item['name'] == 'accessories':
|
elif item['name'] == 'others':
|
||||||
front_layer, back_layer = organize_accessories(item)
|
front_layer, back_layer = organize_others(item)
|
||||||
layers.append(front_layer)
|
layers.append(front_layer)
|
||||||
layers.append(back_layer)
|
layers.append(back_layer)
|
||||||
else:
|
else:
|
||||||
@@ -79,7 +79,7 @@ def design_generate(request_data):
|
|||||||
layers = sorted(layers, key=lambda s: s.get("priority", float('inf')))
|
layers = sorted(layers, key=lambda s: s.get("priority", float('inf')))
|
||||||
|
|
||||||
layers, new_size = update_base_size_priority(layers, body_size)
|
layers, new_size = update_base_size_priority(layers, body_size)
|
||||||
|
# pattern_overall_image_url 、 pattern_print_image_url
|
||||||
for lay in layers:
|
for lay in layers:
|
||||||
items_response['layers'].append({
|
items_response['layers'].append({
|
||||||
'image_category': "body" if lay['name'] == 'mannequin' else lay['name'],
|
'image_category': "body" if lay['name'] == 'mannequin' else lay['name'],
|
||||||
@@ -90,7 +90,9 @@ def design_generate(request_data):
|
|||||||
'gradient_string': lay['gradient_string'] if 'gradient_string' in lay.keys() else "",
|
'gradient_string': lay['gradient_string'] if 'gradient_string' in lay.keys() else "",
|
||||||
'mask_url': lay['mask_url'],
|
'mask_url': lay['mask_url'],
|
||||||
'image_url': lay['image_url'] if 'image_url' in lay.keys() else None,
|
'image_url': lay['image_url'] if 'image_url' in lay.keys() else None,
|
||||||
'pattern_image_url': lay['pattern_image_url'] if 'pattern_image_url' in lay.keys() else None,
|
'pattern_overall_image_url': lay['pattern_overall_image_url'] if 'pattern_overall_image_url' in lay.keys() else None,
|
||||||
|
'pattern_print_image_url': lay['pattern_print_image_url'] if 'pattern_print_image_url' in lay.keys() else None,
|
||||||
|
|
||||||
# 'back_perspective_url': lay['back_perspective_url'] if 'back_perspective_url' in lay.keys() else None,
|
# 'back_perspective_url': lay['back_perspective_url'] if 'back_perspective_url' in lay.keys() else None,
|
||||||
})
|
})
|
||||||
items_response['synthesis_url'] = synthesis(layers, new_size, basic)
|
items_response['synthesis_url'] = synthesis(layers, new_size, basic)
|
||||||
@@ -104,7 +106,9 @@ def design_generate(request_data):
|
|||||||
'image_url': item_result['front_image_url'],
|
'image_url': item_result['front_image_url'],
|
||||||
'mask_url': item_result['mask_url'],
|
'mask_url': item_result['mask_url'],
|
||||||
"gradient_string": item_result['gradient_string'] if 'gradient_string' in item_result.keys() else "",
|
"gradient_string": item_result['gradient_string'] if 'gradient_string' in item_result.keys() else "",
|
||||||
'pattern_image_url': item_result['pattern_image_url'] if 'pattern_image_url' in item_result.keys() else None,
|
'pattern_overall_image_url': item_result['pattern_overall_image_url'] if 'pattern_overall_image_url' in item_result.keys() else None,
|
||||||
|
'pattern_print_image_url': item_result['pattern_print_image_url'] if 'pattern_print_image_url' in item_result.keys() else None,
|
||||||
|
|
||||||
})
|
})
|
||||||
items_response['layers'].append({
|
items_response['layers'].append({
|
||||||
'image_category': f"{item_result['name']}_back",
|
'image_category': f"{item_result['name']}_back",
|
||||||
@@ -114,7 +118,9 @@ def design_generate(request_data):
|
|||||||
'image_url': item_result['back_image_url'],
|
'image_url': item_result['back_image_url'],
|
||||||
'mask_url': item_result['mask_url'],
|
'mask_url': item_result['mask_url'],
|
||||||
"gradient_string": item_result['gradient_string'] if 'gradient_string' in item_result.keys() else "",
|
"gradient_string": item_result['gradient_string'] if 'gradient_string' in item_result.keys() else "",
|
||||||
'pattern_image_url': item_result['pattern_image_url'] if 'pattern_image_url' in item_result.keys() else None,
|
'pattern_overall_image_url': item_result['pattern_overall_image_url'] if 'pattern_overall_image_url' in item_result.keys() else None,
|
||||||
|
'pattern_print_image_url': item_result['pattern_print_image_url'] if 'pattern_print_image_url' in item_result.keys() else None,
|
||||||
|
|
||||||
})
|
})
|
||||||
items_response['synthesis_url'] = synthesis_single(item_result['front_image'], item_result['back_image'])
|
items_response['synthesis_url'] = synthesis_single(item_result['front_image'], item_result['back_image'])
|
||||||
update_progress(process_id, total)
|
update_progress(process_id, total)
|
||||||
@@ -139,10 +145,11 @@ def design_generate(request_data):
|
|||||||
@RunTime
|
@RunTime
|
||||||
def design_generate_v2(request_data):
|
def design_generate_v2(request_data):
|
||||||
objects_data = request_data.dict()['objects']
|
objects_data = request_data.dict()['objects']
|
||||||
|
callback_url = request_data.callback_url
|
||||||
request_id = request_data.requestId
|
request_id = request_data.requestId
|
||||||
threads = []
|
threads = []
|
||||||
|
|
||||||
def process_object(step, object):
|
def process_object(step, object, callback_url):
|
||||||
basic = object['basic']
|
basic = object['basic']
|
||||||
items_response = {
|
items_response = {
|
||||||
'layers': [],
|
'layers': [],
|
||||||
@@ -171,7 +178,9 @@ def design_generate_v2(request_data):
|
|||||||
'gradient_string': lay['gradient_string'] if 'gradient_string' in lay.keys() else "",
|
'gradient_string': lay['gradient_string'] if 'gradient_string' in lay.keys() else "",
|
||||||
'mask_url': lay['mask_url'],
|
'mask_url': lay['mask_url'],
|
||||||
'image_url': lay['image_url'] if 'image_url' in lay.keys() else None,
|
'image_url': lay['image_url'] if 'image_url' in lay.keys() else None,
|
||||||
'pattern_image_url': lay['pattern_image_url'] if 'pattern_image_url' in lay.keys() else None,
|
'pattern_overall_image_url': lay['pattern_overall_image_url'] if 'pattern_overall_image_url' in lay.keys() else None,
|
||||||
|
'pattern_print_image_url': lay['pattern_print_image_url'] if 'pattern_print_image_url' in lay.keys() else None,
|
||||||
|
|
||||||
# 'back_perspective_url': lay['back_perspective_url'] if 'back_perspective_url' in lay.keys() else None,
|
# 'back_perspective_url': lay['back_perspective_url'] if 'back_perspective_url' in lay.keys() else None,
|
||||||
})
|
})
|
||||||
items_response['synthesis_url'] = synthesis(layers, new_size, basic)
|
items_response['synthesis_url'] = synthesis(layers, new_size, basic)
|
||||||
@@ -185,7 +194,9 @@ def design_generate_v2(request_data):
|
|||||||
'image_url': item_result['front_image_url'],
|
'image_url': item_result['front_image_url'],
|
||||||
'mask_url': item_result['mask_url'],
|
'mask_url': item_result['mask_url'],
|
||||||
"gradient_string": item_result['gradient_string'] if 'gradient_string' in item_result.keys() else "",
|
"gradient_string": item_result['gradient_string'] if 'gradient_string' in item_result.keys() else "",
|
||||||
'pattern_image_url': item_result['pattern_image_url'] if 'pattern_image_url' in item_result.keys() else None,
|
'pattern_overall_image_url': item_result['pattern_overall_image_url'] if 'pattern_overall_image_url' in item_result.keys() else None,
|
||||||
|
'pattern_print_image_url': item_result['pattern_print_image_url'] if 'pattern_print_image_url' in item_result.keys() else None,
|
||||||
|
|
||||||
})
|
})
|
||||||
items_response['layers'].append({
|
items_response['layers'].append({
|
||||||
'image_category': f"{item_result['name']}_back",
|
'image_category': f"{item_result['name']}_back",
|
||||||
@@ -195,16 +206,14 @@ def design_generate_v2(request_data):
|
|||||||
'image_url': item_result['back_image_url'],
|
'image_url': item_result['back_image_url'],
|
||||||
'mask_url': item_result['mask_url'],
|
'mask_url': item_result['mask_url'],
|
||||||
"gradient_string": item_result['gradient_string'] if 'gradient_string' in item_result.keys() else "",
|
"gradient_string": item_result['gradient_string'] if 'gradient_string' in item_result.keys() else "",
|
||||||
'pattern_image_url': item_result['pattern_image_url'] if 'pattern_image_url' in item_result.keys() else None,
|
'pattern_overall_image_url': item_result['pattern_overall_image_url'] if 'pattern_overall_image_url' in item_result.keys() else None,
|
||||||
|
'pattern_print_image_url': item_result['pattern_print_image_url'] if 'pattern_print_image_url' in item_result.keys() else None,
|
||||||
|
|
||||||
})
|
})
|
||||||
items_response['synthesis_url'] = synthesis_single(item_result['front_image'], item_result['back_image'])
|
items_response['synthesis_url'] = synthesis_single(item_result['front_image'], item_result['back_image'])
|
||||||
# 发送结果给java端
|
# 发送结果给java端
|
||||||
url = JAVA_STREAM_API_URL
|
url = callback_url
|
||||||
# xu_pei_test_url = "https://cd21b9110505.ngrok-free.app/api/third/party/receiveDesignResults"
|
|
||||||
tianxaing_test_url = "https://c2ae520723c9.ngrok-free.app/api/third/party/receiveDesignResults"
|
|
||||||
logger.info(f"java 回调 -> {url}")
|
logger.info(f"java 回调 -> {url}")
|
||||||
# logger.info(f"xupei java 回调 -> {xu_pei_test_url}")
|
|
||||||
logger.info(f"tianxiang java 回调 -> {tianxaing_test_url}")
|
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
'Accept': "*/*",
|
'Accept': "*/*",
|
||||||
@@ -219,16 +228,8 @@ def design_generate_v2(request_data):
|
|||||||
# 打印结果
|
# 打印结果
|
||||||
logger.info(response.text)
|
logger.info(response.text)
|
||||||
|
|
||||||
# test_response = post_request(xu_pei_test_url, json_data=items_response, headers=headers)
|
|
||||||
test_response = post_request(tianxaing_test_url, json_data=items_response, headers=headers)
|
|
||||||
|
|
||||||
if test_response:
|
|
||||||
# 打印结果
|
|
||||||
# logger.info(f"xupei test response : {test_response.text}")
|
|
||||||
logger.info(f"tianxiang test response : {test_response.text}")
|
|
||||||
|
|
||||||
for step, object in enumerate(objects_data):
|
for step, object in enumerate(objects_data):
|
||||||
t = threading.Thread(target=process_object, args=(step, object))
|
t = threading.Thread(target=process_object, args=(step, object, callback_url))
|
||||||
threads.append(t)
|
threads.append(t)
|
||||||
t.start()
|
t.start()
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from app.service.design_fast.pipeline import LoadImage, KeyPoint, Segmentation, Color, PrintPainting, Scaling, Split, LoadBodyImage, ContourDetection
|
from app.service.design_fast.pipeline import LoadImage, KeyPoint, Segmentation, Color, PrintPainting, Scaling, Split, LoadBodyImage, ContourDetection, NoSegPrintPainting
|
||||||
|
|
||||||
|
|
||||||
class BaseItem:
|
class BaseItem:
|
||||||
@@ -9,23 +9,24 @@ class BaseItem:
|
|||||||
self.result.update(basic)
|
self.result.update(basic)
|
||||||
|
|
||||||
|
|
||||||
class AccessoriesItem(BaseItem):
|
class OthersItem(BaseItem):
|
||||||
def __init__(self, data, basic, minio_client):
|
def __init__(self, data, basic, minio_client):
|
||||||
super().__init__(data, basic)
|
super().__init__(data, basic)
|
||||||
self.Accessories_pipeline = [
|
self.Others_pipeline = [
|
||||||
LoadImage(minio_client),
|
LoadImage(minio_client),
|
||||||
# KeyPoint(),
|
# KeyPoint(),
|
||||||
# ContourDetection(),
|
# ContourDetection(),
|
||||||
Segmentation(minio_client),
|
Segmentation(minio_client),
|
||||||
# BackPerspective(minio_client),
|
# BackPerspective(minio_client),
|
||||||
Color(minio_client),
|
Color(minio_client),
|
||||||
|
NoSegPrintPainting(minio_client),
|
||||||
PrintPainting(minio_client),
|
PrintPainting(minio_client),
|
||||||
Scaling(),
|
Scaling(),
|
||||||
Split(minio_client)
|
Split(minio_client)
|
||||||
]
|
]
|
||||||
|
|
||||||
def process(self):
|
def process(self):
|
||||||
for item in self.Accessories_pipeline:
|
for item in self.Others_pipeline:
|
||||||
self.result = item(self.result)
|
self.result = item(self.result)
|
||||||
return self.result
|
return self.result
|
||||||
|
|
||||||
@@ -39,6 +40,7 @@ class TopItem(BaseItem):
|
|||||||
Segmentation(minio_client),
|
Segmentation(minio_client),
|
||||||
# BackPerspective(minio_client),
|
# BackPerspective(minio_client),
|
||||||
Color(minio_client),
|
Color(minio_client),
|
||||||
|
NoSegPrintPainting(minio_client),
|
||||||
PrintPainting(minio_client),
|
PrintPainting(minio_client),
|
||||||
Scaling(),
|
Scaling(),
|
||||||
Split(minio_client)
|
Split(minio_client)
|
||||||
@@ -60,6 +62,7 @@ class BottomItem(BaseItem):
|
|||||||
Segmentation(minio_client),
|
Segmentation(minio_client),
|
||||||
# BackPerspective(minio_client),
|
# BackPerspective(minio_client),
|
||||||
Color(minio_client),
|
Color(minio_client),
|
||||||
|
NoSegPrintPainting(minio_client),
|
||||||
PrintPainting(minio_client),
|
PrintPainting(minio_client),
|
||||||
Scaling(),
|
Scaling(),
|
||||||
Split(minio_client)
|
Split(minio_client)
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from .keypoint import KeyPoint
|
|||||||
from .keypoint import KeyPoint
|
from .keypoint import KeyPoint
|
||||||
from .loading import LoadImage, LoadBodyImage
|
from .loading import LoadImage, LoadBodyImage
|
||||||
from .print_painting import PrintPainting
|
from .print_painting import PrintPainting
|
||||||
|
from .no_seg_print_painting import NoSegPrintPainting
|
||||||
from .scale import Scaling
|
from .scale import Scaling
|
||||||
from .segmentation import Segmentation
|
from .segmentation import Segmentation
|
||||||
from .split import Split
|
from .split import Split
|
||||||
@@ -16,6 +17,7 @@ __all__ = [
|
|||||||
'Segmentation',
|
'Segmentation',
|
||||||
'BackPerspective',
|
'BackPerspective',
|
||||||
'Color',
|
'Color',
|
||||||
|
'NoSegPrintPainting',
|
||||||
'PrintPainting',
|
'PrintPainting',
|
||||||
'Scaling',
|
'Scaling',
|
||||||
'Split'
|
'Split'
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ class Color:
|
|||||||
resize_pattern = cv2.resize(pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA)
|
resize_pattern = cv2.resize(pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA)
|
||||||
# 无色
|
# 无色
|
||||||
elif "color" not in result.keys() or result['color'] == "":
|
elif "color" not in result.keys() or result['color'] == "":
|
||||||
result['no_seg_sketch'] = result['final_image'] = result['pattern_image'] = result['single_image'] = result['image']
|
result['no_seg_sketch_overall'] = result['no_seg_sketch_print'] = result['final_image'] = result['pattern_image'] = result['single_image'] = result['image']
|
||||||
result['alpha'] = 100 / 255.0
|
result['alpha'] = 100 / 255.0
|
||||||
return result
|
return result
|
||||||
# 正常颜色
|
# 正常颜色
|
||||||
@@ -48,7 +48,7 @@ class Color:
|
|||||||
resize_pattern[mask_3ch] = png_rgb[mask_3ch]
|
resize_pattern[mask_3ch] = png_rgb[mask_3ch]
|
||||||
resize_pattern = cv2.resize(resize_pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA)
|
resize_pattern = cv2.resize(resize_pattern, (dim_image_w, dim_image_h), interpolation=cv2.INTER_AREA)
|
||||||
closed_mo = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
|
closed_mo = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
|
||||||
gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
|
gray_mo = np.expand_dims(cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY), axis=2).repeat(3, axis=2)
|
||||||
get_image_fir = resize_pattern * (closed_mo / 255) * (gray_mo / 255)
|
get_image_fir = resize_pattern * (closed_mo / 255) * (gray_mo / 255)
|
||||||
result['pattern_image'] = get_image_fir.astype(np.uint8)
|
result['pattern_image'] = get_image_fir.astype(np.uint8)
|
||||||
result['final_image'] = result['pattern_image']
|
result['final_image'] = result['pattern_image']
|
||||||
@@ -60,7 +60,7 @@ class Color:
|
|||||||
result['single_image'] = cv2.add(tmp1, tmp2)
|
result['single_image'] = cv2.add(tmp1, tmp2)
|
||||||
result['alpha'] = 100 / 255.0
|
result['alpha'] = 100 / 255.0
|
||||||
|
|
||||||
result['no_seg_sketch'] = result['final_image'].copy()
|
result['no_seg_sketch_overall'] = result['no_seg_sketch_print'] = result['final_image'].copy()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def get_gradient(self, bucket_name, object_name):
|
def get_gradient(self, bucket_name, object_name):
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import io
|
import io
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -38,12 +39,42 @@ class LoadImage:
|
|||||||
|
|
||||||
def __call__(self, result):
|
def __call__(self, result):
|
||||||
result['image'], result['pre_mask'] = self.read_image(result['path'])
|
result['image'], result['pre_mask'] = self.read_image(result['path'])
|
||||||
result['gray'] = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY)
|
# if 'extract_lines' in result.keys():
|
||||||
|
# if result['extract_lines']:
|
||||||
|
# result['gray'] = self.get_lines(cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY), result['path'])
|
||||||
|
# else:
|
||||||
|
# result['gray'] = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY)
|
||||||
|
# else:
|
||||||
|
# result['gray'] = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY)
|
||||||
|
|
||||||
|
result['gray'] = self.get_lines(cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY), result['path'])
|
||||||
result['keypoint'] = self.get_keypoint(result['name'])
|
result['keypoint'] = self.get_keypoint(result['name'])
|
||||||
result['img_shape'] = result['image'].shape
|
result['img_shape'] = result['image'].shape
|
||||||
result['ori_shape'] = result['image'].shape
|
result['ori_shape'] = result['image'].shape
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def get_lines(self, img, path):
|
||||||
|
binary = cv2.adaptiveThreshold(img, 255,
|
||||||
|
cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
|
||||||
|
cv2.THRESH_BINARY_INV,
|
||||||
|
25, 10)
|
||||||
|
|
||||||
|
# 步骤2:细化边缘(可选,让线条更干净)
|
||||||
|
# kernel = np.ones((1, 1), np.uint8)
|
||||||
|
# clean = cv2.morphologyEx(binary, cv2.MORPH_OPEN, kernel)
|
||||||
|
|
||||||
|
thinned = cv2.ximgproc.thinning(binary, thinningType=cv2.ximgproc.THINNING_ZHANGSUEN) # thinning算法细化线条
|
||||||
|
mask = thinned > 0
|
||||||
|
result = np.ones_like(img) * 255
|
||||||
|
result[mask] = img[mask]
|
||||||
|
|
||||||
|
# 步骤3:反转回 白底黑线
|
||||||
|
# lines = cv2.bitwise_not(thinned)
|
||||||
|
# cv2.imwrite(os.path.join('/home/user/PycharmProjects/trinity_client_aida/test/lines_original_result_5', f"Original_{path.replace('/', '-')}.png"), img)
|
||||||
|
# cv2.imwrite(os.path.join('/home/user/PycharmProjects/trinity_client_aida/test/lines_original_result_5', f"Line_{path.replace('/', '-')}.png"), result)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
def read_image(self, image_path):
|
def read_image(self, image_path):
|
||||||
image_mask = None
|
image_mask = None
|
||||||
image = oss_get_image(oss_client=self.minio_client, bucket=image_path.split("/", 1)[0], object_name=image_path.split("/", 1)[1], data_type="cv2")
|
image = oss_get_image(oss_client=self.minio_client, bucket=image_path.split("/", 1)[0], object_name=image_path.split("/", 1)[1], data_type="cv2")
|
||||||
@@ -74,8 +105,8 @@ class LoadImage:
|
|||||||
keypoint = 'head_point'
|
keypoint = 'head_point'
|
||||||
elif name == 'earring':
|
elif name == 'earring':
|
||||||
keypoint = 'ear_point'
|
keypoint = 'ear_point'
|
||||||
elif name == 'accessories':
|
elif name == 'others':
|
||||||
keypoint = "accessories"
|
keypoint = "others"
|
||||||
else:
|
else:
|
||||||
raise KeyError(f"{name} does not belong to item category list: blouse, outwear, dress, trousers, skirt, "
|
raise KeyError(f"{name} does not belong to item category list: blouse, outwear, dress, trousers, skirt, "
|
||||||
f"bag, shoes, hairstyle, earring.")
|
f"bag, shoes, hairstyle, earring.")
|
||||||
|
|||||||
422
app/service/design_fast/pipeline/no_seg_print_painting.py
Normal file
422
app/service/design_fast/pipeline/no_seg_print_painting.py
Normal file
@@ -0,0 +1,422 @@
|
|||||||
|
import random
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from app.service.utils.new_oss_client import oss_get_image
|
||||||
|
|
||||||
|
|
||||||
|
class NoSegPrintPainting:
|
||||||
|
def __init__(self, minio_client):
|
||||||
|
self.minio_client = minio_client
|
||||||
|
|
||||||
|
def __call__(self, result):
|
||||||
|
single_print = result['print']['single']
|
||||||
|
overall_print = result['print']['overall']
|
||||||
|
element_print = result['print']['element']
|
||||||
|
result['single_image'] = None
|
||||||
|
result['print_image'] = None
|
||||||
|
|
||||||
|
if overall_print['print_path_list']:
|
||||||
|
painting_dict = {'dim_image_h': result['pattern_image'].shape[0], 'dim_image_w': result['pattern_image'].shape[1]}
|
||||||
|
if "print_angle_list" in overall_print.keys() and overall_print['print_angle_list'][0] != 0:
|
||||||
|
painting_dict = self.painting_collection(painting_dict, overall_print, print_trigger=True)
|
||||||
|
painting_dict['tile_print'] = self.rotate_crop_image(img=painting_dict['tile_print'], angle=-overall_print['print_angle_list'][0], crop=True)
|
||||||
|
painting_dict['mask_inv_print'] = self.rotate_crop_image(img=painting_dict['mask_inv_print'], angle=-overall_print['print_angle_list'][0], crop=True)
|
||||||
|
|
||||||
|
# resize 到sketch大小
|
||||||
|
painting_dict['tile_print'] = self.resize_and_crop(img=painting_dict['tile_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h'])
|
||||||
|
painting_dict['mask_inv_print'] = self.resize_and_crop(img=painting_dict['mask_inv_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h'])
|
||||||
|
else:
|
||||||
|
painting_dict = self.painting_collection(painting_dict, overall_print, print_trigger=True, is_single=False)
|
||||||
|
result['no_seg_sketch_overall'] = result['no_seg_sketch_print'] = self.printpaint(result, painting_dict, print_=True)
|
||||||
|
result['pattern_image'] = result['no_seg_sketch_overall']
|
||||||
|
|
||||||
|
if single_print['print_path_list']:
|
||||||
|
print_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
|
||||||
|
mask_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
|
||||||
|
for i in range(len(single_print['print_path_list'])):
|
||||||
|
image, image_mode = self.read_image(single_print['print_path_list'][i])
|
||||||
|
|
||||||
|
if image_mode == "RGB":
|
||||||
|
image_rgba = cv2.cvtColor(image, cv2.COLOR_BGR2RGBA)
|
||||||
|
image = Image.fromarray(image_rgba)
|
||||||
|
|
||||||
|
new_size = (int(result['pattern_image'].shape[1] * single_print['print_scale_list'][i][0]), int(result['pattern_image'].shape[0] * single_print['print_scale_list'][i][1]))
|
||||||
|
mask = image.split()[3]
|
||||||
|
resized_source = image.resize(new_size)
|
||||||
|
resized_source_mask = mask.resize(new_size)
|
||||||
|
rotated_resized_source = resized_source.rotate(-single_print['print_angle_list'][i])
|
||||||
|
rotated_resized_source_mask = resized_source_mask.rotate(-single_print['print_angle_list'][i])
|
||||||
|
source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB))
|
||||||
|
source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB))
|
||||||
|
source_image_pil.paste(rotated_resized_source, (int(single_print['location'][i][0]), int(single_print['location'][i][1])), rotated_resized_source)
|
||||||
|
source_image_pil_mask.paste(rotated_resized_source_mask, (int(single_print['location'][i][0]), int(single_print['location'][i][1])), rotated_resized_source_mask)
|
||||||
|
print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR)
|
||||||
|
mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR)
|
||||||
|
ret, mask_background = cv2.threshold(mask_background, 124, 255, cv2.THRESH_BINARY)
|
||||||
|
print_mask = cv2.bitwise_and(result['mask'], cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY))
|
||||||
|
img_fg = cv2.bitwise_or(print_background, print_background, mask=print_mask)
|
||||||
|
img_bg = cv2.bitwise_and(result['pattern_image'], result['pattern_image'], mask=cv2.bitwise_not(print_mask))
|
||||||
|
mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2)
|
||||||
|
gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
|
||||||
|
img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8) # 当sketch 图像为灰色时(非纯白) , 印花*灰度图像会导致印花在sketch上颜色变暗
|
||||||
|
# img_fg = (img_fg * (mask_mo / 255) ).astype(np.uint8) # 不过灰度图像
|
||||||
|
|
||||||
|
final_image = cv2.add(img_bg, img_fg)
|
||||||
|
canvas = np.full_like(final_image, 255)
|
||||||
|
temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2)
|
||||||
|
tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8)
|
||||||
|
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
|
||||||
|
tmp2 = (final_image * (temp_fg / 255)).astype(np.uint8)
|
||||||
|
single_image = cv2.add(tmp1, tmp2)
|
||||||
|
result['no_seg_sketch_print'] = single_image
|
||||||
|
|
||||||
|
if element_print['element_path_list']:
|
||||||
|
print_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8)
|
||||||
|
mask_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8)
|
||||||
|
for i in range(len(element_print['element_path_list'])):
|
||||||
|
image, image_mode = self.read_image(element_print['element_path_list'][i])
|
||||||
|
if image_mode == "RGBA":
|
||||||
|
new_size = (int(result['final_image'].shape[1] * element_print['element_scale_list'][i][0]), int(result['final_image'].shape[0] * element_print['element_scale_list'][i][1]))
|
||||||
|
|
||||||
|
mask = image.split()[3]
|
||||||
|
resized_source = image.resize(new_size)
|
||||||
|
resized_source_mask = mask.resize(new_size)
|
||||||
|
|
||||||
|
rotated_resized_source = resized_source.rotate(-element_print['element_angle_list'][i])
|
||||||
|
rotated_resized_source_mask = resized_source_mask.rotate(-element_print['element_angle_list'][i])
|
||||||
|
|
||||||
|
source_image_pil = Image.fromarray(cv2.cvtColor(print_background, cv2.COLOR_BGR2RGB))
|
||||||
|
source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB))
|
||||||
|
|
||||||
|
source_image_pil.paste(rotated_resized_source, (int(element_print['location'][i][0]), int(element_print['location'][i][1])), rotated_resized_source)
|
||||||
|
source_image_pil_mask.paste(rotated_resized_source_mask, (int(element_print['location'][i][0]), int(element_print['location'][i][1])), rotated_resized_source_mask)
|
||||||
|
|
||||||
|
print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR)
|
||||||
|
mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR)
|
||||||
|
else:
|
||||||
|
mask = self.get_mask_inv(image)
|
||||||
|
mask = np.expand_dims(mask, axis=2)
|
||||||
|
mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
|
||||||
|
mask = cv2.bitwise_not(mask)
|
||||||
|
mask = cv2.resize(mask, (int(result['final_image'].shape[1] * single_print['print_scale_list'][i][0]), int(result['final_image'].shape[0] * single_print['print_scale_list'][i][1])))
|
||||||
|
image = cv2.resize(image, (int(result['final_image'].shape[1] * single_print['print_scale_list'][i][0]), int(result['final_image'].shape[0] * single_print['print_scale_list'][i][1])))
|
||||||
|
# 旋转后的坐标需要重新算
|
||||||
|
rotate_mask, _ = self.img_rotate(mask, element_print['element_angle_list'][i])
|
||||||
|
rotate_image, rotated_new_size = self.img_rotate(image, element_print['element_angle_list'][i])
|
||||||
|
# x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2)
|
||||||
|
x, y = int(element_print['location'][i][0] - rotated_new_size[0]), int(element_print['location'][i][1] - rotated_new_size[1])
|
||||||
|
|
||||||
|
image_x = print_background.shape[1]
|
||||||
|
image_y = print_background.shape[0]
|
||||||
|
print_x = rotate_image.shape[1]
|
||||||
|
print_y = rotate_image.shape[0]
|
||||||
|
|
||||||
|
if x <= 0:
|
||||||
|
rotate_image = rotate_image[:, -x:]
|
||||||
|
rotate_mask = rotate_mask[:, -x:]
|
||||||
|
start_x = x = 0
|
||||||
|
else:
|
||||||
|
start_x = x
|
||||||
|
|
||||||
|
if y <= 0:
|
||||||
|
rotate_image = rotate_image[-y:, :]
|
||||||
|
rotate_mask = rotate_mask[-y:, :]
|
||||||
|
start_y = y = 0
|
||||||
|
else:
|
||||||
|
start_y = y
|
||||||
|
|
||||||
|
if x + print_x > image_x:
|
||||||
|
rotate_image = rotate_image[:, :image_x - x]
|
||||||
|
rotate_mask = rotate_mask[:, :image_x - x]
|
||||||
|
|
||||||
|
if y + print_y > image_y:
|
||||||
|
rotate_image = rotate_image[:image_y - y, :]
|
||||||
|
rotate_mask = rotate_mask[:image_y - y, :]
|
||||||
|
|
||||||
|
mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x)
|
||||||
|
print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x)
|
||||||
|
|
||||||
|
print_mask = cv2.bitwise_and(result['mask'], cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY))
|
||||||
|
img_fg = cv2.bitwise_or(print_background, print_background, mask=print_mask)
|
||||||
|
three_channel_image = cv2.merge([cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask)])
|
||||||
|
img_bg = cv2.bitwise_and(result['no_seg_sketch_print'], three_channel_image)
|
||||||
|
result['final_image'] = cv2.add(img_bg, img_fg)
|
||||||
|
canvas = np.full_like(result['final_image'], 255)
|
||||||
|
temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2)
|
||||||
|
tmp1 = (canvas * (temp_bg / 255)).astype(np.uint8)
|
||||||
|
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
|
||||||
|
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
|
||||||
|
result['no_seg_sketch_print'] = cv2.add(tmp1, tmp2)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def stack_prin(print_background, pattern_image, rotate_image, start_y, y, start_x, x):
|
||||||
|
temp_print = np.zeros((pattern_image.shape[0], pattern_image.shape[1], 3), dtype=np.uint8)
|
||||||
|
temp_print[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image
|
||||||
|
img2gray = cv2.cvtColor(temp_print, cv2.COLOR_BGR2GRAY)
|
||||||
|
ret, mask_ = cv2.threshold(img2gray, 1, 255, cv2.THRESH_BINARY)
|
||||||
|
mask_inv = cv2.bitwise_not(mask_)
|
||||||
|
img1_bg = cv2.bitwise_and(print_background, print_background, mask=mask_inv)
|
||||||
|
img2_fg = cv2.bitwise_and(temp_print, temp_print, mask=mask_)
|
||||||
|
print_background = img1_bg + img2_fg
|
||||||
|
return print_background
|
||||||
|
|
||||||
|
def painting_collection(self, painting_dict, print_dict, print_trigger=False, is_single=False):
|
||||||
|
if print_trigger:
|
||||||
|
print_ = self.get_print(print_dict)
|
||||||
|
painting_dict['Trigger'] = not is_single
|
||||||
|
painting_dict['location'] = print_['location']
|
||||||
|
single_mask_inv_print = self.get_mask_inv(print_['image'])
|
||||||
|
dim_max = max(painting_dict['dim_image_h'], painting_dict['dim_image_w'])
|
||||||
|
dim_pattern = (int(dim_max * print_['scale'] / 5), int(dim_max * print_['scale'] / 5))
|
||||||
|
if not is_single:
|
||||||
|
self.random_seed = random.randint(0, 1000)
|
||||||
|
# 如果print 模式为overall 且 有角度的话 , 组合的print为正方形,方便裁剪
|
||||||
|
if "print_angle_list" in print_dict.keys() and print_dict['print_angle_list'][0] != 0:
|
||||||
|
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], dim_max, dim_max, painting_dict['location'], trigger=True)
|
||||||
|
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], dim_max, dim_max, painting_dict['location'], trigger=True)
|
||||||
|
else:
|
||||||
|
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'], trigger=True)
|
||||||
|
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'], trigger=True)
|
||||||
|
else:
|
||||||
|
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'])
|
||||||
|
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'])
|
||||||
|
painting_dict['dim_print_h'], painting_dict['dim_print_w'] = dim_pattern
|
||||||
|
return painting_dict
|
||||||
|
|
||||||
|
def tile_image(self, pattern, dim, scale, dim_image_h, dim_image_w, location, trigger=False):
|
||||||
|
tile = None
|
||||||
|
if not trigger:
|
||||||
|
tile = cv2.resize(pattern, dim, interpolation=cv2.INTER_AREA)
|
||||||
|
else:
|
||||||
|
resize_pattern = cv2.resize(pattern, dim, interpolation=cv2.INTER_AREA)
|
||||||
|
if len(pattern.shape) == 2:
|
||||||
|
tile = np.tile(resize_pattern, (int((5 + 1) / scale) + 4, int((5 + 1) / scale) + 4))
|
||||||
|
if len(pattern.shape) == 3:
|
||||||
|
tile = np.tile(resize_pattern, (int((5 + 1) / scale) + 4, int((5 + 1) / scale) + 4, 1))
|
||||||
|
tile = self.crop_image(tile, dim_image_h, dim_image_w, location, resize_pattern.shape)
|
||||||
|
return tile
|
||||||
|
|
||||||
|
def get_mask_inv(self, print_):
|
||||||
|
if print_[0][0][0] == 255 and print_[0][0][1] == 255 and print_[0][0][2] == 255:
|
||||||
|
bg_color = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)[0][0]
|
||||||
|
print_tile = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)
|
||||||
|
bg_l, bg_a, bg_b = bg_color[0], bg_color[1], bg_color[2]
|
||||||
|
bg_L_high, bg_L_low = self.get_low_high_lab(bg_l, L=True)
|
||||||
|
bg_a_high, bg_a_low = self.get_low_high_lab(bg_a)
|
||||||
|
bg_b_high, bg_b_low = self.get_low_high_lab(bg_b)
|
||||||
|
lower = np.array([bg_L_low, bg_a_low, bg_b_low])
|
||||||
|
upper = np.array([bg_L_high, bg_a_high, bg_b_high])
|
||||||
|
mask_inv = cv2.inRange(print_tile, lower, upper)
|
||||||
|
return mask_inv
|
||||||
|
else:
|
||||||
|
mask_inv = np.zeros(print_.shape[:2], dtype=np.uint8)
|
||||||
|
return mask_inv
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def printpaint(result, painting_dict, print_=False):
|
||||||
|
|
||||||
|
if print_ and painting_dict['Trigger']:
|
||||||
|
print_mask = cv2.bitwise_and(result['mask'], cv2.bitwise_not(painting_dict['mask_inv_print']))
|
||||||
|
img_fg = cv2.bitwise_and(painting_dict['tile_print'], painting_dict['tile_print'], mask=print_mask)
|
||||||
|
else:
|
||||||
|
print_mask = result['mask']
|
||||||
|
img_fg = result['final_image']
|
||||||
|
if print_ and not painting_dict['Trigger']:
|
||||||
|
index_ = None
|
||||||
|
try:
|
||||||
|
index_ = len(painting_dict['location'])
|
||||||
|
except:
|
||||||
|
assert f'there must be parameter of location if choose IfSingle'
|
||||||
|
|
||||||
|
for i in range(index_):
|
||||||
|
start_h, start_w = int(painting_dict['location'][i][1]), int(painting_dict['location'][i][0])
|
||||||
|
|
||||||
|
length_h = min(start_h + painting_dict['dim_print_h'], img_fg.shape[0])
|
||||||
|
length_w = min(start_w + painting_dict['dim_print_w'], img_fg.shape[1])
|
||||||
|
|
||||||
|
change_region = img_fg[start_h: length_h, start_w: length_w, :]
|
||||||
|
# problem in change_mask
|
||||||
|
change_mask = print_mask[start_h: length_h, start_w: length_w]
|
||||||
|
# get real part into change mask
|
||||||
|
_, change_mask = cv2.threshold(change_mask, 220, 255, cv2.THRESH_BINARY)
|
||||||
|
mask = cv2.bitwise_not(painting_dict['mask_inv_print'])
|
||||||
|
img_fg[start_h:start_h + painting_dict['dim_print_h'], start_w:start_w + painting_dict['dim_print_w'], :] = change_region
|
||||||
|
|
||||||
|
clothes_mask_print = cv2.bitwise_not(print_mask)
|
||||||
|
|
||||||
|
img_bg = cv2.bitwise_and(result['pattern_image'], result['pattern_image'], mask=clothes_mask_print)
|
||||||
|
mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2)
|
||||||
|
gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
|
||||||
|
img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8)
|
||||||
|
print_image = cv2.add(img_bg, img_fg)
|
||||||
|
return print_image
|
||||||
|
|
||||||
|
def get_print(self, print_dict):
|
||||||
|
if 'print_scale_list' not in print_dict.keys() or print_dict['print_scale_list'][0][0] < 0.3:
|
||||||
|
print_dict['scale'] = 0.3
|
||||||
|
else:
|
||||||
|
print_dict['scale'] = print_dict['print_scale_list'][0][0]
|
||||||
|
|
||||||
|
bucket_name = print_dict['print_path_list'][0].split("/", 1)[0]
|
||||||
|
object_name = print_dict['print_path_list'][0].split("/", 1)[1]
|
||||||
|
image = oss_get_image(oss_client=self.minio_client, bucket=bucket_name, object_name=object_name, data_type="PIL")
|
||||||
|
# 判断图片格式,如果是RGBA 则贴在一张纯白图片上 防止透明转黑
|
||||||
|
if image.mode == "RGBA":
|
||||||
|
new_background = Image.new('RGB', image.size, (255, 255, 255))
|
||||||
|
new_background.paste(image, mask=image.split()[3])
|
||||||
|
image = new_background
|
||||||
|
print_dict['image'] = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
|
||||||
|
return print_dict
|
||||||
|
|
||||||
|
def crop_image(self, image, image_size_h, image_size_w, location, print_shape):
|
||||||
|
print_w = print_shape[1]
|
||||||
|
print_h = print_shape[0]
|
||||||
|
|
||||||
|
random.seed(self.random_seed)
|
||||||
|
|
||||||
|
# 1.拿到偏移量后和resize后的print宽高取余 得到真正偏移量
|
||||||
|
# 偏移量增加2分之print.w 使坐标位于图中间 如果要位于左上角删除+ print_w // 2 即可
|
||||||
|
x_offset = print_w - int(location[0][1] % print_w) + print_w // 2
|
||||||
|
y_offset = print_h - int(location[0][0] % print_h) + print_h // 2
|
||||||
|
|
||||||
|
# y_offset = int(location[0][0])
|
||||||
|
# x_offset = int(location[0][1])
|
||||||
|
|
||||||
|
if len(image.shape) == 2:
|
||||||
|
image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w]
|
||||||
|
elif len(image.shape) == 3:
|
||||||
|
image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w, :]
|
||||||
|
return image
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_low_high_lab(Lab_value, L=False):
|
||||||
|
if L:
|
||||||
|
high = Lab_value + 30 if Lab_value + 30 < 255 else 255
|
||||||
|
low = Lab_value - 30 if Lab_value - 30 > 0 else 0
|
||||||
|
else:
|
||||||
|
high = Lab_value + 30 if Lab_value + 30 < 255 else 255
|
||||||
|
low = Lab_value - 30 if Lab_value - 30 > 0 else 0
|
||||||
|
return high, low
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def img_rotate(image, angel):
|
||||||
|
"""顺时针旋转图像任意角度
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image (np.array): [原始图像]
|
||||||
|
angel (float): [逆时针旋转的角度]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
[array]: [旋转后的图像]
|
||||||
|
"""
|
||||||
|
|
||||||
|
h, w = image.shape[:2]
|
||||||
|
center = (w // 2, h // 2)
|
||||||
|
# if type(angel) is not int:
|
||||||
|
# angel = 0
|
||||||
|
M = cv2.getRotationMatrix2D(center, -angel, 1)
|
||||||
|
# 调整旋转后的图像长宽
|
||||||
|
rotated_h = int((w * np.abs(M[0, 1]) + (h * np.abs(M[0, 0]))))
|
||||||
|
rotated_w = int((h * np.abs(M[0, 1]) + (w * np.abs(M[0, 0]))))
|
||||||
|
M[0, 2] += (rotated_w - w) // 2
|
||||||
|
M[1, 2] += (rotated_h - h) // 2
|
||||||
|
# 旋转图像
|
||||||
|
rotated_img = cv2.warpAffine(image, M, (rotated_w, rotated_h))
|
||||||
|
|
||||||
|
return rotated_img, ((rotated_img.shape[1] - image.shape[1]) // 2, (rotated_img.shape[0] - image.shape[0]) // 2)
|
||||||
|
# return rotated_img, (0, 0)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def rotate_crop_image(img, angle, crop):
|
||||||
|
"""
|
||||||
|
angle: 旋转的角度
|
||||||
|
crop: 是否需要进行裁剪,布尔向量
|
||||||
|
"""
|
||||||
|
if not isinstance(crop, bool):
|
||||||
|
raise ValueError("The 'crop' parameter must be a boolean.")
|
||||||
|
|
||||||
|
crop_image = lambda img, x0, y0, w, h: img[y0:y0 + h, x0:x0 + w]
|
||||||
|
h, w = img.shape[:2]
|
||||||
|
# 旋转角度的周期是360°
|
||||||
|
angle %= 360
|
||||||
|
# 计算仿射变换矩阵
|
||||||
|
M_rotation = cv2.getRotationMatrix2D((w / 2, h / 2), angle, 1)
|
||||||
|
# 得到旋转后的图像
|
||||||
|
img_rotated = cv2.warpAffine(img, M_rotation, (w, h))
|
||||||
|
|
||||||
|
# 如果需要去除黑边
|
||||||
|
if crop:
|
||||||
|
# 裁剪角度的等效周期是180°
|
||||||
|
angle_crop = angle % 180
|
||||||
|
if angle_crop > 90:
|
||||||
|
angle_crop = 180 - angle_crop
|
||||||
|
# 转化角度为弧度
|
||||||
|
theta = angle_crop * np.pi / 180
|
||||||
|
# 计算高宽比
|
||||||
|
hw_ratio = float(h) / float(w)
|
||||||
|
# 计算裁剪边长系数的分子项
|
||||||
|
tan_theta = np.tan(theta)
|
||||||
|
numerator = np.cos(theta) + np.sin(theta) * np.tan(theta)
|
||||||
|
|
||||||
|
# 计算分母中和高宽比相关的项
|
||||||
|
r = hw_ratio if h > w else 1 / hw_ratio
|
||||||
|
# 计算分母项
|
||||||
|
denominator = r * tan_theta + 1
|
||||||
|
# 最终的边长系数
|
||||||
|
crop_mult = numerator / denominator
|
||||||
|
|
||||||
|
# 得到裁剪区域
|
||||||
|
w_crop = int(crop_mult * w)
|
||||||
|
h_crop = int(crop_mult * h)
|
||||||
|
x0 = int((w - w_crop) / 2)
|
||||||
|
y0 = int((h - h_crop) / 2)
|
||||||
|
|
||||||
|
img_rotated = crop_image(img_rotated, x0, y0, w_crop, h_crop)
|
||||||
|
|
||||||
|
return img_rotated
|
||||||
|
|
||||||
|
def read_image(self, image_url):
|
||||||
|
image = oss_get_image(oss_client=self.minio_client, bucket=image_url.split("/", 1)[0], object_name=image_url.split("/", 1)[1], data_type="cv2")
|
||||||
|
if image.shape[2] == 4:
|
||||||
|
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)
|
||||||
|
image = Image.fromarray(image_rgb)
|
||||||
|
image_mode = "RGBA"
|
||||||
|
else:
|
||||||
|
image_mode = "RGB"
|
||||||
|
return image, image_mode
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def resize_and_crop(img, target_width, target_height):
|
||||||
|
# 获取原始图像的尺寸
|
||||||
|
original_height, original_width = img.shape[:2]
|
||||||
|
|
||||||
|
# 计算目标尺寸的宽高比
|
||||||
|
target_ratio = target_width / target_height
|
||||||
|
|
||||||
|
# 计算原始图像的宽高比
|
||||||
|
original_ratio = original_width / original_height
|
||||||
|
|
||||||
|
# 调整尺寸
|
||||||
|
if original_ratio > target_ratio:
|
||||||
|
# 原始图像更宽,按高度resize,然后裁剪宽度
|
||||||
|
new_height = target_height
|
||||||
|
new_width = int(original_width * (target_height / original_height))
|
||||||
|
resized_img = cv2.resize(img, (new_width, new_height))
|
||||||
|
# 裁剪宽度
|
||||||
|
start_x = (new_width - target_width) // 2
|
||||||
|
cropped_img = resized_img[:, start_x:start_x + target_width]
|
||||||
|
else:
|
||||||
|
# 原始图像更高,按宽度resize,然后裁剪高度
|
||||||
|
new_width = target_width
|
||||||
|
new_height = int(original_height * (target_width / original_width))
|
||||||
|
resized_img = cv2.resize(img, (new_width, new_height))
|
||||||
|
# 裁剪高度
|
||||||
|
start_y = (new_height - target_height) // 2
|
||||||
|
cropped_img = resized_img[start_y:start_y + target_height, :]
|
||||||
|
|
||||||
|
return cropped_img
|
||||||
@@ -184,7 +184,7 @@ class PrintPainting:
|
|||||||
source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB))
|
source_image_pil_mask = Image.fromarray(cv2.cvtColor(mask_background, cv2.COLOR_BGR2RGB))
|
||||||
|
|
||||||
source_image_pil.paste(rotated_resized_source, (int(element_print['location'][i][0] * sketch_resize_scale[0]), int(element_print['location'][i][1] * sketch_resize_scale[1])), rotated_resized_source)
|
source_image_pil.paste(rotated_resized_source, (int(element_print['location'][i][0] * sketch_resize_scale[0]), int(element_print['location'][i][1] * sketch_resize_scale[1])), rotated_resized_source)
|
||||||
source_image_pil_mask.paste(rotated_resized_source_mask, (int(element_print['location'][i][0] * sketch_resize_scale[1]), int(element_print['location'][i][1] * sketch_resize_scale[1])), rotated_resized_source_mask)
|
source_image_pil_mask.paste(rotated_resized_source_mask, (int(element_print['location'][i][0] * sketch_resize_scale[0]), int(element_print['location'][i][1] * sketch_resize_scale[1])), rotated_resized_source_mask)
|
||||||
|
|
||||||
print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR)
|
print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR)
|
||||||
mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR)
|
mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR)
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ class Scaling:
|
|||||||
result['scale'] = result['scale_bag']
|
result['scale'] = result['scale_bag']
|
||||||
elif result['keypoint'] == 'ear_point':
|
elif result['keypoint'] == 'ear_point':
|
||||||
result['scale'] = result['scale_earrings']
|
result['scale'] = result['scale_earrings']
|
||||||
elif result['keypoint'] == 'accessories':
|
elif result['keypoint'] == 'others':
|
||||||
# 由于没有识别配饰keypoint的模型 所以统一将配饰的两个关键点设定为 (0,0) (0,img.width)
|
# 由于没有识别配饰keypoint的模型 所以统一将配饰的两个关键点设定为 (0,0) (0,img.width)
|
||||||
# 模特的关键点设定为(0,0) (0,320/2) 距离比例简写为 160 / img.width
|
# 模特的关键点设定为(0,0) (0,320/2) 距离比例简写为 160 / img.width
|
||||||
distance_clo = result['img_shape'][1]
|
distance_clo = result['img_shape'][1]
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ class Split(object):
|
|||||||
|
|
||||||
def __call__(self, result):
|
def __call__(self, result):
|
||||||
try:
|
try:
|
||||||
if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms', 'accessories'):
|
if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms', 'others'):
|
||||||
ori_front_mask = result['front_mask'].copy()
|
ori_front_mask = result['front_mask'].copy()
|
||||||
ori_back_mask = result['back_mask'].copy()
|
ori_back_mask = result['back_mask'].copy()
|
||||||
|
|
||||||
@@ -32,14 +32,14 @@ class Split(object):
|
|||||||
new_width = int(width * result['resize_scale'][0])
|
new_width = int(width * result['resize_scale'][0])
|
||||||
new_height = int(height * result['resize_scale'][1])
|
new_height = int(height * result['resize_scale'][1])
|
||||||
|
|
||||||
front_mask = cv2.resize(result['front_mask'], (new_width, new_height))
|
front_mask = cv2.resize(result['front_mask'], (new_width, new_height), interpolation=cv2.INTER_AREA)
|
||||||
back_mask = cv2.resize(result['back_mask'], (new_width, new_height))
|
back_mask = cv2.resize(result['back_mask'], (new_width, new_height), interpolation=cv2.INTER_AREA)
|
||||||
|
|
||||||
rgba_image = rgb_to_rgba(result['final_image'], front_mask + back_mask)
|
rgba_image = rgb_to_rgba(result['final_image'], front_mask + back_mask)
|
||||||
new_size = (int(rgba_image.shape[1] * result["scale"]), int(rgba_image.shape[0] * result["scale"]))
|
new_size = (int(rgba_image.shape[1] * result["scale"]), int(rgba_image.shape[0] * result["scale"]))
|
||||||
rgba_image = cv2.resize(rgba_image, new_size)
|
rgba_image = cv2.resize(rgba_image, new_size, interpolation=cv2.INTER_AREA)
|
||||||
result_front_image = np.zeros_like(rgba_image)
|
result_front_image = np.zeros_like(rgba_image)
|
||||||
front_mask = cv2.resize(front_mask, new_size)
|
front_mask = cv2.resize(front_mask, new_size, interpolation=cv2.INTER_AREA)
|
||||||
result_front_image[front_mask != 0] = rgba_image[front_mask != 0]
|
result_front_image[front_mask != 0] = rgba_image[front_mask != 0]
|
||||||
result_front_image_pil = Image.fromarray(cvtColor(result_front_image, COLOR_BGR2RGBA))
|
result_front_image_pil = Image.fromarray(cvtColor(result_front_image, COLOR_BGR2RGBA))
|
||||||
if 'transparent' in result.keys():
|
if 'transparent' in result.keys():
|
||||||
@@ -48,7 +48,7 @@ class Split(object):
|
|||||||
if transparent['mask_url'] is not None and transparent['mask_url'] != "":
|
if transparent['mask_url'] is not None and transparent['mask_url'] != "":
|
||||||
# 预处理用户自选区mask
|
# 预处理用户自选区mask
|
||||||
seg_mask = oss_get_image(oss_client=self.minio_client, bucket=transparent['mask_url'].split('/')[0], object_name=transparent['mask_url'][transparent['mask_url'].find('/') + 1:], data_type="cv2")
|
seg_mask = oss_get_image(oss_client=self.minio_client, bucket=transparent['mask_url'].split('/')[0], object_name=transparent['mask_url'][transparent['mask_url'].find('/') + 1:], data_type="cv2")
|
||||||
seg_mask = cv2.resize(seg_mask, new_size, interpolation=cv2.INTER_NEAREST)
|
seg_mask = cv2.resize(seg_mask, new_size, interpolation=cv2.INTER_AREA)
|
||||||
# 转换颜色空间为 RGB(OpenCV 默认是 BGR)
|
# 转换颜色空间为 RGB(OpenCV 默认是 BGR)
|
||||||
image_rgb = cv2.cvtColor(seg_mask, cv2.COLOR_BGR2RGB)
|
image_rgb = cv2.cvtColor(seg_mask, cv2.COLOR_BGR2RGB)
|
||||||
|
|
||||||
@@ -75,7 +75,7 @@ class Split(object):
|
|||||||
|
|
||||||
# if result["name"] in ('blouse', 'dress', 'outwear', 'tops'):
|
# if result["name"] in ('blouse', 'dress', 'outwear', 'tops'):
|
||||||
# result_back_image = np.zeros_like(rgba_image)
|
# result_back_image = np.zeros_like(rgba_image)
|
||||||
# back_mask = cv2.resize(back_mask, new_size)
|
# back_mask = cv2.resize(back_mask, new_size, interpolation=cv2.INTER_AREA)
|
||||||
# result_back_image[back_mask != 0] = rgba_image[back_mask != 0]
|
# result_back_image[back_mask != 0] = rgba_image[back_mask != 0]
|
||||||
# result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA))
|
# result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA))
|
||||||
# result['back_image'], result["back_image_url"], _ = upload_png_mask(self.minio_client, result_back_image_pil, f'{generate_uuid()}', mask=None)
|
# result['back_image'], result["back_image_url"], _ = upload_png_mask(self.minio_client, result_back_image_pil, f'{generate_uuid()}', mask=None)
|
||||||
@@ -104,7 +104,7 @@ class Split(object):
|
|||||||
# # result['back_mask_image'] = None
|
# # result['back_mask_image'] = None
|
||||||
|
|
||||||
result_back_image = np.zeros_like(rgba_image)
|
result_back_image = np.zeros_like(rgba_image)
|
||||||
back_mask = cv2.resize(back_mask, new_size)
|
back_mask = cv2.resize(back_mask, new_size, interpolation=cv2.INTER_AREA)
|
||||||
result_back_image[back_mask != 0] = rgba_image[back_mask != 0]
|
result_back_image[back_mask != 0] = rgba_image[back_mask != 0]
|
||||||
result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA))
|
result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA))
|
||||||
result['back_image'], result["back_image_url"], _ = upload_png_mask(self.minio_client, result_back_image_pil, f'{generate_uuid()}', mask=None)
|
result['back_image'], result["back_image_url"], _ = upload_png_mask(self.minio_client, result_back_image_pil, f'{generate_uuid()}', mask=None)
|
||||||
@@ -121,10 +121,12 @@ class Split(object):
|
|||||||
req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
|
req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
|
||||||
result['mask_url'] = req.bucket_name + "/" + req.object_name
|
result['mask_url'] = req.bucket_name + "/" + req.object_name
|
||||||
|
|
||||||
# 创建中间图层
|
# 创建中间图层(未分割图层) 1.color + overall_print 2.color + overall_print + print
|
||||||
result_pattern_image_rgba = rgb_to_rgba(result['no_seg_sketch'], ori_front_mask + ori_back_mask)
|
result_pattern_overall_image_pil = Image.fromarray(cvtColor(rgb_to_rgba(result['no_seg_sketch_overall'], ori_front_mask + ori_back_mask), COLOR_BGR2RGBA))
|
||||||
result_pattern_image_pil = Image.fromarray(cvtColor(result_pattern_image_rgba, COLOR_BGR2RGBA))
|
result['pattern_overall_image'], result['pattern_overall_image_url'], _ = upload_png_mask(self.minio_client, result_pattern_overall_image_pil, f'{generate_uuid()}')
|
||||||
result['pattern_image'], result['pattern_image_url'], _ = upload_png_mask(self.minio_client, result_pattern_image_pil, f'{generate_uuid()}')
|
|
||||||
|
result_pattern_print_image_pil = Image.fromarray(cvtColor(rgb_to_rgba(result['no_seg_sketch_print'], ori_front_mask + ori_back_mask), COLOR_BGR2RGBA))
|
||||||
|
result['pattern_print_image'], result['pattern_print_image_url'], _ = upload_png_mask(self.minio_client, result_pattern_print_image_pil, f'{generate_uuid()}')
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning(f"split runtime exception : {e} image_id : {result['image_id']}")
|
logging.warning(f"split runtime exception : {e} image_id : {result['image_id']}")
|
||||||
|
|||||||
@@ -32,7 +32,9 @@ def organize_clothing(layer):
|
|||||||
resize_scale=layer["resize_scale"],
|
resize_scale=layer["resize_scale"],
|
||||||
mask=cv2.resize(layer['mask'], layer["front_image"].size),
|
mask=cv2.resize(layer['mask'], layer["front_image"].size),
|
||||||
gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "",
|
gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "",
|
||||||
pattern_image_url=layer['pattern_image_url'],
|
pattern_overall_image_url=layer['pattern_overall_image_url'],
|
||||||
|
pattern_print_image_url=layer['pattern_print_image_url'],
|
||||||
|
|
||||||
pattern_image=layer['pattern_image'],
|
pattern_image=layer['pattern_image'],
|
||||||
# back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else ""
|
# back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else ""
|
||||||
)
|
)
|
||||||
@@ -49,20 +51,21 @@ def organize_clothing(layer):
|
|||||||
resize_scale=layer["resize_scale"],
|
resize_scale=layer["resize_scale"],
|
||||||
mask=cv2.resize(layer['mask'], layer["front_image"].size),
|
mask=cv2.resize(layer['mask'], layer["front_image"].size),
|
||||||
gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "",
|
gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "",
|
||||||
pattern_image_url=layer['pattern_image_url'],
|
pattern_overall_image_url=layer['pattern_overall_image_url'],
|
||||||
|
pattern_print_image_url=layer['pattern_print_image_url'],
|
||||||
# back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else ""
|
# back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else ""
|
||||||
)
|
)
|
||||||
return front_layer, back_layer
|
return front_layer, back_layer
|
||||||
|
|
||||||
|
|
||||||
def organize_accessories(layer):
|
def organize_others(layer):
|
||||||
# 起始坐标
|
# 起始坐标
|
||||||
start_point = (0, 0)
|
start_point = (0, 0)
|
||||||
layer['clothes_keypoint'] = {
|
layer['clothes_keypoint'] = {
|
||||||
'accessories_left': [0, 0]
|
'others_left': [0, 0]
|
||||||
}
|
}
|
||||||
layer['body_point_test'] = {
|
layer['body_point_test'] = {
|
||||||
'accessories_left': [0, 0]
|
'others_left': [0, 0]
|
||||||
}
|
}
|
||||||
|
|
||||||
start_point = calculate_start_point(layer['keypoint'], layer['scale'], layer['clothes_keypoint'], layer['body_point_test'], layer["offset"], layer["resize_scale"])
|
start_point = calculate_start_point(layer['keypoint'], layer['scale'], layer['clothes_keypoint'], layer['body_point_test'], layer["offset"], layer["resize_scale"])
|
||||||
@@ -80,7 +83,8 @@ def organize_accessories(layer):
|
|||||||
resize_scale=layer["resize_scale"],
|
resize_scale=layer["resize_scale"],
|
||||||
mask=cv2.resize(layer['mask'], layer["front_image"].size),
|
mask=cv2.resize(layer['mask'], layer["front_image"].size),
|
||||||
gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "",
|
gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "",
|
||||||
pattern_image_url=layer['pattern_image_url'],
|
pattern_overall_image_url=layer['pattern_overall_image_url'],
|
||||||
|
pattern_print_image_url=layer['pattern_print_image_url'],
|
||||||
pattern_image=layer['pattern_image'],
|
pattern_image=layer['pattern_image'],
|
||||||
# back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else ""
|
# back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else ""
|
||||||
)
|
)
|
||||||
@@ -97,7 +101,8 @@ def organize_accessories(layer):
|
|||||||
resize_scale=layer["resize_scale"],
|
resize_scale=layer["resize_scale"],
|
||||||
mask=cv2.resize(layer['mask'], layer["front_image"].size),
|
mask=cv2.resize(layer['mask'], layer["front_image"].size),
|
||||||
gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "",
|
gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "",
|
||||||
pattern_image_url=layer['pattern_image_url'],
|
pattern_overall_image_url=layer['pattern_overall_image_url'],
|
||||||
|
pattern_print_image_url=layer['pattern_print_image_url'],
|
||||||
# back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else ""
|
# back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else ""
|
||||||
)
|
)
|
||||||
return front_layer, back_layer
|
return front_layer, back_layer
|
||||||
|
|||||||
@@ -60,6 +60,18 @@ def positioning(all_mask_shape, mask_shape, offset):
|
|||||||
|
|
||||||
# @RunTime
|
# @RunTime
|
||||||
def synthesis(data, size, basic_info):
|
def synthesis(data, size, basic_info):
|
||||||
|
"""
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data:
|
||||||
|
size:
|
||||||
|
basic_info:
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
|
||||||
|
"""
|
||||||
|
# out_of_bounds_control: 是否允许服装越界 True 允许 False 不允许 默认情况允许
|
||||||
|
out_of_bounds_control = basic_info.get('out_of_bounds_control', True)
|
||||||
# 创建底图
|
# 创建底图
|
||||||
base_image = Image.new('RGBA', size, (0, 0, 0, 0))
|
base_image = Image.new('RGBA', size, (0, 0, 0, 0))
|
||||||
try:
|
try:
|
||||||
@@ -79,15 +91,18 @@ def synthesis(data, size, basic_info):
|
|||||||
_, binary_body_mask = cv2.threshold(body_mask, 127, 255, cv2.THRESH_BINARY)
|
_, binary_body_mask = cv2.threshold(body_mask, 127, 255, cv2.THRESH_BINARY)
|
||||||
top_outer_mask = np.array(binary_body_mask)
|
top_outer_mask = np.array(binary_body_mask)
|
||||||
bottom_outer_mask = np.array(binary_body_mask)
|
bottom_outer_mask = np.array(binary_body_mask)
|
||||||
accessories_outer_mask = np.array(binary_body_mask)
|
others_outer_mask = np.array(binary_body_mask)
|
||||||
|
|
||||||
top = True
|
top = True
|
||||||
bottom = True
|
bottom = True
|
||||||
accessories = True
|
others = True
|
||||||
i = len(data)
|
i = len(data)
|
||||||
while i:
|
while i:
|
||||||
i -= 1
|
i -= 1
|
||||||
if top and data[i]['name'] in ["blouse_front", "outwear_front", "dress_front", "tops_front"]:
|
if top and data[i]['name'] in ["blouse_front", "outwear_front", "dress_front", "tops_front"]:
|
||||||
|
if out_of_bounds_control:
|
||||||
|
top = True
|
||||||
|
else:
|
||||||
top = False
|
top = False
|
||||||
mask_shape = data[i]['mask'].shape
|
mask_shape = data[i]['mask'].shape
|
||||||
y_offset, x_offset = data[i]['adaptive_position']
|
y_offset, x_offset = data[i]['adaptive_position']
|
||||||
@@ -111,7 +126,7 @@ def synthesis(data, size, basic_info):
|
|||||||
background = np.zeros_like(top_outer_mask)
|
background = np.zeros_like(top_outer_mask)
|
||||||
background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end]
|
background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end]
|
||||||
bottom_outer_mask = background + bottom_outer_mask
|
bottom_outer_mask = background + bottom_outer_mask
|
||||||
elif accessories and data[i]['name'] in ['accessories_front']:
|
elif others and data[i]['name'] in ['others_front']:
|
||||||
mask_shape = data[i]['mask'].shape
|
mask_shape = data[i]['mask'].shape
|
||||||
y_offset, x_offset = data[i]['adaptive_position']
|
y_offset, x_offset = data[i]['adaptive_position']
|
||||||
# 初始化叠加区域的起始和结束位置
|
# 初始化叠加区域的起始和结束位置
|
||||||
@@ -121,13 +136,13 @@ def synthesis(data, size, basic_info):
|
|||||||
_, sketch_mask = cv2.threshold(data[i]['mask'], 127, 255, cv2.THRESH_BINARY)
|
_, sketch_mask = cv2.threshold(data[i]['mask'], 127, 255, cv2.THRESH_BINARY)
|
||||||
background = np.zeros_like(top_outer_mask)
|
background = np.zeros_like(top_outer_mask)
|
||||||
background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end]
|
background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end]
|
||||||
accessories_outer_mask = background + accessories_outer_mask
|
others_outer_mask = background + others_outer_mask
|
||||||
pass
|
pass
|
||||||
elif bottom is False and top is False:
|
elif bottom is False and top is False:
|
||||||
break
|
break
|
||||||
|
|
||||||
all_mask = cv2.bitwise_or(top_outer_mask, bottom_outer_mask)
|
all_mask = cv2.bitwise_or(top_outer_mask, bottom_outer_mask)
|
||||||
all_mask = cv2.bitwise_or(all_mask, accessories_outer_mask)
|
all_mask = cv2.bitwise_or(all_mask, others_outer_mask)
|
||||||
|
|
||||||
for layer in data:
|
for layer in data:
|
||||||
if layer['image'] is not None:
|
if layer['image'] is not None:
|
||||||
@@ -207,7 +222,9 @@ def update_base_size_priority(layers, size):
|
|||||||
if info['name'] == 'mannequin':
|
if info['name'] == 'mannequin':
|
||||||
new_height = info['image'].height
|
new_height = info['image'].height
|
||||||
max_x = max(x_list)
|
max_x = max(x_list)
|
||||||
new_width = max_x - min_x * 2
|
|
||||||
|
# x坐标中最小偏移量的绝对值 + 最大偏移量
|
||||||
|
new_width = max_x + abs(min_x)
|
||||||
# 更新坐标
|
# 更新坐标
|
||||||
for info in layers:
|
for info in layers:
|
||||||
info['adaptive_position'] = (info['position'][0], info['position'][1] - min_x)
|
info['adaptive_position'] = (info['position'][0], info['position'][1] - min_x)
|
||||||
|
|||||||
@@ -1,240 +1,240 @@
|
|||||||
# 预加载资源
|
# # 预加载资源
|
||||||
import logging
|
# import logging
|
||||||
import time
|
# import time
|
||||||
from collections import defaultdict
|
# from collections import defaultdict
|
||||||
import os
|
# import os
|
||||||
import json
|
# import json
|
||||||
import numpy as np
|
# import numpy as np
|
||||||
|
#
|
||||||
from app.core.config import DB_CONFIG, RECOMMEND_PATH_PREFIX
|
# from app.core.config import DB_CONFIG, RECOMMEND_PATH_PREFIX
|
||||||
|
#
|
||||||
logger = logging.getLogger()
|
# logger = logging.getLogger()
|
||||||
import pymysql
|
# import pymysql
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
# from concurrent.futures import ThreadPoolExecutor
|
||||||
|
#
|
||||||
HEAT_VECTOR_FILE = 'heat_vectors_data/heat_vectors.json' # 可动态加载或配置
|
# HEAT_VECTOR_FILE = 'heat_vectors_data/heat_vectors.json' # 可动态加载或配置
|
||||||
|
#
|
||||||
matrix_data = {
|
# matrix_data = {
|
||||||
"interaction_matrix": None,
|
# "interaction_matrix": None,
|
||||||
"feature_matrix": None,
|
# "feature_matrix": None,
|
||||||
"user_index_interaction": None,
|
# "user_index_interaction": None,
|
||||||
"sketch_index_interaction": None,
|
# "sketch_index_interaction": None,
|
||||||
"user_index_feature": None,
|
# "user_index_feature": None,
|
||||||
"sketch_index_feature": None,
|
# "sketch_index_feature": None,
|
||||||
"iid_to_sketch": None,
|
# "iid_to_sketch": None,
|
||||||
"category_to_iids": None,
|
# "category_to_iids": None,
|
||||||
"cached_scores": {},
|
# "cached_scores": {},
|
||||||
"cached_valid_idxs": {},
|
# "cached_valid_idxs": {},
|
||||||
"category_sketch_idxs_inter": None,
|
# "category_sketch_idxs_inter": None,
|
||||||
"category_sketch_idxs_feature": None,
|
# "category_sketch_idxs_feature": None,
|
||||||
"user_inter_full": dict(),
|
# "user_inter_full": dict(),
|
||||||
"user_feat_full": dict(),
|
# "user_feat_full": dict(),
|
||||||
"brand_feature_matrix": None,
|
# "brand_feature_matrix": None,
|
||||||
"brand_index_map": None,
|
# "brand_index_map": None,
|
||||||
"heat_data": {},
|
# "heat_data": {},
|
||||||
}
|
# }
|
||||||
|
#
|
||||||
|
#
|
||||||
def load_resources():
|
# def load_resources():
|
||||||
"""加载所有矩阵和映射关系,并触发预缓存"""
|
# """加载所有矩阵和映射关系,并触发预缓存"""
|
||||||
try:
|
# try:
|
||||||
start_time = time.time()
|
# start_time = time.time()
|
||||||
|
#
|
||||||
# 清空缓存
|
# # 清空缓存
|
||||||
matrix_data["cached_scores"].clear()
|
# matrix_data["cached_scores"].clear()
|
||||||
matrix_data["cached_valid_idxs"].clear()
|
# matrix_data["cached_valid_idxs"].clear()
|
||||||
|
#
|
||||||
# 加载数据
|
# # 加载数据
|
||||||
sketch_to_iid = np.load(f'{RECOMMEND_PATH_PREFIX}sketch_to_iid.npy', allow_pickle=True).item()
|
# sketch_to_iid = np.load(f'{RECOMMEND_PATH_PREFIX}sketch_to_iid.npy', allow_pickle=True).item()
|
||||||
matrix_data["iid_to_sketch"] = {v: k for k, v in sketch_to_iid.items()}
|
# matrix_data["iid_to_sketch"] = {v: k for k, v in sketch_to_iid.items()}
|
||||||
|
#
|
||||||
matrix_data["interaction_matrix"] = np.load(f"{RECOMMEND_PATH_PREFIX}interaction_matrix.npy", allow_pickle=True)
|
# matrix_data["interaction_matrix"] = np.load(f"{RECOMMEND_PATH_PREFIX}interaction_matrix.npy", allow_pickle=True)
|
||||||
matrix_data["user_index_interaction"] = np.load(f"{RECOMMEND_PATH_PREFIX}user_index_interaction_matrix.npy", allow_pickle=True).item()
|
# matrix_data["user_index_interaction"] = np.load(f"{RECOMMEND_PATH_PREFIX}user_index_interaction_matrix.npy", allow_pickle=True).item()
|
||||||
matrix_data["sketch_index_interaction"] = np.load(f"{RECOMMEND_PATH_PREFIX}sketch_index_interaction_matrix.npy",
|
# matrix_data["sketch_index_interaction"] = np.load(f"{RECOMMEND_PATH_PREFIX}sketch_index_interaction_matrix.npy",
|
||||||
allow_pickle=True).item()
|
# allow_pickle=True).item()
|
||||||
|
#
|
||||||
matrix_data["feature_matrix"] = np.load(f"{RECOMMEND_PATH_PREFIX}feature_matrix.npy", allow_pickle=True)
|
# matrix_data["feature_matrix"] = np.load(f"{RECOMMEND_PATH_PREFIX}feature_matrix.npy", allow_pickle=True)
|
||||||
|
#
|
||||||
brand_feature_path = f"{RECOMMEND_PATH_PREFIX}brand_feature_matrix.npy"
|
# brand_feature_path = f"{RECOMMEND_PATH_PREFIX}brand_feature_matrix.npy"
|
||||||
if os.path.exists(brand_feature_path):
|
# if os.path.exists(brand_feature_path):
|
||||||
matrix_data["brand_feature_matrix"] = np.load(brand_feature_path, allow_pickle=True)
|
# matrix_data["brand_feature_matrix"] = np.load(brand_feature_path, allow_pickle=True)
|
||||||
else:
|
# else:
|
||||||
logger.warning("brand_feature_matrix 文件不存在,使用空数组")
|
# logger.warning("brand_feature_matrix 文件不存在,使用空数组")
|
||||||
matrix_data["brand_feature_matrix"] = np.array([])
|
# matrix_data["brand_feature_matrix"] = np.array([])
|
||||||
|
#
|
||||||
# brand_index_map
|
# # brand_index_map
|
||||||
brand_index_path = f"{RECOMMEND_PATH_PREFIX}brand_index_map.npy"
|
# brand_index_path = f"{RECOMMEND_PATH_PREFIX}brand_index_map.npy"
|
||||||
if os.path.exists(brand_index_path):
|
# if os.path.exists(brand_index_path):
|
||||||
matrix_data["brand_index_map"] = np.load(brand_index_path, allow_pickle=True).item()
|
# matrix_data["brand_index_map"] = np.load(brand_index_path, allow_pickle=True).item()
|
||||||
else:
|
# else:
|
||||||
logger.warning("brand_index_map 文件不存在,使用空字典")
|
# logger.warning("brand_index_map 文件不存在,使用空字典")
|
||||||
matrix_data["brand_index_map"] = {}
|
# matrix_data["brand_index_map"] = {}
|
||||||
|
#
|
||||||
matrix_data["user_index_feature"] = np.load(f"{RECOMMEND_PATH_PREFIX}user_index_feature_matrix.npy", allow_pickle=True).item()
|
# matrix_data["user_index_feature"] = np.load(f"{RECOMMEND_PATH_PREFIX}user_index_feature_matrix.npy", allow_pickle=True).item()
|
||||||
|
#
|
||||||
matrix_data["sketch_index_feature"] = np.load(f"{RECOMMEND_PATH_PREFIX}sketch_index_feature_matrix.npy", allow_pickle=True).item()
|
# matrix_data["sketch_index_feature"] = np.load(f"{RECOMMEND_PATH_PREFIX}sketch_index_feature_matrix.npy", allow_pickle=True).item()
|
||||||
|
#
|
||||||
category_to_iid_map = np.load(f"{RECOMMEND_PATH_PREFIX}iid_to_category_interaction_matrix.npy", allow_pickle=True).item()
|
# category_to_iid_map = np.load(f"{RECOMMEND_PATH_PREFIX}iid_to_category_interaction_matrix.npy", allow_pickle=True).item()
|
||||||
matrix_data["category_to_iids"] = defaultdict(list)
|
# matrix_data["category_to_iids"] = defaultdict(list)
|
||||||
for iid, cat in category_to_iid_map.items():
|
# for iid, cat in category_to_iid_map.items():
|
||||||
matrix_data["category_to_iids"][cat].append(iid)
|
# matrix_data["category_to_iids"][cat].append(iid)
|
||||||
|
#
|
||||||
logger.info(f"资源加载完成,耗时: {time.time() - start_time:.2f}秒")
|
# logger.info(f"资源加载完成,耗时: {time.time() - start_time:.2f}秒")
|
||||||
|
#
|
||||||
# 触发预缓存
|
# # 触发预缓存
|
||||||
precache_user_category()
|
# precache_user_category()
|
||||||
|
#
|
||||||
if os.path.exists(HEAT_VECTOR_FILE):
|
# if os.path.exists(HEAT_VECTOR_FILE):
|
||||||
with open(HEAT_VECTOR_FILE, 'r', encoding='utf-8') as f:
|
# with open(HEAT_VECTOR_FILE, 'r', encoding='utf-8') as f:
|
||||||
heat_json = json.load(f)
|
# heat_json = json.load(f)
|
||||||
matrix_data["heat_data"] = heat_json.get("data", {})
|
# matrix_data["heat_data"] = heat_json.get("data", {})
|
||||||
logger.info(f"热度向量数据加载完成,共加载 {len(matrix_data['heat_data'])} 个类别")
|
# logger.info(f"热度向量数据加载完成,共加载 {len(matrix_data['heat_data'])} 个类别")
|
||||||
else:
|
# else:
|
||||||
matrix_data["heat_data"] = {}
|
# matrix_data["heat_data"] = {}
|
||||||
|
#
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
logger.error(f"资源加载失败: {str(e)}")
|
# logger.error(f"资源加载失败: {str(e)}")
|
||||||
raise RuntimeError("初始化失败")
|
# raise RuntimeError("初始化失败")
|
||||||
|
#
|
||||||
|
#
|
||||||
def precache_user_category():
|
# def precache_user_category():
|
||||||
"""优化后的用户分类预缓存(添加耗时统计)"""
|
# """优化后的用户分类预缓存(添加耗时统计)"""
|
||||||
if not all([
|
# if not all([
|
||||||
matrix_data["interaction_matrix"] is not None,
|
# matrix_data["interaction_matrix"] is not None,
|
||||||
matrix_data["feature_matrix"] is not None,
|
# matrix_data["feature_matrix"] is not None,
|
||||||
matrix_data["user_index_interaction"] is not None
|
# matrix_data["user_index_interaction"] is not None
|
||||||
]):
|
# ]):
|
||||||
logger.warning("资源未加载完成,跳过预缓存")
|
# logger.warning("资源未加载完成,跳过预缓存")
|
||||||
return
|
# return
|
||||||
|
#
|
||||||
start_time = time.perf_counter()
|
# start_time = time.perf_counter()
|
||||||
time_stats = {
|
# time_stats = {
|
||||||
"get_all_user_categories": 0,
|
# "get_all_user_categories": 0,
|
||||||
"process_user_category": 0,
|
# "process_user_category": 0,
|
||||||
"thread_execution": 0,
|
# "thread_execution": 0,
|
||||||
"cache_update": 0,
|
# "cache_update": 0,
|
||||||
"total": 0,
|
# "total": 0,
|
||||||
}
|
# }
|
||||||
|
#
|
||||||
# 统计用户类别获取时间
|
# # 统计用户类别获取时间
|
||||||
t1 = time.perf_counter()
|
# t1 = time.perf_counter()
|
||||||
user_categories = get_all_user_categories()
|
# user_categories = get_all_user_categories()
|
||||||
time_stats["get_all_user_categories"] = time.perf_counter() - t1
|
# time_stats["get_all_user_categories"] = time.perf_counter() - t1
|
||||||
|
#
|
||||||
precached_count = 0
|
# precached_count = 0
|
||||||
|
#
|
||||||
def process_user_category(user_id, categories):
|
# def process_user_category(user_id, categories):
|
||||||
"""单用户类别缓存计算(统计耗时)"""
|
# """单用户类别缓存计算(统计耗时)"""
|
||||||
local_cache = {}
|
# local_cache = {}
|
||||||
local_valid_idxs = {}
|
# local_valid_idxs = {}
|
||||||
t_start = time.perf_counter()
|
# t_start = time.perf_counter()
|
||||||
|
#
|
||||||
for category in categories:
|
# for category in categories:
|
||||||
cache_key = (user_id, category)
|
# cache_key = (user_id, category)
|
||||||
if cache_key in matrix_data["cached_scores"]:
|
# if cache_key in matrix_data["cached_scores"]:
|
||||||
continue
|
# continue
|
||||||
|
#
|
||||||
try:
|
# try:
|
||||||
user_idx_inter = matrix_data["user_index_interaction"].get(user_id)
|
# user_idx_inter = matrix_data["user_index_interaction"].get(user_id)
|
||||||
user_idx_feature = matrix_data["user_index_feature"].get(user_id)
|
# user_idx_feature = matrix_data["user_index_feature"].get(user_id)
|
||||||
|
#
|
||||||
# 统计获取类别 IID 耗时
|
# # 统计获取类别 IID 耗时
|
||||||
t_iid = time.perf_counter()
|
# t_iid = time.perf_counter()
|
||||||
category_iids = matrix_data["category_to_iids"].get(category, [])
|
# category_iids = matrix_data["category_to_iids"].get(category, [])
|
||||||
valid_sketch_idxs_inter = [matrix_data["sketch_index_interaction"][iid]
|
# valid_sketch_idxs_inter = [matrix_data["sketch_index_interaction"][iid]
|
||||||
for iid in category_iids if iid in matrix_data["sketch_index_interaction"]]
|
# for iid in category_iids if iid in matrix_data["sketch_index_interaction"]]
|
||||||
valid_sketch_idxs_feature = [matrix_data["sketch_index_feature"][iid]
|
# valid_sketch_idxs_feature = [matrix_data["sketch_index_feature"][iid]
|
||||||
for iid in category_iids if iid in matrix_data["sketch_index_feature"]]
|
# for iid in category_iids if iid in matrix_data["sketch_index_feature"]]
|
||||||
time_stats["process_user_category"] += time.perf_counter() - t_iid
|
# time_stats["process_user_category"] += time.perf_counter() - t_iid
|
||||||
|
#
|
||||||
# 统计矩阵计算耗时
|
# # 统计矩阵计算耗时
|
||||||
t_matrix = time.perf_counter()
|
# t_matrix = time.perf_counter()
|
||||||
processed_inter = np.zeros(len(valid_sketch_idxs_inter))
|
# processed_inter = np.zeros(len(valid_sketch_idxs_inter))
|
||||||
if user_idx_inter is not None and valid_sketch_idxs_inter:
|
# if user_idx_inter is not None and valid_sketch_idxs_inter:
|
||||||
raw_inter_scores = matrix_data["interaction_matrix"][user_idx_inter, valid_sketch_idxs_inter]
|
# raw_inter_scores = matrix_data["interaction_matrix"][user_idx_inter, valid_sketch_idxs_inter]
|
||||||
processed_inter = raw_inter_scores * 0.7
|
# processed_inter = raw_inter_scores * 0.7
|
||||||
|
#
|
||||||
processed_feat = np.zeros(len(valid_sketch_idxs_feature))
|
# processed_feat = np.zeros(len(valid_sketch_idxs_feature))
|
||||||
if user_idx_feature is not None and valid_sketch_idxs_feature:
|
# if user_idx_feature is not None and valid_sketch_idxs_feature:
|
||||||
raw_feat_scores = matrix_data["feature_matrix"][user_idx_feature, valid_sketch_idxs_feature]
|
# raw_feat_scores = matrix_data["feature_matrix"][user_idx_feature, valid_sketch_idxs_feature]
|
||||||
raw_feat_scores = (raw_feat_scores - np.min(raw_feat_scores)) / (
|
# raw_feat_scores = (raw_feat_scores - np.min(raw_feat_scores)) / (
|
||||||
np.max(raw_feat_scores) - np.min(raw_feat_scores) + 1e-8)
|
# np.max(raw_feat_scores) - np.min(raw_feat_scores) + 1e-8)
|
||||||
processed_feat = raw_feat_scores * 0.3
|
# processed_feat = raw_feat_scores * 0.3
|
||||||
time_stats["process_user_category"] += time.perf_counter() - t_matrix
|
# time_stats["process_user_category"] += time.perf_counter() - t_matrix
|
||||||
|
#
|
||||||
if len(processed_inter) == len(processed_feat):
|
# if len(processed_inter) == len(processed_feat):
|
||||||
local_cache[cache_key] = (processed_inter, processed_feat)
|
# local_cache[cache_key] = (processed_inter, processed_feat)
|
||||||
local_valid_idxs[cache_key] = valid_sketch_idxs_inter
|
# local_valid_idxs[cache_key] = valid_sketch_idxs_inter
|
||||||
|
#
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
logger.error(f"预缓存失败 (user={user_id}, category={category}): {str(e)}")
|
# logger.error(f"预缓存失败 (user={user_id}, category={category}): {str(e)}")
|
||||||
|
#
|
||||||
return local_cache, local_valid_idxs
|
# return local_cache, local_valid_idxs
|
||||||
|
#
|
||||||
# 统计线程执行时间
|
# # 统计线程执行时间
|
||||||
t2 = time.perf_counter()
|
# t2 = time.perf_counter()
|
||||||
with ThreadPoolExecutor(max_workers=8) as executor:
|
# with ThreadPoolExecutor(max_workers=8) as executor:
|
||||||
futures = {executor.submit(process_user_category, user_id, categories): user_id for user_id, categories in user_categories.items()}
|
# futures = {executor.submit(process_user_category, user_id, categories): user_id for user_id, categories in user_categories.items()}
|
||||||
for future in futures:
|
# for future in futures:
|
||||||
try:
|
# try:
|
||||||
t_cache = time.perf_counter()
|
# t_cache = time.perf_counter()
|
||||||
cache_part, valid_idxs_part = future.result()
|
# cache_part, valid_idxs_part = future.result()
|
||||||
matrix_data["cached_scores"].update(cache_part)
|
# matrix_data["cached_scores"].update(cache_part)
|
||||||
matrix_data["cached_valid_idxs"].update(valid_idxs_part)
|
# matrix_data["cached_valid_idxs"].update(valid_idxs_part)
|
||||||
time_stats["cache_update"] += time.perf_counter() - t_cache
|
# time_stats["cache_update"] += time.perf_counter() - t_cache
|
||||||
precached_count += len(cache_part)
|
# precached_count += len(cache_part)
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
logger.error(f"线程执行错误: {str(e)}")
|
# logger.error(f"线程执行错误: {str(e)}")
|
||||||
time_stats["thread_execution"] = time.perf_counter() - t2
|
# time_stats["thread_execution"] = time.perf_counter() - t2
|
||||||
|
#
|
||||||
time_stats["total"] = time.perf_counter() - start_time
|
# time_stats["total"] = time.perf_counter() - start_time
|
||||||
|
#
|
||||||
# 输出统计信息
|
# # 输出统计信息
|
||||||
logger.info(f"""
|
# logger.info(f"""
|
||||||
预缓存完成,共缓存 {precached_count} 组数据,耗时统计如下:
|
# 预缓存完成,共缓存 {precached_count} 组数据,耗时统计如下:
|
||||||
- 获取用户类别数据: {time_stats["get_all_user_categories"]:.2f}s
|
# - 获取用户类别数据: {time_stats["get_all_user_categories"]:.2f}s
|
||||||
- 计算用户类别缓存: {time_stats["process_user_category"]:.2f}s
|
# - 计算用户类别缓存: {time_stats["process_user_category"]:.2f}s
|
||||||
- 线程任务执行: {time_stats["thread_execution"]:.2f}s
|
# - 线程任务执行: {time_stats["thread_execution"]:.2f}s
|
||||||
- 更新缓存数据: {time_stats["cache_update"]:.2f}s
|
# - 更新缓存数据: {time_stats["cache_update"]:.2f}s
|
||||||
- 总耗时: {time_stats["total"]:.2f}s
|
# - 总耗时: {time_stats["total"]:.2f}s
|
||||||
""")
|
# """)
|
||||||
|
#
|
||||||
|
#
|
||||||
def get_all_user_categories():
|
# def get_all_user_categories():
|
||||||
"""获取所有用户及其对应的分类"""
|
# """获取所有用户及其对应的分类"""
|
||||||
conn = None
|
# conn = None
|
||||||
try:
|
# try:
|
||||||
conn = pymysql.connect(**DB_CONFIG)
|
# conn = pymysql.connect(**DB_CONFIG)
|
||||||
cursor = conn.cursor()
|
# cursor = conn.cursor()
|
||||||
|
#
|
||||||
query = """
|
# query = """
|
||||||
SELECT DISTINCT account_id, path
|
# SELECT DISTINCT account_id, path
|
||||||
FROM user_preference_log_prediction
|
# FROM user_preference_log_prediction
|
||||||
"""
|
# """
|
||||||
cursor.execute(query)
|
# cursor.execute(query)
|
||||||
results = cursor.fetchall()
|
# results = cursor.fetchall()
|
||||||
|
#
|
||||||
user_categories = defaultdict(set)
|
# user_categories = defaultdict(set)
|
||||||
for account_id, path in results:
|
# for account_id, path in results:
|
||||||
category = get_category_from_path(path)
|
# category = get_category_from_path(path)
|
||||||
user_categories[account_id].add(category)
|
# user_categories[account_id].add(category)
|
||||||
|
#
|
||||||
return dict(user_categories)
|
# return dict(user_categories)
|
||||||
|
#
|
||||||
except Exception as e:
|
# except Exception as e:
|
||||||
logger.error(f"数据库查询失败: {str(e)}")
|
# logger.error(f"数据库查询失败: {str(e)}")
|
||||||
return {}
|
# return {}
|
||||||
finally:
|
# finally:
|
||||||
if conn:
|
# if conn:
|
||||||
conn.close()
|
# conn.close()
|
||||||
|
#
|
||||||
|
#
|
||||||
def get_category_from_path(path: str) -> str:
|
# def get_category_from_path(path: str) -> str:
|
||||||
"""从路径解析类别"""
|
# """从路径解析类别"""
|
||||||
try:
|
# try:
|
||||||
parts = path.split('/')
|
# parts = path.split('/')
|
||||||
if len(parts) >= 4:
|
# if len(parts) >= 4:
|
||||||
return f"{parts[2]}_{parts[3]}"
|
# return f"{parts[2]}_{parts[3]}"
|
||||||
return "unknown"
|
# return "unknown"
|
||||||
except:
|
# except:
|
||||||
return "unknown"
|
# return "unknown"
|
||||||
|
|||||||
1
app/service/recommendation_system/__init__.py
Normal file
1
app/service/recommendation_system/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
|
||||||
73
app/service/recommendation_system/config.py
Normal file
73
app/service/recommendation_system/config.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
"""
|
||||||
|
推荐系统配置
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
from app.core.config import (
|
||||||
|
DB_CONFIG, DB_HOST, DB_PORT, DB_USERNAME, DB_PASSWORD, DB_NAME,
|
||||||
|
REDIS_HOST, REDIS_PORT, REDIS_DB,
|
||||||
|
MILVUS_URL, MILVUS_TOKEN, MILVUS_ALIAS,
|
||||||
|
MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE
|
||||||
|
)
|
||||||
|
|
||||||
|
# Milvus 集合名称
|
||||||
|
MILVUS_COLLECTION_SKETCH_VECTORS = "sketch_vectors_norm"
|
||||||
|
|
||||||
|
# Redis key 前缀
|
||||||
|
REDIS_KEY_USER_PREF_PREFIX = "user_pref"
|
||||||
|
|
||||||
|
# 推荐系统配置参数
|
||||||
|
RECOMMENDATION_CONFIG = {
|
||||||
|
# 时间衰减半衰期(用于计算时间衰减权重)
|
||||||
|
# 值越小,最近的行为权重越大
|
||||||
|
"K_half": 20,
|
||||||
|
|
||||||
|
# 探索与利用的比例 (0.0-1.0)
|
||||||
|
# - 值越大,使用探索分支(随机推荐)的几率越大,结果更随机
|
||||||
|
# - 值越小,使用利用分支(基于用户偏好)的几率越大,结果更精准
|
||||||
|
# - 建议范围: 0.3-0.7,要增加随机性可提高到 0.6-0.8
|
||||||
|
"explore_ratio": 0.5,
|
||||||
|
|
||||||
|
# 向量检索返回的候选数量
|
||||||
|
# 值越大,候选池越大,但计算成本也越高
|
||||||
|
# 建议范围: 100-1000
|
||||||
|
"topk": 1000,
|
||||||
|
|
||||||
|
# Style 加分系数(同 style 的候选进行加分)
|
||||||
|
# 值越大,匹配 style 的候选被选中的概率越大
|
||||||
|
# 要降低某个结果的重复率,可以降低此值(如 0.1 或 0.05)
|
||||||
|
"style_bonus": 0.2,
|
||||||
|
|
||||||
|
# Softmax 抽样的温度参数
|
||||||
|
# - 温度越高(>1.0),概率分布越均匀,结果更随机,重复率更低
|
||||||
|
# - 温度越低(<1.0),高分项概率越大,结果更集中,重复率更高
|
||||||
|
# - 温度=1.0 为标准 Softmax
|
||||||
|
# - 建议范围: 1.0-3.0,要增加随机性可提高到 2.0-3.0
|
||||||
|
"softmax_temperature": 0.07,
|
||||||
|
|
||||||
|
# 监听间隔(秒)
|
||||||
|
"listen_interval_sec": 30,
|
||||||
|
|
||||||
|
# 批量处理大小
|
||||||
|
"batch_size": 1000,
|
||||||
|
|
||||||
|
# Redis 过期时间(秒,30天)
|
||||||
|
"redis_expire_seconds": 2592000,
|
||||||
|
|
||||||
|
# 向量维度
|
||||||
|
"vector_dim": 2048,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 数据库表名
|
||||||
|
TABLE_USER_PREFERENCE_LOG = "user_preference_log_test"
|
||||||
|
TABLE_SYS_FILE = "t_sys_file"
|
||||||
|
|
||||||
|
# MySQL 连接配置(用于推荐系统)
|
||||||
|
MYSQL_CONFIG = {
|
||||||
|
"host": DB_HOST,
|
||||||
|
"port": DB_PORT,
|
||||||
|
"user": DB_USERNAME,
|
||||||
|
"password": DB_PASSWORD,
|
||||||
|
"database": DB_NAME,
|
||||||
|
"charset": "utf8mb4"
|
||||||
|
}
|
||||||
|
|
||||||
331
app/service/recommendation_system/import_sys_sketch_to_milvus.py
Normal file
331
app/service/recommendation_system/import_sys_sketch_to_milvus.py
Normal file
@@ -0,0 +1,331 @@
|
|||||||
|
"""
|
||||||
|
独立脚本:从 t_sys_file 导入系统图向量到 Milvus
|
||||||
|
可以单独运行,不依赖整个项目启动
|
||||||
|
|
||||||
|
使用方法:
|
||||||
|
python -m app.service.recommendation_system.import_sys_sketch_to_milvus
|
||||||
|
或
|
||||||
|
python app/service/recommendation_system/import_sys_sketch_to_milvus.py
|
||||||
|
"""
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# 添加项目根目录到 Python 路径
|
||||||
|
project_root = Path(__file__).parent.parent.parent.parent
|
||||||
|
sys.path.insert(0, str(project_root))
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pymysql
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from app.service.recommendation_system.config import (
|
||||||
|
MYSQL_CONFIG, TABLE_SYS_FILE,
|
||||||
|
RECOMMENDATION_CONFIG, MILVUS_COLLECTION_SKETCH_VECTORS
|
||||||
|
)
|
||||||
|
from app.service.recommendation_system.vector_utils import extract_feature_vector, normalize_vector
|
||||||
|
from app.service.recommendation_system.milvus_client import create_collection, insert_vectors
|
||||||
|
|
||||||
|
# 配置日志
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format='%(asctime)s - %(levelname)s - %(message)s',
|
||||||
|
handlers=[
|
||||||
|
logging.StreamHandler(),
|
||||||
|
logging.FileHandler('import_sys_sketch.log', encoding='utf-8')
|
||||||
|
]
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_sys_file_records(conn, limit=None, offset=0):
|
||||||
|
"""
|
||||||
|
从 t_sys_file 表获取系统图记录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
conn: 数据库连接
|
||||||
|
limit: 限制数量(None 表示不限制)
|
||||||
|
offset: 偏移量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
记录列表,每个元素为 (id, url, style, level3_type, level2_type, deprecated)
|
||||||
|
"""
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
query = f"""
|
||||||
|
SELECT id, url, style, level3_type, level2_type, deprecated
|
||||||
|
FROM {TABLE_SYS_FILE}
|
||||||
|
WHERE level1_type = 'Images'
|
||||||
|
AND style IS NOT NULL
|
||||||
|
AND style != ''
|
||||||
|
AND deprecated != 1
|
||||||
|
ORDER BY id
|
||||||
|
"""
|
||||||
|
|
||||||
|
if limit:
|
||||||
|
query += f" LIMIT {limit} OFFSET {offset}"
|
||||||
|
|
||||||
|
cursor.execute(query)
|
||||||
|
records = cursor.fetchall()
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
return records
|
||||||
|
|
||||||
|
|
||||||
|
def get_total_count(conn):
|
||||||
|
"""获取总记录数"""
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute(f"""
|
||||||
|
SELECT COUNT(*)
|
||||||
|
FROM {TABLE_SYS_FILE}
|
||||||
|
WHERE level1_type = 'Images'
|
||||||
|
AND style IS NOT NULL
|
||||||
|
AND style != ''
|
||||||
|
AND deprecated != 1
|
||||||
|
""")
|
||||||
|
count = cursor.fetchone()[0]
|
||||||
|
cursor.close()
|
||||||
|
return count
|
||||||
|
|
||||||
|
|
||||||
|
def process_and_insert_batch(records, batch_size=1000, retry_times=3):
|
||||||
|
"""
|
||||||
|
处理并批量插入向量
|
||||||
|
|
||||||
|
Args:
|
||||||
|
records: 记录列表
|
||||||
|
batch_size: 批量大小
|
||||||
|
retry_times: 失败重试次数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(成功数量, 失败数量)
|
||||||
|
"""
|
||||||
|
success_count = 0
|
||||||
|
failed_count = 0
|
||||||
|
failed_records = []
|
||||||
|
batch_data = []
|
||||||
|
|
||||||
|
# 使用 tqdm 显示进度
|
||||||
|
with tqdm(total=len(records), desc="处理记录", unit="条") as pbar:
|
||||||
|
for idx, (sys_file_id, url, style, level3_type, level2_type, deprecated) in enumerate(records):
|
||||||
|
try:
|
||||||
|
# 计算 category
|
||||||
|
category = f"{level3_type.lower()}_{level2_type.lower()}"
|
||||||
|
|
||||||
|
# 提取特征向量
|
||||||
|
feature_vector = extract_feature_vector(url)
|
||||||
|
# 归一化,便于 IP≈cosine 度量
|
||||||
|
feature_vector = normalize_vector(feature_vector)
|
||||||
|
|
||||||
|
# 检查向量是否有效
|
||||||
|
if np.all(feature_vector == 0):
|
||||||
|
logger.warning(f"向量提取失败,跳过: {url} (id={sys_file_id})")
|
||||||
|
failed_count += 1
|
||||||
|
failed_records.append((sys_file_id, url))
|
||||||
|
pbar.update(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 准备数据
|
||||||
|
data_item = {
|
||||||
|
"path": url,
|
||||||
|
"sys_file_id": sys_file_id,
|
||||||
|
"style": style,
|
||||||
|
"category": category,
|
||||||
|
"is_system_sketch": 1,
|
||||||
|
"deprecated": deprecated if deprecated else 0,
|
||||||
|
"feature_vector": feature_vector.tolist()
|
||||||
|
}
|
||||||
|
|
||||||
|
batch_data.append(data_item)
|
||||||
|
|
||||||
|
# 批量写入
|
||||||
|
if len(batch_data) >= batch_size:
|
||||||
|
try:
|
||||||
|
insert_vectors(batch_data)
|
||||||
|
success_count += len(batch_data)
|
||||||
|
batch_data = []
|
||||||
|
logger.info(f"已成功插入 {success_count} 条记录")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"批量写入失败: {e}")
|
||||||
|
failed_count += len(batch_data)
|
||||||
|
failed_records.extend([(item["sys_file_id"], item["path"]) for item in batch_data])
|
||||||
|
batch_data = []
|
||||||
|
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"处理记录失败 [id={sys_file_id}, url={url}]: {e}")
|
||||||
|
failed_count += 1
|
||||||
|
failed_records.append((sys_file_id, url))
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
# 写入剩余数据
|
||||||
|
if batch_data:
|
||||||
|
try:
|
||||||
|
insert_vectors(batch_data)
|
||||||
|
success_count += len(batch_data)
|
||||||
|
logger.info(f"写入剩余 {len(batch_data)} 条记录")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"写入剩余数据失败: {e}")
|
||||||
|
failed_count += len(batch_data)
|
||||||
|
failed_records.extend([(item["sys_file_id"], item["path"]) for item in batch_data])
|
||||||
|
|
||||||
|
# 重试失败记录
|
||||||
|
if failed_records and retry_times > 0:
|
||||||
|
logger.info(f"开始重试 {len(failed_records)} 条失败记录,最多重试 {retry_times} 次...")
|
||||||
|
|
||||||
|
for retry in range(retry_times):
|
||||||
|
if not failed_records:
|
||||||
|
break
|
||||||
|
|
||||||
|
retry_failed = []
|
||||||
|
with tqdm(total=len(failed_records), desc=f"重试第 {retry + 1} 次", unit="条") as pbar:
|
||||||
|
for sys_file_id, url in failed_records:
|
||||||
|
try:
|
||||||
|
# 重新查询记录信息
|
||||||
|
conn = pymysql.connect(**MYSQL_CONFIG)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
cursor.execute(f"""
|
||||||
|
SELECT id, url, style, level3_type, level2_type, deprecated
|
||||||
|
FROM {TABLE_SYS_FILE}
|
||||||
|
WHERE id = %s
|
||||||
|
""", (sys_file_id,))
|
||||||
|
record = cursor.fetchone()
|
||||||
|
cursor.close()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
if not record:
|
||||||
|
retry_failed.append((sys_file_id, url))
|
||||||
|
pbar.update(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
sys_file_id, url, style, level3_type, level2_type, deprecated = record
|
||||||
|
category = f"{level3_type.lower()}_{level2_type.lower()}"
|
||||||
|
|
||||||
|
feature_vector = extract_feature_vector(url)
|
||||||
|
feature_vector = normalize_vector(feature_vector)
|
||||||
|
if np.all(feature_vector == 0):
|
||||||
|
retry_failed.append((sys_file_id, url))
|
||||||
|
pbar.update(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
data_item = {
|
||||||
|
"path": url,
|
||||||
|
"sys_file_id": sys_file_id,
|
||||||
|
"style": style,
|
||||||
|
"category": category,
|
||||||
|
"is_system_sketch": 1,
|
||||||
|
"deprecated": deprecated if deprecated else 0,
|
||||||
|
"feature_vector": feature_vector.tolist()
|
||||||
|
}
|
||||||
|
|
||||||
|
insert_vectors([data_item])
|
||||||
|
success_count += 1
|
||||||
|
failed_count -= 1
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"重试失败 [id={sys_file_id}, url={url}]: {e}")
|
||||||
|
retry_failed.append((sys_file_id, url))
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
failed_records = retry_failed
|
||||||
|
if failed_records:
|
||||||
|
logger.warning(f"第 {retry + 1} 次重试后仍有 {len(failed_records)} 条记录失败")
|
||||||
|
|
||||||
|
return success_count, failed_count, failed_records
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""主函数"""
|
||||||
|
parser = argparse.ArgumentParser(description='从 t_sys_file 导入系统图向量到 Milvus')
|
||||||
|
parser.add_argument('--batch-size', type=int, default=1000, help='批量处理大小(默认:1000)')
|
||||||
|
parser.add_argument('--retry-times', type=int, default=3, help='失败重试次数(默认:3)')
|
||||||
|
parser.add_argument('--limit', type=int, default=None, help='限制处理数量(用于测试,默认:不限制)')
|
||||||
|
parser.add_argument('--offset', type=int, default=0, help='起始偏移量(默认:0)')
|
||||||
|
parser.add_argument('--skip-create-collection', action='store_true', help='跳过创建集合(如果集合已存在)')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
logger.info("=" * 60)
|
||||||
|
logger.info("开始从 t_sys_file 导入系统图向量到 Milvus")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
logger.info(f"配置参数:")
|
||||||
|
logger.info(f" - 批量大小: {args.batch_size}")
|
||||||
|
logger.info(f" - 重试次数: {args.retry_times}")
|
||||||
|
logger.info(f" - 限制数量: {args.limit if args.limit else '不限制'}")
|
||||||
|
logger.info(f" - 起始偏移: {args.offset}")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
|
||||||
|
# 1. 创建 Milvus 集合
|
||||||
|
if not args.skip_create_collection:
|
||||||
|
logger.info("创建 Milvus 集合...")
|
||||||
|
try:
|
||||||
|
create_collection()
|
||||||
|
logger.info("Milvus 集合创建成功(或已存在)")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"创建 Milvus 集合失败: {e}")
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
logger.info("跳过创建集合")
|
||||||
|
|
||||||
|
# 2. 连接数据库
|
||||||
|
logger.info("连接数据库...")
|
||||||
|
try:
|
||||||
|
conn = pymysql.connect(**MYSQL_CONFIG)
|
||||||
|
logger.info("数据库连接成功")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"数据库连接失败: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 3. 获取总记录数
|
||||||
|
total_count = get_total_count(conn)
|
||||||
|
logger.info(f"找到 {total_count} 条系统图记录")
|
||||||
|
|
||||||
|
if total_count == 0:
|
||||||
|
logger.warning("没有找到系统图数据")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 4. 获取记录
|
||||||
|
logger.info("获取记录...")
|
||||||
|
records = get_sys_file_records(conn, limit=args.limit, offset=args.offset)
|
||||||
|
logger.info(f"获取到 {len(records)} 条记录")
|
||||||
|
|
||||||
|
if not records:
|
||||||
|
logger.warning("没有获取到记录")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 5. 处理并插入
|
||||||
|
logger.info("开始处理记录...")
|
||||||
|
success_count, failed_count, failed_records = process_and_insert_batch(
|
||||||
|
records,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
retry_times=args.retry_times
|
||||||
|
)
|
||||||
|
|
||||||
|
# 6. 输出结果
|
||||||
|
logger.info("=" * 60)
|
||||||
|
logger.info("导入完成!")
|
||||||
|
logger.info(f" - 成功: {success_count} 条")
|
||||||
|
logger.info(f" - 失败: {failed_count} 条")
|
||||||
|
if failed_records:
|
||||||
|
logger.warning(f" - 失败记录列表(前10条):")
|
||||||
|
for sys_file_id, url in failed_records[:10]:
|
||||||
|
logger.warning(f" ID={sys_file_id}, URL={url}")
|
||||||
|
if len(failed_records) > 10:
|
||||||
|
logger.warning(f" ... 还有 {len(failed_records) - 10} 条失败记录")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"处理过程中发生错误: {e}", exc_info=True)
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
logger.info("数据库连接已关闭")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
||||||
343
app/service/recommendation_system/incremental_listener.py
Normal file
343
app/service/recommendation_system/incremental_listener.py
Normal file
@@ -0,0 +1,343 @@
|
|||||||
|
"""
|
||||||
|
增量监听模块
|
||||||
|
实时监听 user_preference_log_test 表的新增记录,更新用户偏好向量
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import pymysql
|
||||||
|
import numpy as np
|
||||||
|
from typing import List, Dict, Set, Tuple, Optional
|
||||||
|
from datetime import datetime
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
from apscheduler.schedulers.background import BackgroundScheduler
|
||||||
|
from apscheduler.schedulers.blocking import BlockingScheduler
|
||||||
|
|
||||||
|
from app.service.recommendation_system.config import (
|
||||||
|
MYSQL_CONFIG, TABLE_USER_PREFERENCE_LOG, TABLE_SYS_FILE,
|
||||||
|
RECOMMENDATION_CONFIG, REDIS_KEY_USER_PREF_PREFIX
|
||||||
|
)
|
||||||
|
from app.service.recommendation_system.vector_utils import extract_feature_vector, compute_weighted_average, normalize_vector
|
||||||
|
from app.service.recommendation_system.milvus_client import query_vectors_by_paths, insert_vectors
|
||||||
|
from app.service.utils.redis_utils import Redis
|
||||||
|
import json
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class IncrementalListener:
|
||||||
|
"""增量监听器"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.last_process_time = None
|
||||||
|
self.processed_combinations: Set[Tuple[int, str]] = set() # 已处理的 (account_id, category) 组合
|
||||||
|
self.listen_interval = RECOMMENDATION_CONFIG["listen_interval_sec"]
|
||||||
|
|
||||||
|
def get_new_like_records(self) -> List[Tuple]:
|
||||||
|
"""
|
||||||
|
获取新增点赞记录
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
记录列表,每个元素为 (id, account_id, path, category, style, data_time, is_system_sketch, sys_file_id)
|
||||||
|
"""
|
||||||
|
conn = None
|
||||||
|
try:
|
||||||
|
conn = pymysql.connect(**MYSQL_CONFIG)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
if self.last_process_time is None:
|
||||||
|
# 第一次运行,查询最近30分钟的数据
|
||||||
|
cursor.execute(f"""
|
||||||
|
SELECT id, account_id, path, category, style, data_time, is_system_sketch, sys_file_id
|
||||||
|
FROM {TABLE_USER_PREFERENCE_LOG}
|
||||||
|
WHERE data_time > DATE_SUB(NOW(), INTERVAL 30 MINUTE)
|
||||||
|
ORDER BY data_time
|
||||||
|
""")
|
||||||
|
else:
|
||||||
|
# 基于上次处理时间查询
|
||||||
|
cursor.execute(f"""
|
||||||
|
SELECT id, account_id, path, category, style, data_time, is_system_sketch, sys_file_id
|
||||||
|
FROM {TABLE_USER_PREFERENCE_LOG}
|
||||||
|
WHERE data_time > %s
|
||||||
|
ORDER BY data_time
|
||||||
|
""", (self.last_process_time,))
|
||||||
|
|
||||||
|
records = cursor.fetchall()
|
||||||
|
return records
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取新增点赞记录失败: {e}", exc_info=True)
|
||||||
|
return []
|
||||||
|
finally:
|
||||||
|
if conn:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def process_new_records(self, records: List[Tuple]):
|
||||||
|
"""
|
||||||
|
处理新增记录
|
||||||
|
|
||||||
|
Args:
|
||||||
|
records: 记录列表
|
||||||
|
"""
|
||||||
|
if not records:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 按用户+类别分组
|
||||||
|
user_category_records = defaultdict(list)
|
||||||
|
for record in records:
|
||||||
|
account_id = record[1]
|
||||||
|
category = record[3]
|
||||||
|
if category: # 只处理有类别的记录
|
||||||
|
user_category_records[(account_id, category)].append(record)
|
||||||
|
|
||||||
|
# 去重:只处理一次每个 (account_id, category) 组合
|
||||||
|
to_process = []
|
||||||
|
for (account_id, category), recs in user_category_records.items():
|
||||||
|
if (account_id, category) not in self.processed_combinations:
|
||||||
|
to_process.append((account_id, category, recs))
|
||||||
|
self.processed_combinations.add((account_id, category))
|
||||||
|
|
||||||
|
logger.info(f"需要处理 {len(to_process)} 个用户-类别组合")
|
||||||
|
|
||||||
|
# 处理每个组合
|
||||||
|
for account_id, category, recs in to_process:
|
||||||
|
try:
|
||||||
|
self.update_user_preference_vector(account_id, category)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"更新用户偏好向量失败 [user={account_id}, category={category}]: {e}", exc_info=True)
|
||||||
|
|
||||||
|
# 更新最后处理时间
|
||||||
|
if records:
|
||||||
|
self.last_process_time = records[-1][5] # data_time
|
||||||
|
# 重置去重集合,确保下次周期不会跳过同一用户-类别
|
||||||
|
self.processed_combinations.clear()
|
||||||
|
|
||||||
|
def update_user_preference_vector(self, account_id: int, category: str):
|
||||||
|
"""
|
||||||
|
更新用户偏好向量
|
||||||
|
|
||||||
|
Args:
|
||||||
|
account_id: 用户ID
|
||||||
|
category: 类别
|
||||||
|
"""
|
||||||
|
conn = None
|
||||||
|
try:
|
||||||
|
conn = pymysql.connect(**MYSQL_CONFIG)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
# 1. 获取该用户该类别的所有点赞记录
|
||||||
|
cursor.execute(f"""
|
||||||
|
SELECT path, data_time
|
||||||
|
FROM {TABLE_USER_PREFERENCE_LOG}
|
||||||
|
WHERE account_id = %s AND category = %s
|
||||||
|
ORDER BY data_time DESC
|
||||||
|
""", (account_id, category))
|
||||||
|
|
||||||
|
like_records = cursor.fetchall()
|
||||||
|
|
||||||
|
if not like_records:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 2. 批量查询点赞次数
|
||||||
|
paths = [r[0] for r in like_records]
|
||||||
|
placeholders = ','.join(['%s'] * len(paths))
|
||||||
|
cursor.execute(f"""
|
||||||
|
SELECT path, COUNT(*) as like_count
|
||||||
|
FROM {TABLE_USER_PREFERENCE_LOG}
|
||||||
|
WHERE account_id = %s AND category = %s AND path IN ({placeholders})
|
||||||
|
GROUP BY path
|
||||||
|
""", (account_id, category) + tuple(paths))
|
||||||
|
|
||||||
|
like_counts = {row[0]: row[1] for row in cursor.fetchall()}
|
||||||
|
|
||||||
|
# 3. 批量获取向量
|
||||||
|
vectors_dict = query_vectors_by_paths(paths)
|
||||||
|
|
||||||
|
# 处理查询不到的 path(新用户图或异常情况)
|
||||||
|
missing_paths = [p for p in paths if p not in vectors_dict]
|
||||||
|
if missing_paths:
|
||||||
|
logger.info(f"用户 {account_id} 类别 {category} 有 {len(missing_paths)} 个 path 需要实时计算向量")
|
||||||
|
self._compute_and_insert_missing_vectors(missing_paths, conn)
|
||||||
|
# 重新查询
|
||||||
|
vectors_dict = query_vectors_by_paths(paths)
|
||||||
|
|
||||||
|
# 4. 计算权重并加权平均
|
||||||
|
vectors = []
|
||||||
|
weights = []
|
||||||
|
K_half = RECOMMENDATION_CONFIG["K_half"]
|
||||||
|
|
||||||
|
for k, (path, data_time) in enumerate(like_records, 1):
|
||||||
|
if path not in vectors_dict:
|
||||||
|
continue
|
||||||
|
|
||||||
|
vector_data = vectors_dict[path]
|
||||||
|
feature_vector = np.array(vector_data["feature_vector"])
|
||||||
|
|
||||||
|
# 时间衰减权重
|
||||||
|
d_k = 0.5 ** (k / K_half)
|
||||||
|
|
||||||
|
# 点赞次数权重
|
||||||
|
like_count = like_counts.get(path, 1)
|
||||||
|
p_i = 1 + math.log(1 + like_count)
|
||||||
|
|
||||||
|
# 综合权重
|
||||||
|
w_i = d_k * p_i
|
||||||
|
|
||||||
|
vectors.append(feature_vector)
|
||||||
|
weights.append(w_i)
|
||||||
|
|
||||||
|
if not vectors:
|
||||||
|
logger.warning(f"用户 {account_id} 类别 {category} 没有有效向量")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 5. 计算加权平均并做 L2 归一化,IP≈cosine
|
||||||
|
preference_vector = compute_weighted_average(vectors, weights)
|
||||||
|
preference_vector = normalize_vector(preference_vector)
|
||||||
|
|
||||||
|
# 6. 写入 Redis
|
||||||
|
key = f"{REDIS_KEY_USER_PREF_PREFIX}:{account_id}:{category}"
|
||||||
|
vector_json = json.dumps(preference_vector.tolist())
|
||||||
|
Redis.write(
|
||||||
|
key=key,
|
||||||
|
value=vector_json,
|
||||||
|
expire=RECOMMENDATION_CONFIG["redis_expire_seconds"]
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"用户偏好向量更新成功 [user={account_id}, category={category}]")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"更新用户偏好向量失败 [user={account_id}, category={category}]: {e}", exc_info=True)
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
if conn:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def _compute_and_insert_missing_vectors(self, paths: List[str], conn: pymysql.connections.Connection):
|
||||||
|
"""
|
||||||
|
计算并插入缺失的向量
|
||||||
|
|
||||||
|
Args:
|
||||||
|
paths: 缺失的 path 列表
|
||||||
|
conn: 数据库连接
|
||||||
|
"""
|
||||||
|
cursor = conn.cursor()
|
||||||
|
data_to_insert = []
|
||||||
|
|
||||||
|
for path in paths:
|
||||||
|
try:
|
||||||
|
# 判断数据来源(查询 t_sys_file 表)
|
||||||
|
cursor.execute(f"""
|
||||||
|
SELECT id, url, style, level3_type, level2_type, deprecated
|
||||||
|
FROM {TABLE_SYS_FILE}
|
||||||
|
WHERE url = %s
|
||||||
|
LIMIT 1
|
||||||
|
""", (path,))
|
||||||
|
|
||||||
|
sys_file = cursor.fetchone()
|
||||||
|
|
||||||
|
# 提取特征向量
|
||||||
|
feature_vector = extract_feature_vector(path)
|
||||||
|
|
||||||
|
if np.all(feature_vector == 0):
|
||||||
|
logger.warning(f"向量提取失败,跳过: {path}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if sys_file:
|
||||||
|
# 系统图
|
||||||
|
sys_file_id, url, style, level3_type, level2_type, deprecated = sys_file
|
||||||
|
category = f"{level3_type.lower()}_{level2_type.lower()}"
|
||||||
|
|
||||||
|
data_item = {
|
||||||
|
"path": path,
|
||||||
|
"sys_file_id": sys_file_id,
|
||||||
|
"style": style,
|
||||||
|
"category": category,
|
||||||
|
"is_system_sketch": 1,
|
||||||
|
"deprecated": deprecated if deprecated else 0,
|
||||||
|
"feature_vector": feature_vector.tolist()
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# 用户图
|
||||||
|
# 从 user_preference_log_test 获取 category(如果有)
|
||||||
|
cursor.execute(f"""
|
||||||
|
SELECT category
|
||||||
|
FROM {TABLE_USER_PREFERENCE_LOG}
|
||||||
|
WHERE path = %s AND category IS NOT NULL
|
||||||
|
LIMIT 1
|
||||||
|
""", (path,))
|
||||||
|
|
||||||
|
category_result = cursor.fetchone()
|
||||||
|
category = category_result[0] if category_result else None
|
||||||
|
|
||||||
|
data_item = {
|
||||||
|
"path": path,
|
||||||
|
"sys_file_id": None,
|
||||||
|
"style": None,
|
||||||
|
"category": category,
|
||||||
|
"is_system_sketch": 0,
|
||||||
|
"deprecated": 0,
|
||||||
|
"feature_vector": feature_vector.tolist()
|
||||||
|
}
|
||||||
|
|
||||||
|
data_to_insert.append(data_item)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"处理缺失向量失败 [{path}]: {e}")
|
||||||
|
|
||||||
|
# 批量插入
|
||||||
|
if data_to_insert:
|
||||||
|
try:
|
||||||
|
insert_vectors(data_to_insert)
|
||||||
|
logger.info(f"成功插入 {len(data_to_insert)} 个缺失向量")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"插入缺失向量失败: {e}")
|
||||||
|
|
||||||
|
def process_once(self):
|
||||||
|
"""单次轮询任务,供调度器调用"""
|
||||||
|
try:
|
||||||
|
records = self.get_new_like_records()
|
||||||
|
|
||||||
|
if records:
|
||||||
|
logger.info(f"发现 {len(records)} 条新增记录")
|
||||||
|
self.process_new_records(records)
|
||||||
|
else:
|
||||||
|
logger.debug("没有新增记录")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"监听轮询异常: {e}", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
|
def start_background_listener(scheduler: BackgroundScheduler):
|
||||||
|
"""将增量监听任务注册到后台调度器"""
|
||||||
|
listener = IncrementalListener()
|
||||||
|
scheduler.add_job(
|
||||||
|
listener.process_once,
|
||||||
|
"interval",
|
||||||
|
seconds=listener.listen_interval,
|
||||||
|
max_instances=1,
|
||||||
|
coalesce=True,
|
||||||
|
id="recommendation_incremental_listener",
|
||||||
|
replace_existing=True,
|
||||||
|
)
|
||||||
|
logger.info("增量监听任务已注册到调度器")
|
||||||
|
|
||||||
|
|
||||||
|
def start_blocking_listener():
|
||||||
|
"""以阻塞方式启动调度器(用于独立脚本运行)"""
|
||||||
|
listener = IncrementalListener()
|
||||||
|
scheduler = BlockingScheduler()
|
||||||
|
scheduler.add_job(
|
||||||
|
listener.process_once,
|
||||||
|
"interval",
|
||||||
|
seconds=listener.listen_interval,
|
||||||
|
max_instances=1,
|
||||||
|
coalesce=True,
|
||||||
|
id="recommendation_incremental_listener",
|
||||||
|
replace_existing=True,
|
||||||
|
)
|
||||||
|
logger.info("增量监听调度器已启动(BlockingScheduler)")
|
||||||
|
scheduler.start()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
start_blocking_listener()
|
||||||
|
|
||||||
295
app/service/recommendation_system/milvus_client.py
Normal file
295
app/service/recommendation_system/milvus_client.py
Normal file
@@ -0,0 +1,295 @@
|
|||||||
|
"""
|
||||||
|
Milvus 客户端封装
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
from typing import List, Dict, Optional, Any
|
||||||
|
import numpy as np
|
||||||
|
from pymilvus import MilvusClient, FieldSchema, CollectionSchema, DataType, connections, Collection
|
||||||
|
|
||||||
|
from app.core.config import MILVUS_URL, MILVUS_TOKEN, MILVUS_ALIAS
|
||||||
|
from app.service.recommendation_system.config import MILVUS_COLLECTION_SKETCH_VECTORS, RECOMMENDATION_CONFIG
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Milvus 客户端(单例)
|
||||||
|
_milvus_client = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_milvus_client() -> MilvusClient:
|
||||||
|
"""获取 Milvus 客户端(单例模式)"""
|
||||||
|
global _milvus_client
|
||||||
|
if _milvus_client is None:
|
||||||
|
try:
|
||||||
|
_milvus_client = MilvusClient(
|
||||||
|
uri=MILVUS_URL,
|
||||||
|
token=MILVUS_TOKEN,
|
||||||
|
db_name=MILVUS_ALIAS
|
||||||
|
)
|
||||||
|
logger.info("Milvus 客户端连接成功")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Milvus 客户端连接失败: {e}")
|
||||||
|
raise
|
||||||
|
return _milvus_client
|
||||||
|
|
||||||
|
|
||||||
|
def create_collection():
|
||||||
|
"""
|
||||||
|
创建 Milvus 集合 sketch_vectors
|
||||||
|
|
||||||
|
集合结构:
|
||||||
|
- path (PK, varchar(512)) - 主键,MinIO 逻辑 URL
|
||||||
|
- sys_file_id (int64, 可为NULL) - 系统文件ID
|
||||||
|
- style (varchar(50), 可为NULL) - 风格样式
|
||||||
|
- category (varchar(100), 可为NULL) - 类别
|
||||||
|
- is_system_sketch (int8, 默认 1) - 标记字段:1-系统图,0-用户图
|
||||||
|
- deprecated (int8, 默认 0) - 是否废弃
|
||||||
|
- feature_vector (FloatVector(2048)) - 2048维特征向量
|
||||||
|
"""
|
||||||
|
client = get_milvus_client()
|
||||||
|
|
||||||
|
# 检查集合是否已存在
|
||||||
|
collections = client.list_collections()
|
||||||
|
if MILVUS_COLLECTION_SKETCH_VECTORS in collections:
|
||||||
|
logger.info(f"集合 {MILVUS_COLLECTION_SKETCH_VECTORS} 已存在")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 解析 Milvus URL
|
||||||
|
# 处理 http://host.docker.internal:19530 格式
|
||||||
|
url_clean = MILVUS_URL.replace("http://", "").replace("https://", "")
|
||||||
|
if ":" in url_clean:
|
||||||
|
host, port_str = url_clean.split(":", 1)
|
||||||
|
port = int(port_str)
|
||||||
|
else:
|
||||||
|
host = url_clean
|
||||||
|
port = 19530
|
||||||
|
|
||||||
|
# 使用传统 API 创建集合(更可靠)
|
||||||
|
# 连接到 Milvus(如果未连接)
|
||||||
|
try:
|
||||||
|
connections.connect(
|
||||||
|
alias=MILVUS_ALIAS,
|
||||||
|
host=host,
|
||||||
|
port=port,
|
||||||
|
token=MILVUS_TOKEN if MILVUS_TOKEN else None
|
||||||
|
)
|
||||||
|
logger.info(f"已连接到 Milvus: {host}:{port}")
|
||||||
|
except Exception as conn_e:
|
||||||
|
# 如果连接已存在,忽略错误
|
||||||
|
if "already exists" in str(conn_e).lower() or "Connection already exists" in str(conn_e):
|
||||||
|
logger.info("Milvus 连接已存在")
|
||||||
|
else:
|
||||||
|
logger.warning(f"连接 Milvus 时出现警告: {conn_e}")
|
||||||
|
|
||||||
|
# 定义字段
|
||||||
|
fields = [
|
||||||
|
FieldSchema(name="path", dtype=DataType.VARCHAR, is_primary=True, max_length=512),
|
||||||
|
FieldSchema(name="sys_file_id", dtype=DataType.INT64),
|
||||||
|
FieldSchema(name="style", dtype=DataType.VARCHAR, max_length=50),
|
||||||
|
FieldSchema(name="category", dtype=DataType.VARCHAR, max_length=50),
|
||||||
|
FieldSchema(name="is_system_sketch", dtype=DataType.INT8),
|
||||||
|
FieldSchema(name="deprecated", dtype=DataType.INT8),
|
||||||
|
FieldSchema(
|
||||||
|
name="feature_vector",
|
||||||
|
dtype=DataType.FLOAT_VECTOR,
|
||||||
|
dim=RECOMMENDATION_CONFIG["vector_dim"]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
# 创建 schema
|
||||||
|
schema = CollectionSchema(
|
||||||
|
fields=fields,
|
||||||
|
description="Sketch vectors collection for recommendation system"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建集合
|
||||||
|
collection = Collection(
|
||||||
|
name=MILVUS_COLLECTION_SKETCH_VECTORS,
|
||||||
|
schema=schema,
|
||||||
|
using=MILVUS_ALIAS
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建索引
|
||||||
|
# 注意:使用 IP(内积)作为度量类型,与搜索时保持一致
|
||||||
|
# 如果向量已归一化,IP 等价于 COSINE
|
||||||
|
index_params = {
|
||||||
|
"metric_type": "IP", # 内积(Inner Product)
|
||||||
|
"index_type": "IVF_FLAT",
|
||||||
|
"params": {"nlist": 1024}
|
||||||
|
}
|
||||||
|
|
||||||
|
collection.create_index(
|
||||||
|
field_name="feature_vector",
|
||||||
|
index_params=index_params
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"集合 {MILVUS_COLLECTION_SKETCH_VECTORS} 创建成功")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"创建集合失败: {e}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def insert_vectors(data: List[Dict[str, Any]]):
|
||||||
|
"""
|
||||||
|
批量插入向量到 Milvus
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: 数据列表,每个元素包含:
|
||||||
|
- path: str
|
||||||
|
- sys_file_id: int (可选)
|
||||||
|
- style: str (可选)
|
||||||
|
- category: str (可选)
|
||||||
|
- is_system_sketch: int (默认 1)
|
||||||
|
- deprecated: int (默认 0)
|
||||||
|
- feature_vector: List[float] (2048维)
|
||||||
|
"""
|
||||||
|
if not data:
|
||||||
|
return
|
||||||
|
|
||||||
|
client = get_milvus_client()
|
||||||
|
|
||||||
|
try:
|
||||||
|
client.insert(
|
||||||
|
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
|
||||||
|
data=data
|
||||||
|
)
|
||||||
|
logger.info(f"成功插入 {len(data)} 条向量数据")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"插入向量失败: {e}", exc_info=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def query_vectors_by_paths(paths: List[str]) -> Dict[str, Dict]:
|
||||||
|
"""
|
||||||
|
根据 path 列表批量查询向量
|
||||||
|
|
||||||
|
Args:
|
||||||
|
paths: path 列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
{path: {feature_vector: [...], ...}} 字典
|
||||||
|
"""
|
||||||
|
if not paths:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
client = get_milvus_client()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 构建查询表达式
|
||||||
|
# 使用 filter 参数而不是 expr(根据 pymilvus MilvusClient API)
|
||||||
|
# 对于字符串列表,使用单引号包裹每个值
|
||||||
|
path_list = ", ".join([f"'{p}'" for p in paths])
|
||||||
|
filter_expr = f"path in [{path_list}]"
|
||||||
|
|
||||||
|
results = client.query(
|
||||||
|
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
|
||||||
|
filter=filter_expr,
|
||||||
|
output_fields=["path", "feature_vector", "style", "category", "sys_file_id", "is_system_sketch", "deprecated"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# 转换为字典
|
||||||
|
result_dict = {}
|
||||||
|
for r in results:
|
||||||
|
result_dict[r["path"]] = r
|
||||||
|
|
||||||
|
return result_dict
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"查询向量失败: {e}", exc_info=True)
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def search_similar_vectors(
|
||||||
|
query_vector: np.ndarray,
|
||||||
|
category: str,
|
||||||
|
topk: int = 500,
|
||||||
|
style: Optional[str] = None
|
||||||
|
) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
向量相似度检索
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query_vector: 查询向量(2048维)
|
||||||
|
category: 类别过滤
|
||||||
|
topk: 返回数量
|
||||||
|
style: 风格过滤(可选)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
检索结果列表,每个元素包含 path, score, style, category 等字段
|
||||||
|
"""
|
||||||
|
client = get_milvus_client()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 构建过滤表达式
|
||||||
|
# 使用 filter 参数而不是 expr(根据 pymilvus MilvusClient API)
|
||||||
|
filter_expr = f"category == '{category}' && deprecated == 0"
|
||||||
|
if style:
|
||||||
|
filter_expr += f" && style == '{style}'"
|
||||||
|
|
||||||
|
# 搜索
|
||||||
|
results = client.search(
|
||||||
|
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
|
||||||
|
data=[query_vector.tolist()],
|
||||||
|
anns_field="feature_vector",
|
||||||
|
search_params={"metric_type": "IP", "params": {"nprobe": 10}},
|
||||||
|
limit=topk,
|
||||||
|
filter=filter_expr,
|
||||||
|
output_fields=["path", "style", "category", "sys_file_id"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# 格式化结果
|
||||||
|
formatted_results = []
|
||||||
|
if results and len(results) > 0:
|
||||||
|
for hit in results[0]:
|
||||||
|
formatted_results.append({
|
||||||
|
"path": hit.get("entity", {}).get("path", ""),
|
||||||
|
"score": hit.get("distance", 0.0),
|
||||||
|
"style": hit.get("entity", {}).get("style", ""),
|
||||||
|
"category": hit.get("entity", {}).get("category", ""),
|
||||||
|
"sys_file_id": hit.get("entity", {}).get("sys_file_id")
|
||||||
|
})
|
||||||
|
|
||||||
|
return formatted_results
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"向量检索失败: {e}", exc_info=True)
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def query_random_candidates(category: str, style: Optional[str] = None, limit: int = 10) -> List[Dict]:
|
||||||
|
"""
|
||||||
|
随机查询候选(用于探索分支)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
category: 类别
|
||||||
|
style: 风格(可选)
|
||||||
|
limit: 返回数量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
候选列表
|
||||||
|
"""
|
||||||
|
client = get_milvus_client()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 构建过滤表达式
|
||||||
|
filter_expr = f"category == '{category}' && deprecated == 0"
|
||||||
|
if style:
|
||||||
|
filter_expr += f" && style == '{style}'"
|
||||||
|
|
||||||
|
# 查询所有符合条件的记录
|
||||||
|
results = client.query(
|
||||||
|
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
|
||||||
|
filter=filter_expr,
|
||||||
|
output_fields=["path", "style", "category"],
|
||||||
|
limit=10000 # 先查询大量数据,然后随机选择
|
||||||
|
)
|
||||||
|
|
||||||
|
# 随机选择
|
||||||
|
if len(results) > limit:
|
||||||
|
import random
|
||||||
|
results = random.sample(results, limit)
|
||||||
|
|
||||||
|
return results
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"随机查询候选失败: {e}", exc_info=True)
|
||||||
|
return []
|
||||||
|
|
||||||
556
app/service/recommendation_system/precompute.py
Normal file
556
app/service/recommendation_system/precompute.py
Normal file
@@ -0,0 +1,556 @@
|
|||||||
|
"""
|
||||||
|
预计算模块
|
||||||
|
包含:数据库表结构优化、Milvus集合创建、系统图向量预计算、初始用户偏好向量生成
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import pymysql
|
||||||
|
import numpy as np
|
||||||
|
from typing import List, Dict, Tuple, Optional
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
from app.service.recommendation_system.config import (
|
||||||
|
MYSQL_CONFIG, TABLE_USER_PREFERENCE_LOG, TABLE_SYS_FILE,
|
||||||
|
RECOMMENDATION_CONFIG, REDIS_KEY_USER_PREF_PREFIX
|
||||||
|
)
|
||||||
|
from app.service.recommendation_system.vector_utils import extract_feature_vector, normalize_vector, compute_weighted_average
|
||||||
|
from app.service.recommendation_system.milvus_client import (
|
||||||
|
create_collection, insert_vectors, query_vectors_by_paths
|
||||||
|
)
|
||||||
|
from app.service.utils.redis_utils import Redis
|
||||||
|
import json
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def optimize_database_table():
|
||||||
|
"""
|
||||||
|
优化 user_preference_log_test 表结构
|
||||||
|
添加冗余字段和索引
|
||||||
|
"""
|
||||||
|
conn = None
|
||||||
|
try:
|
||||||
|
conn = pymysql.connect(**MYSQL_CONFIG)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
# 1. 添加冗余字段
|
||||||
|
logger.info("添加冗余字段...")
|
||||||
|
alter_sqls = [
|
||||||
|
f"ALTER TABLE {TABLE_USER_PREFERENCE_LOG} ADD COLUMN category VARCHAR(100) COMMENT '类别:lower(level3_type + \"_\" + level2_type)'",
|
||||||
|
f"ALTER TABLE {TABLE_USER_PREFERENCE_LOG} ADD COLUMN style VARCHAR(50) COMMENT '风格样式'",
|
||||||
|
f"ALTER TABLE {TABLE_USER_PREFERENCE_LOG} ADD COLUMN is_system_sketch TINYINT(1) DEFAULT 1 COMMENT '是否为系统图(1-是,0-用户图)'",
|
||||||
|
f"ALTER TABLE {TABLE_USER_PREFERENCE_LOG} ADD COLUMN sys_file_id BIGINT NULL COMMENT '系统文件ID'",
|
||||||
|
]
|
||||||
|
|
||||||
|
for sql in alter_sqls:
|
||||||
|
try:
|
||||||
|
cursor.execute(sql)
|
||||||
|
logger.info(f"执行成功: {sql[:50]}...")
|
||||||
|
except Exception as e:
|
||||||
|
if "Duplicate column name" in str(e):
|
||||||
|
logger.info(f"字段已存在,跳过: {sql[:50]}...")
|
||||||
|
else:
|
||||||
|
logger.warning(f"执行失败: {sql[:50]}... 错误: {e}")
|
||||||
|
|
||||||
|
# 2. 创建索引(MySQL 不支持 IF NOT EXISTS,需要先检查)
|
||||||
|
logger.info("创建索引...")
|
||||||
|
index_definitions = [
|
||||||
|
("idx_account_category_time", ["account_id", "category", "data_time"]),
|
||||||
|
("idx_account_path", ["account_id", "path"]),
|
||||||
|
]
|
||||||
|
|
||||||
|
for index_name, columns in index_definitions:
|
||||||
|
try:
|
||||||
|
# 检查索引是否已存在
|
||||||
|
cursor.execute(f"""
|
||||||
|
SELECT COUNT(*)
|
||||||
|
FROM information_schema.statistics
|
||||||
|
WHERE table_schema = DATABASE()
|
||||||
|
AND table_name = '{TABLE_USER_PREFERENCE_LOG}'
|
||||||
|
AND index_name = '{index_name}'
|
||||||
|
""")
|
||||||
|
exists = cursor.fetchone()[0] > 0
|
||||||
|
|
||||||
|
if exists:
|
||||||
|
logger.info(f"索引已存在,跳过: {index_name}")
|
||||||
|
else:
|
||||||
|
# 创建索引
|
||||||
|
columns_str = ', '.join(columns)
|
||||||
|
create_sql = f"CREATE INDEX {index_name} ON {TABLE_USER_PREFERENCE_LOG}({columns_str})"
|
||||||
|
cursor.execute(create_sql)
|
||||||
|
logger.info(f"索引创建成功: {index_name}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"索引创建失败: {index_name} 错误: {e}")
|
||||||
|
|
||||||
|
conn.commit()
|
||||||
|
logger.info("数据库表结构优化完成")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"数据库表结构优化失败: {e}", exc_info=True)
|
||||||
|
if conn:
|
||||||
|
conn.rollback()
|
||||||
|
finally:
|
||||||
|
if conn:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
def migrate_historical_data(batch_size: int = 1000):
|
||||||
|
"""
|
||||||
|
历史数据迁移:批量更新冗余字段
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_size: 每批处理数量
|
||||||
|
"""
|
||||||
|
conn = None
|
||||||
|
try:
|
||||||
|
conn = pymysql.connect(**MYSQL_CONFIG)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
# 查询需要更新的记录数
|
||||||
|
cursor.execute(f"""
|
||||||
|
SELECT COUNT(*)
|
||||||
|
FROM {TABLE_USER_PREFERENCE_LOG} u
|
||||||
|
WHERE u.category IS NULL
|
||||||
|
""")
|
||||||
|
total_count = cursor.fetchone()[0]
|
||||||
|
logger.info(f"需要迁移的记录数: {total_count}")
|
||||||
|
|
||||||
|
if total_count == 0:
|
||||||
|
logger.info("无需迁移数据")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 分批处理
|
||||||
|
offset = 0
|
||||||
|
processed = 0
|
||||||
|
|
||||||
|
while offset < total_count:
|
||||||
|
# 查询一批记录
|
||||||
|
cursor.execute(f"""
|
||||||
|
SELECT u.id, u.path
|
||||||
|
FROM {TABLE_USER_PREFERENCE_LOG} u
|
||||||
|
WHERE u.category IS NULL
|
||||||
|
LIMIT {batch_size} OFFSET {offset}
|
||||||
|
""")
|
||||||
|
records = cursor.fetchall()
|
||||||
|
|
||||||
|
if not records:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 批量更新
|
||||||
|
for record_id, path in records:
|
||||||
|
# 查询 t_sys_file 表
|
||||||
|
cursor.execute(f"""
|
||||||
|
SELECT id, url, style, level3_type, level2_type, deprecated
|
||||||
|
FROM {TABLE_SYS_FILE}
|
||||||
|
WHERE url = %s
|
||||||
|
LIMIT 1
|
||||||
|
""", (path,))
|
||||||
|
|
||||||
|
sys_file = cursor.fetchone()
|
||||||
|
|
||||||
|
if sys_file:
|
||||||
|
# 系统图
|
||||||
|
sys_file_id, url, style, level3_type, level2_type, deprecated = sys_file
|
||||||
|
category = f"{level3_type.lower()}_{level2_type.lower()}"
|
||||||
|
|
||||||
|
cursor.execute(f"""
|
||||||
|
UPDATE {TABLE_USER_PREFERENCE_LOG}
|
||||||
|
SET category = %s,
|
||||||
|
style = %s,
|
||||||
|
is_system_sketch = 1,
|
||||||
|
sys_file_id = %s
|
||||||
|
WHERE id = %s
|
||||||
|
""", (category, style, sys_file_id, record_id))
|
||||||
|
else:
|
||||||
|
# 用户图
|
||||||
|
cursor.execute(f"""
|
||||||
|
UPDATE {TABLE_USER_PREFERENCE_LOG}
|
||||||
|
SET is_system_sketch = 0,
|
||||||
|
category = NULL,
|
||||||
|
style = NULL,
|
||||||
|
sys_file_id = NULL
|
||||||
|
WHERE id = %s
|
||||||
|
""", (record_id,))
|
||||||
|
|
||||||
|
conn.commit()
|
||||||
|
processed += len(records)
|
||||||
|
offset += batch_size
|
||||||
|
logger.info(f"已迁移 {processed}/{total_count} 条记录")
|
||||||
|
|
||||||
|
logger.info("历史数据迁移完成")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"历史数据迁移失败: {e}", exc_info=True)
|
||||||
|
if conn:
|
||||||
|
conn.rollback()
|
||||||
|
finally:
|
||||||
|
if conn:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
def precompute_system_sketch_vectors(batch_size: int = 1000, retry_times: int = 3):
|
||||||
|
"""
|
||||||
|
系统图向量预计算与导入
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_size: 每批处理数量
|
||||||
|
retry_times: 失败重试次数
|
||||||
|
"""
|
||||||
|
conn = None
|
||||||
|
try:
|
||||||
|
conn = pymysql.connect(**MYSQL_CONFIG)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
# 1. 数据筛选
|
||||||
|
logger.info("查询系统图数据...")
|
||||||
|
cursor.execute(f"""
|
||||||
|
SELECT id, url, style, level3_type, level2_type, deprecated
|
||||||
|
FROM {TABLE_SYS_FILE}
|
||||||
|
WHERE level1_type = 'Images'
|
||||||
|
AND style IS NOT NULL
|
||||||
|
AND style != ''
|
||||||
|
AND deprecated != 1
|
||||||
|
""")
|
||||||
|
records = cursor.fetchall()
|
||||||
|
logger.info(f"找到 {len(records)} 条系统图记录")
|
||||||
|
|
||||||
|
if not records:
|
||||||
|
logger.warning("没有找到系统图数据")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 2. 批量处理
|
||||||
|
failed_records = []
|
||||||
|
batch_data = []
|
||||||
|
|
||||||
|
for idx, (sys_file_id, url, style, level3_type, level2_type, deprecated) in enumerate(records, 1):
|
||||||
|
try:
|
||||||
|
# 计算 category
|
||||||
|
category = f"{level3_type.lower()}_{level2_type.lower()}"
|
||||||
|
|
||||||
|
# 提取特征向量
|
||||||
|
feature_vector = extract_feature_vector(url)
|
||||||
|
|
||||||
|
# 检查向量是否有效
|
||||||
|
if np.all(feature_vector == 0):
|
||||||
|
logger.warning(f"向量提取失败,跳过: {url}")
|
||||||
|
failed_records.append((sys_file_id, url))
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 准备数据
|
||||||
|
data_item = {
|
||||||
|
"path": url,
|
||||||
|
"sys_file_id": sys_file_id,
|
||||||
|
"style": style,
|
||||||
|
"category": category,
|
||||||
|
"is_system_sketch": 1,
|
||||||
|
"deprecated": deprecated if deprecated else 0,
|
||||||
|
"feature_vector": feature_vector.tolist()
|
||||||
|
}
|
||||||
|
|
||||||
|
batch_data.append(data_item)
|
||||||
|
|
||||||
|
# 批量写入
|
||||||
|
if len(batch_data) >= batch_size:
|
||||||
|
try:
|
||||||
|
insert_vectors(batch_data)
|
||||||
|
batch_data = []
|
||||||
|
logger.info(f"已处理 {idx}/{len(records)} 条记录")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"批量写入失败: {e}")
|
||||||
|
failed_records.extend([(item["sys_file_id"], item["path"]) for item in batch_data])
|
||||||
|
batch_data = []
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"处理记录失败 [{url}]: {e}")
|
||||||
|
failed_records.append((sys_file_id, url))
|
||||||
|
|
||||||
|
# 写入剩余数据
|
||||||
|
if batch_data:
|
||||||
|
try:
|
||||||
|
insert_vectors(batch_data)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"写入剩余数据失败: {e}")
|
||||||
|
failed_records.extend([(item["sys_file_id"], item["path"]) for item in batch_data])
|
||||||
|
|
||||||
|
# 3. 重试失败记录
|
||||||
|
if failed_records and retry_times > 0:
|
||||||
|
logger.info(f"重试 {len(failed_records)} 条失败记录...")
|
||||||
|
for retry in range(retry_times):
|
||||||
|
retry_failed = []
|
||||||
|
for sys_file_id, url in failed_records:
|
||||||
|
try:
|
||||||
|
category = f"{level3_type.lower()}_{level2_type.lower()}"
|
||||||
|
feature_vector = extract_feature_vector(url)
|
||||||
|
if not np.all(feature_vector == 0):
|
||||||
|
data_item = {
|
||||||
|
"path": url,
|
||||||
|
"sys_file_id": sys_file_id,
|
||||||
|
"style": style,
|
||||||
|
"category": category,
|
||||||
|
"is_system_sketch": 1,
|
||||||
|
"deprecated": 0,
|
||||||
|
"feature_vector": feature_vector.tolist()
|
||||||
|
}
|
||||||
|
insert_vectors([data_item])
|
||||||
|
else:
|
||||||
|
retry_failed.append((sys_file_id, url))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"重试失败 [{url}]: {e}")
|
||||||
|
retry_failed.append((sys_file_id, url))
|
||||||
|
|
||||||
|
failed_records = retry_failed
|
||||||
|
if not failed_records:
|
||||||
|
break
|
||||||
|
|
||||||
|
if failed_records:
|
||||||
|
logger.warning(f"仍有 {len(failed_records)} 条记录处理失败")
|
||||||
|
|
||||||
|
logger.info("系统图向量预计算完成")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"系统图向量预计算失败: {e}", exc_info=True)
|
||||||
|
finally:
|
||||||
|
if conn:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
def compute_user_preference_vector(
|
||||||
|
account_id: int,
|
||||||
|
category: str,
|
||||||
|
conn: Optional[pymysql.connections.Connection] = None
|
||||||
|
# max_date: Optional[datetime] = None
|
||||||
|
) -> Optional[np.ndarray]:
|
||||||
|
"""
|
||||||
|
计算用户偏好向量
|
||||||
|
|
||||||
|
Args:
|
||||||
|
account_id: 用户ID
|
||||||
|
category: 类别
|
||||||
|
conn: 数据库连接(可选)
|
||||||
|
max_date: 最大日期(可选,用于评估时只使用训练集数据)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
用户偏好向量(2048维),失败返回 None
|
||||||
|
"""
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
should_close = False
|
||||||
|
if conn is None:
|
||||||
|
conn = pymysql.connect(**MYSQL_CONFIG)
|
||||||
|
should_close = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
# 1. 获取点赞记录(如果指定了max_date,只查询该日期之前的数据)
|
||||||
|
if max_date:
|
||||||
|
cursor.execute(f"""
|
||||||
|
SELECT path, data_time
|
||||||
|
FROM {TABLE_USER_PREFERENCE_LOG}
|
||||||
|
WHERE account_id = %s AND category = %s AND style is not null
|
||||||
|
AND data_time < %s
|
||||||
|
ORDER BY data_time DESC
|
||||||
|
""", (account_id, category, max_date))
|
||||||
|
else:
|
||||||
|
cursor.execute(f"""
|
||||||
|
SELECT path, data_time
|
||||||
|
FROM {TABLE_USER_PREFERENCE_LOG}
|
||||||
|
WHERE account_id = %s AND category = %s AND style is not null
|
||||||
|
ORDER BY data_time DESC
|
||||||
|
""", (account_id, category))
|
||||||
|
|
||||||
|
like_records = cursor.fetchall()
|
||||||
|
|
||||||
|
if not like_records:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 2. 批量查询点赞次数(如果指定了max_date,只统计该日期之前的点赞)
|
||||||
|
paths = [r[0] for r in like_records]
|
||||||
|
if not paths:
|
||||||
|
return None
|
||||||
|
|
||||||
|
placeholders = ','.join(['%s'] * len(paths))
|
||||||
|
if max_date:
|
||||||
|
cursor.execute(f"""
|
||||||
|
SELECT path, COUNT(*) as like_count
|
||||||
|
FROM {TABLE_USER_PREFERENCE_LOG}
|
||||||
|
WHERE account_id = %s AND category = %s AND path IN ({placeholders})
|
||||||
|
AND data_time < %s
|
||||||
|
GROUP BY path
|
||||||
|
""", (account_id, category) + tuple(paths) + (max_date,))
|
||||||
|
else:
|
||||||
|
cursor.execute(f"""
|
||||||
|
SELECT path, COUNT(*) as like_count
|
||||||
|
FROM {TABLE_USER_PREFERENCE_LOG}
|
||||||
|
WHERE account_id = %s AND category = %s AND path IN ({placeholders})
|
||||||
|
GROUP BY path
|
||||||
|
""", (account_id, category) + tuple(paths))
|
||||||
|
|
||||||
|
like_counts = {row[0]: row[1] for row in cursor.fetchall()}
|
||||||
|
|
||||||
|
# 3. 批量获取向量
|
||||||
|
vectors_dict = query_vectors_by_paths(paths)
|
||||||
|
|
||||||
|
# 处理查询不到的 path(用户图或异常情况)
|
||||||
|
missing_paths = [p for p in paths if p not in vectors_dict]
|
||||||
|
if missing_paths:
|
||||||
|
logger.info(f"用户 {account_id} 类别 {category} 有 {len(missing_paths)} 个 path 需要实时计算向量")
|
||||||
|
# 目前未有非系统图向量,跳过
|
||||||
|
# 这里可以实时计算并写入 Milvus,但为了简化,先跳过
|
||||||
|
# 实际实现中应该调用 vector_utils.extract_feature_vector 并写入 Milvus
|
||||||
|
|
||||||
|
# 4. 计算权重并加权平均
|
||||||
|
vectors = []
|
||||||
|
weights = []
|
||||||
|
K_half = RECOMMENDATION_CONFIG["K_half"]
|
||||||
|
|
||||||
|
for k, (path, data_time) in enumerate(like_records, 1):
|
||||||
|
if path not in vectors_dict:
|
||||||
|
continue
|
||||||
|
|
||||||
|
vector_data = vectors_dict[path]
|
||||||
|
feature_vector = np.array(vector_data["feature_vector"])
|
||||||
|
|
||||||
|
# 时间衰减权重
|
||||||
|
d_k = 0.5 ** (k / K_half)
|
||||||
|
|
||||||
|
# 点赞次数权重
|
||||||
|
like_count = like_counts.get(path, 1)
|
||||||
|
p_i = 1 + math.log(1 + like_count)
|
||||||
|
|
||||||
|
# 综合权重
|
||||||
|
# w_i = d_k * p_i
|
||||||
|
w_i = p_i
|
||||||
|
|
||||||
|
vectors.append(feature_vector)
|
||||||
|
weights.append(w_i)
|
||||||
|
|
||||||
|
if not vectors:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 5. 计算加权平均并做 L2 归一化,IP≈cosine
|
||||||
|
preference_vector = compute_weighted_average(vectors, weights)
|
||||||
|
preference_vector = normalize_vector(preference_vector)
|
||||||
|
|
||||||
|
return preference_vector
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"计算用户偏好向量失败 [user={account_id}, category={category}]: {e}", exc_info=True)
|
||||||
|
return None
|
||||||
|
finally:
|
||||||
|
if should_close and conn:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
def generate_initial_user_preference_vectors(batch_size: int = 100):
|
||||||
|
"""
|
||||||
|
初始用户偏好向量生成
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_size: 每批处理用户数
|
||||||
|
"""
|
||||||
|
conn = None
|
||||||
|
try:
|
||||||
|
conn = pymysql.connect(**MYSQL_CONFIG)
|
||||||
|
cursor = conn.cursor()
|
||||||
|
|
||||||
|
# 1. 扫描历史数据
|
||||||
|
logger.info("扫描用户和类别组合...")
|
||||||
|
cursor.execute(f"""
|
||||||
|
SELECT DISTINCT account_id, category
|
||||||
|
FROM {TABLE_USER_PREFERENCE_LOG}
|
||||||
|
WHERE category IS NOT NULL
|
||||||
|
AND style IS NOT NULL
|
||||||
|
""")
|
||||||
|
|
||||||
|
user_categories = cursor.fetchall()
|
||||||
|
logger.info(f"找到 {len(user_categories)} 个用户-类别组合")
|
||||||
|
|
||||||
|
if not user_categories:
|
||||||
|
logger.warning("没有找到用户-类别组合")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 2. 批量处理
|
||||||
|
processed = 0
|
||||||
|
failed = 0
|
||||||
|
|
||||||
|
for account_id, category in user_categories:
|
||||||
|
try:
|
||||||
|
# 计算偏好向量
|
||||||
|
preference_vector = compute_user_preference_vector(account_id, category, conn)
|
||||||
|
|
||||||
|
if preference_vector is not None:
|
||||||
|
# 写入 Redis
|
||||||
|
key = f"{REDIS_KEY_USER_PREF_PREFIX}:{account_id}:{category}"
|
||||||
|
# 序列化向量(使用 JSON)
|
||||||
|
vector_json = json.dumps(preference_vector.tolist())
|
||||||
|
Redis.write(
|
||||||
|
key=key,
|
||||||
|
value=vector_json,
|
||||||
|
expire=RECOMMENDATION_CONFIG["redis_expire_seconds"]
|
||||||
|
)
|
||||||
|
processed += 1
|
||||||
|
else:
|
||||||
|
failed += 1
|
||||||
|
|
||||||
|
if (processed + failed) % batch_size == 0:
|
||||||
|
logger.info(f"已处理 {processed + failed}/{len(user_categories)} 个组合,成功: {processed}, 失败: {failed}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"处理失败 [user={account_id}, category={category}]: {e}")
|
||||||
|
failed += 1
|
||||||
|
|
||||||
|
logger.info(f"初始用户偏好向量生成完成,成功: {processed}, 失败: {failed}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"初始用户偏好向量生成失败: {e}", exc_info=True)
|
||||||
|
finally:
|
||||||
|
if conn:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
def run_precompute():
|
||||||
|
"""
|
||||||
|
运行所有预计算任务
|
||||||
|
"""
|
||||||
|
logger.info("=" * 50)
|
||||||
|
logger.info("开始预计算任务")
|
||||||
|
logger.info("=" * 50)
|
||||||
|
|
||||||
|
# 1. 优化数据库表结构
|
||||||
|
logger.info("\n[1/5] 优化数据库表结构...")
|
||||||
|
optimize_database_table()
|
||||||
|
|
||||||
|
# # 2. 创建 Milvus 集合
|
||||||
|
# logger.info("\n[2/5] 创建 Milvus 集合...")
|
||||||
|
# create_collection()
|
||||||
|
|
||||||
|
# 3. 历史数据迁移
|
||||||
|
logger.info("\n[3/5] 历史数据迁移...")
|
||||||
|
migrate_historical_data()
|
||||||
|
|
||||||
|
# # 4. 系统图向量预计算
|
||||||
|
# logger.info("\n[4/5] 系统图向量预计算...")
|
||||||
|
# precompute_system_sketch_vectors()
|
||||||
|
|
||||||
|
# 5. 初始用户偏好向量生成
|
||||||
|
logger.info("\n[5/5] 初始用户偏好向量生成...")
|
||||||
|
generate_initial_user_preference_vectors()
|
||||||
|
|
||||||
|
logger.info("=" * 50)
|
||||||
|
logger.info("预计算任务完成")
|
||||||
|
logger.info("=" * 50)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 1. 优化数据库表结构
|
||||||
|
logger.info("\n[1/5] 优化数据库表结构...")
|
||||||
|
optimize_database_table()
|
||||||
|
|
||||||
|
# 3. 历史数据迁移
|
||||||
|
logger.info("\n[3/5] 历史数据迁移...")
|
||||||
|
migrate_historical_data()
|
||||||
|
|
||||||
|
# 5. 初始用户偏好向量生成
|
||||||
|
logger.info("\n[5/5] 初始用户偏好向量生成...")
|
||||||
|
generate_initial_user_preference_vectors()
|
||||||
214
app/service/recommendation_system/recommendation_api.py
Normal file
214
app/service/recommendation_system/recommendation_api.py
Normal file
@@ -0,0 +1,214 @@
|
|||||||
|
"""
|
||||||
|
推荐接口实现
|
||||||
|
实现探索/利用分支、向量检索、Softmax抽样等功能
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
from typing import List, Dict, Optional
|
||||||
|
|
||||||
|
from app.service.recommendation_system.config import RECOMMENDATION_CONFIG, REDIS_KEY_USER_PREF_PREFIX
|
||||||
|
from app.service.recommendation_system.milvus_client import search_similar_vectors, query_random_candidates
|
||||||
|
from app.service.recommendation_system.precompute import compute_user_preference_vector
|
||||||
|
from app.service.recommendation_system.vector_utils import normalize_vector
|
||||||
|
from app.service.utils.redis_utils import Redis
|
||||||
|
import json
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_preference_vector(user_id: int, category: str) -> Optional[np.ndarray]:
|
||||||
|
"""
|
||||||
|
获取用户偏好向量
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: 用户ID
|
||||||
|
category: 类别
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
用户偏好向量(2048维),失败返回 None
|
||||||
|
"""
|
||||||
|
# 1. 从 Redis 获取
|
||||||
|
key = f"{REDIS_KEY_USER_PREF_PREFIX}:{user_id}:{category}"
|
||||||
|
vector_json = Redis.read(key)
|
||||||
|
|
||||||
|
if vector_json:
|
||||||
|
try:
|
||||||
|
vector_list = json.loads(vector_json)
|
||||||
|
return np.array(vector_list, dtype=np.float32)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"解析 Redis 向量失败 [user={user_id}, category={category}]: {e}")
|
||||||
|
|
||||||
|
# 2. 如果不存在,实时计算
|
||||||
|
logger.info(f"Redis 中不存在用户偏好向量,实时计算 [user={user_id}, category={category}]")
|
||||||
|
preference_vector = compute_user_preference_vector(user_id, category)
|
||||||
|
|
||||||
|
if preference_vector is not None:
|
||||||
|
# 写入 Redis
|
||||||
|
vector_json = json.dumps(preference_vector.tolist())
|
||||||
|
Redis.write(
|
||||||
|
key=key,
|
||||||
|
value=vector_json,
|
||||||
|
expire=RECOMMENDATION_CONFIG["redis_expire_seconds"]
|
||||||
|
)
|
||||||
|
|
||||||
|
return preference_vector
|
||||||
|
|
||||||
|
|
||||||
|
def explore_branch(category: str, style: Optional[str] = None) -> List[str]:
|
||||||
|
"""
|
||||||
|
探索分支(随机推荐)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
category: 类别
|
||||||
|
style: 风格(可选)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
推荐结果列表,每个元素包含 path, style, category 等字段
|
||||||
|
"""
|
||||||
|
# 查询候选(随机池)
|
||||||
|
pool_size = 10 # 固定查询10个,然后随机选择
|
||||||
|
|
||||||
|
candidates = query_random_candidates(category, style, limit=pool_size)
|
||||||
|
|
||||||
|
if not candidates:
|
||||||
|
logger.warning(f"探索分支:类别 {category} 没有候选数据")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 随机选择
|
||||||
|
if len(candidates) > 1:
|
||||||
|
import random
|
||||||
|
candidates = random.sample(candidates, 1)
|
||||||
|
|
||||||
|
# 格式化返回结果
|
||||||
|
return [candidate.get("path", "") for candidate in candidates[:1]]
|
||||||
|
|
||||||
|
|
||||||
|
def exploit_branch(
|
||||||
|
user_id: int,
|
||||||
|
category: str,
|
||||||
|
style: Optional[str] = None
|
||||||
|
) -> List[str]:
|
||||||
|
"""
|
||||||
|
利用分支(基于向量相似度推荐)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: 用户ID
|
||||||
|
category: 类别
|
||||||
|
num_recommendations: 返回数量
|
||||||
|
style: 风格(可选,用于加分)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
推荐结果列表,每个元素包含 path, style, category, similarity, sample_score 等字段
|
||||||
|
"""
|
||||||
|
# 1. 获取用户偏好向量
|
||||||
|
embedding = get_user_preference_vector(user_id, category)
|
||||||
|
|
||||||
|
if embedding is None:
|
||||||
|
logger.warning(f"利用分支:无法获取用户偏好向量,回退到探索分支 [user={user_id}, category={category}]")
|
||||||
|
return explore_branch(category, style)
|
||||||
|
|
||||||
|
# 2. Milvus 相似度检索(内积 IP)
|
||||||
|
topk = RECOMMENDATION_CONFIG["topk"]
|
||||||
|
results = search_similar_vectors(embedding, category, topk)
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
logger.warning(f"利用分支:向量检索无结果,回退到探索分支 [user={user_id}, category={category}]")
|
||||||
|
return explore_branch(category, style)
|
||||||
|
|
||||||
|
# 3. Style 加分(可选,需传入 style 参数)
|
||||||
|
style_bonus = RECOMMENDATION_CONFIG["style_bonus"]
|
||||||
|
if style:
|
||||||
|
for result in results:
|
||||||
|
similarity = result["score"]
|
||||||
|
if result.get("style") == style:
|
||||||
|
# 加分:相似度 * (1 + style_bonus)
|
||||||
|
similarity = similarity * (1 + style_bonus)
|
||||||
|
result["final_score"] = similarity
|
||||||
|
else:
|
||||||
|
for result in results:
|
||||||
|
result["final_score"] = result["score"]
|
||||||
|
|
||||||
|
# 4. Softmax 抽样
|
||||||
|
scores = [r["final_score"] for r in results]
|
||||||
|
probabilities = softmax_with_temperature(scores, RECOMMENDATION_CONFIG["softmax_temperature"])
|
||||||
|
|
||||||
|
# 根据概率抽样
|
||||||
|
if not results:
|
||||||
|
return []
|
||||||
|
|
||||||
|
selected_index = np.random.choice(len(results), size=1, p=probabilities, replace=False)
|
||||||
|
selected_results = [results[int(selected_index[0])]]
|
||||||
|
|
||||||
|
# 5. 返回结果
|
||||||
|
return [result.get("path", "") for result in selected_results]
|
||||||
|
|
||||||
|
|
||||||
|
def softmax_with_temperature(scores: List[float], temperature: float = 1.0) -> List[float]:
|
||||||
|
"""
|
||||||
|
Softmax 函数(带温度参数)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scores: 分数列表
|
||||||
|
temperature: 温度参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
概率列表
|
||||||
|
"""
|
||||||
|
if not scores:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 除以温度
|
||||||
|
scaled_scores = [s / temperature for s in scores]
|
||||||
|
|
||||||
|
# 减去最大值(数值稳定性)
|
||||||
|
max_score = max(scaled_scores)
|
||||||
|
exp_scores = [math.exp(s - max_score) for s in scaled_scores]
|
||||||
|
|
||||||
|
# 归一化
|
||||||
|
sum_exp = sum(exp_scores)
|
||||||
|
if sum_exp == 0:
|
||||||
|
# 如果所有分数都是负无穷或非常小,返回均匀分布
|
||||||
|
return [1.0 / len(scores)] * len(scores)
|
||||||
|
|
||||||
|
probabilities = [exp_s / sum_exp for exp_s in exp_scores]
|
||||||
|
return probabilities
|
||||||
|
|
||||||
|
|
||||||
|
def get_recommendations(
|
||||||
|
user_id: int,
|
||||||
|
category: str,
|
||||||
|
style: Optional[str] = None
|
||||||
|
) -> List[str]:
|
||||||
|
"""
|
||||||
|
获取推荐结果(主函数)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: 用户ID
|
||||||
|
category: 类别(如 female_skirt)
|
||||||
|
num_recommendations: 返回推荐数量(默认 1)
|
||||||
|
style: 风格(可选):若传入,则在利用分支对同 style 的候选进行加分
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
推荐结果列表,每个元素包含 path 等字段
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 1. 读取配置参数
|
||||||
|
explore_ratio = RECOMMENDATION_CONFIG["explore_ratio"]
|
||||||
|
|
||||||
|
# 2. 探索/利用决策
|
||||||
|
r = random.random() # 生成随机数 (0-1)
|
||||||
|
|
||||||
|
if r < explore_ratio:
|
||||||
|
logger.debug(f"探索分支 [user={user_id}, category={category}]")
|
||||||
|
return explore_branch(category, style)
|
||||||
|
|
||||||
|
logger.debug(f"利用分支 [user={user_id}, category={category}]")
|
||||||
|
return exploit_branch(user_id, category, style)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取推荐结果失败 [user={user_id}, category={category}]: {e}", exc_info=True)
|
||||||
|
# 容错:回退到探索分支
|
||||||
|
return explore_branch(category, style)
|
||||||
|
|
||||||
189
app/service/recommendation_system/vector_utils.py
Normal file
189
app/service/recommendation_system/vector_utils.py
Normal file
@@ -0,0 +1,189 @@
|
|||||||
|
"""
|
||||||
|
向量计算工具类
|
||||||
|
包含 ResNet50 特征提取、向量归一化等功能
|
||||||
|
"""
|
||||||
|
import io
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torchvision import models, transforms
|
||||||
|
from PIL import Image
|
||||||
|
from minio import Minio
|
||||||
|
|
||||||
|
from app.core.config import MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE
|
||||||
|
from app.service.recommendation_system.config import RECOMMENDATION_CONFIG
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 图像预处理(与ResNet训练时的预处理一致)
|
||||||
|
transform = transforms.Compose([
|
||||||
|
transforms.Resize((224, 224)), # ResNet 要求 224x224 的输入
|
||||||
|
transforms.ToTensor(), # 转换为 Tensor
|
||||||
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # 标准化
|
||||||
|
])
|
||||||
|
|
||||||
|
# 加载预训练的 ResNet50 模型(去掉最后全连接层)
|
||||||
|
_resnet_model = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_resnet_model():
|
||||||
|
"""获取 ResNet50 模型(单例模式)"""
|
||||||
|
global _resnet_model
|
||||||
|
if _resnet_model is None:
|
||||||
|
logger.info("加载 ResNet50 模型...")
|
||||||
|
_resnet_model = models.resnet50(pretrained=True)
|
||||||
|
modules = list(_resnet_model.children())[:-1] # 移除最后的全连接层
|
||||||
|
_resnet_model = torch.nn.Sequential(*modules)
|
||||||
|
_resnet_model.eval() # 设置为评估模式
|
||||||
|
logger.info("ResNet50 模型加载完成")
|
||||||
|
return _resnet_model
|
||||||
|
|
||||||
|
|
||||||
|
# MinIO 客户端(单例)
|
||||||
|
_minio_client = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_minio_client():
|
||||||
|
"""获取 MinIO 客户端(单例模式)"""
|
||||||
|
global _minio_client
|
||||||
|
if _minio_client is None:
|
||||||
|
_minio_client = Minio(
|
||||||
|
MINIO_URL,
|
||||||
|
access_key=MINIO_ACCESS,
|
||||||
|
secret_key=MINIO_SECRET,
|
||||||
|
secure=MINIO_SECURE
|
||||||
|
)
|
||||||
|
return _minio_client
|
||||||
|
|
||||||
|
|
||||||
|
def get_image_from_minio(path: str) -> Image.Image:
|
||||||
|
"""
|
||||||
|
从 MinIO 获取图片
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: MinIO 逻辑 URL,格式如 "bucket_name/object_name"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PIL Image 对象,失败返回 None
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 分割路径,获取桶名和文件路径
|
||||||
|
path_parts = path.split('/', 1)
|
||||||
|
if len(path_parts) != 2:
|
||||||
|
logger.error(f"路径格式错误: {path}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
bucket_name, file_name = path_parts
|
||||||
|
minio_client = get_minio_client()
|
||||||
|
|
||||||
|
# 获取文件
|
||||||
|
obj = minio_client.get_object(bucket_name, file_name)
|
||||||
|
img_data = obj.read() # 读取图像数据
|
||||||
|
img = Image.open(io.BytesIO(img_data)) # 将数据转为图像对象
|
||||||
|
|
||||||
|
return img
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"从 MinIO 获取图片失败 [{path}]: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def extract_feature_vector(path: str) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
使用 ResNet50 提取图片特征向量(2048维)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: MinIO 逻辑 URL
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
2048维特征向量(numpy array),失败返回零向量
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 从 MinIO 获取图像
|
||||||
|
img = get_image_from_minio(path)
|
||||||
|
if img is None:
|
||||||
|
logger.warning(f"无法获取图片,返回零向量: {path}")
|
||||||
|
return np.zeros(RECOMMENDATION_CONFIG["vector_dim"], dtype=np.float32)
|
||||||
|
|
||||||
|
# 预处理
|
||||||
|
# 部分 MinIO 图片可能是 RGBA/CMYK,转换成 RGB 以匹配 3 通道标准化参数
|
||||||
|
if img.mode != "RGB":
|
||||||
|
try:
|
||||||
|
img = img.convert("RGB")
|
||||||
|
except Exception:
|
||||||
|
logger.warning(f"无法转换图片为RGB,返回零向量: {path}")
|
||||||
|
return np.zeros(RECOMMENDATION_CONFIG["vector_dim"], dtype=np.float32)
|
||||||
|
|
||||||
|
img_tensor = transform(img).unsqueeze(0) # 扩展维度以适应批量处理
|
||||||
|
|
||||||
|
# 提取特征
|
||||||
|
resnet_model = get_resnet_model()
|
||||||
|
with torch.no_grad(): # 在不需要计算梯度的情况下进行推断
|
||||||
|
feature_vector = resnet_model(img_tensor) # 获取 ResNet 的输出
|
||||||
|
feature_vector = feature_vector.squeeze().cpu().numpy() # 转换为 NumPy 数组并去掉 batch 维度
|
||||||
|
|
||||||
|
# 确保是 2048 维
|
||||||
|
if feature_vector.ndim > 1:
|
||||||
|
feature_vector = feature_vector.flatten()
|
||||||
|
|
||||||
|
# 确保维度正确
|
||||||
|
if len(feature_vector) != RECOMMENDATION_CONFIG["vector_dim"]:
|
||||||
|
logger.warning(f"向量维度不正确: {len(feature_vector)}, 期望: {RECOMMENDATION_CONFIG['vector_dim']}")
|
||||||
|
# 如果维度不对,尝试调整
|
||||||
|
if len(feature_vector) > RECOMMENDATION_CONFIG["vector_dim"]:
|
||||||
|
feature_vector = feature_vector[:RECOMMENDATION_CONFIG["vector_dim"]]
|
||||||
|
else:
|
||||||
|
padded = np.zeros(RECOMMENDATION_CONFIG["vector_dim"], dtype=np.float32)
|
||||||
|
padded[:len(feature_vector)] = feature_vector
|
||||||
|
feature_vector = padded
|
||||||
|
|
||||||
|
return feature_vector.astype(np.float32)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"提取特征向量失败 [{path}]: {e}", exc_info=True)
|
||||||
|
return np.zeros(RECOMMENDATION_CONFIG["vector_dim"], dtype=np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_vector(vector: np.ndarray) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
L2 归一化向量
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vector: 输入向量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
归一化后的向量
|
||||||
|
"""
|
||||||
|
norm = np.linalg.norm(vector)
|
||||||
|
if norm == 0:
|
||||||
|
return vector
|
||||||
|
return vector / norm
|
||||||
|
|
||||||
|
|
||||||
|
def compute_weighted_average(vectors: list, weights: list) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
计算加权平均向量
|
||||||
|
|
||||||
|
Args:
|
||||||
|
vectors: 向量列表
|
||||||
|
weights: 权重列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
加权平均向量(不做归一化,模长为加权平均后的尺度)
|
||||||
|
"""
|
||||||
|
if not vectors or not weights:
|
||||||
|
return np.zeros(RECOMMENDATION_CONFIG["vector_dim"], dtype=np.float32)
|
||||||
|
|
||||||
|
# 确保所有向量都是 numpy array
|
||||||
|
vectors = [np.array(v) for v in vectors]
|
||||||
|
weights = np.array(weights)
|
||||||
|
|
||||||
|
# 计算加权和
|
||||||
|
weighted_sum = np.zeros_like(vectors[0])
|
||||||
|
for v, w in zip(vectors, weights):
|
||||||
|
weighted_sum += v * w
|
||||||
|
|
||||||
|
# 返回加权平均(除以权重和,不做 L2 归一化,模长不会随条数线性暴涨)
|
||||||
|
weight_total = weights.sum()
|
||||||
|
if weight_total == 0:
|
||||||
|
return weighted_sum
|
||||||
|
return weighted_sum / weight_total
|
||||||
|
|
||||||
Reference in New Issue
Block a user