31 Commits

Author SHA1 Message Date
litianxiang
fb46a9521d Merge remote-tracking branch 'origin/develop' into dev-ltx
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-13 13:57:28 +08:00
litianxiang
b90688f835 更改增量更新日志级别 2026-01-13 13:57:15 +08:00
zcr
7e30779aec feat: seg any thing 新增box模式
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-13 12:43:30 +08:00
zcr
f7294f5966 feat: seg any thing 新增box模式
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-13 12:32:18 +08:00
zcr
0ac5a4e0a8 Merge remote-tracking branch 'origin/develop' into develop
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-12 16:18:15 +08:00
zcr
40b57b749c feat: 新增design模式 merge,前端CV python 合成 2026-01-12 16:18:04 +08:00
litianxiang
b8a538a8a1 fix:增量更新向量问题修改
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-12 13:59:06 +08:00
litianxiang
29b4f43a27 debug:推荐接口
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-12 13:34:56 +08:00
litianxiang
69dc20207d debug:推荐接口
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-12 13:03:58 +08:00
litianxiang
18979af604 debug:推荐接口返回redis值
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-12 13:01:26 +08:00
litianxiang
74406f9be4 推荐接口更新向量接口注册
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-12 11:59:01 +08:00
litianxiang
df99e3ac76 新增查看redis内容接口
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-12 11:51:37 +08:00
litianxiang
19346c2eb7 Merge remote-tracking branch 'origin/develop' into dev-ltx
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-12 09:51:52 +08:00
litianxiang
2af9cbfe78 fix:推荐接口 2026-01-12 09:49:07 +08:00
zcr
fe12b5697d fix: design 镜像默认值修改,旋转方向和前端保持一致
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-09 17:40:49 +08:00
zcr
c04d4877b0 fix: design 回参新增镜像旋转参数
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-09 17:12:53 +08:00
zcr
91016e6cae fix: design 回参新增镜像旋转参数
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-09 17:08:16 +08:00
zcr
0f4bb260ad fix: design 回参新增镜像旋转参数
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-09 17:06:39 +08:00
zcr
c792106f02 fix: design 回参新增镜像旋转参数
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-09 15:42:42 +08:00
zcr
deac5a4cab fix: design item sketch旋转参数为none
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-09 12:31:34 +08:00
zcr
15682036b3 feat : 新增seg anything 接口 ,接口文档补充
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-08 17:39:27 +08:00
zcr
9ba3a0ca49 feat : 新增seg anything 接口
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-08 17:33:54 +08:00
zcr
f6963070fb feat : 支持上下左右同时镜像
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-08 13:47:44 +08:00
zcr
12f5ca3ca3 feat : design 示例说明
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-08 10:44:02 +08:00
zcr
19110f51bf feat : design 示例说明
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-08 10:29:31 +08:00
zcr
e04636ce21 feat : design overall print 新增平铺间距和旋转角度
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-07 17:03:02 +08:00
zcr
2a50e7040e feat : design overall print 新增平铺间距和旋转角度
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-07 16:22:19 +08:00
zcr
a6f3bda9f7 feat : design 单品新增 镜像旋转功能
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-06 12:21:10 +08:00
zcr
c18f45e549 feat : design 单品新增 镜像旋转功能
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2026-01-06 12:00:58 +08:00
zcr
4951fab71a 代码整理
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2025-12-30 17:49:22 +08:00
zcr
aa57478852 新推荐接口first commit
All checks were successful
git commit AiDA python develop 分支构建部署 / scheduled_deploy (push) Has been skipped
2025-12-30 17:35:32 +08:00
24 changed files with 1062 additions and 682 deletions

View File

@@ -1,9 +1,10 @@
import json import json
import logging import logging
import requests
from fastapi import APIRouter, HTTPException, BackgroundTasks from fastapi import APIRouter, HTTPException, BackgroundTasks
from app.schemas.design import DesignModel, ModelProgressModel, DesignStreamModel from app.schemas.design import DesignModel, ModelProgressModel, DesignStreamModel, SAMRequestModel
from app.schemas.response_template import ResponseModel from app.schemas.response_template import ResponseModel
from app.service.design_fast.design_generate import design_generate, design_generate_v2 from app.service.design_fast.design_generate import design_generate, design_generate_v2
from app.service.design_fast.model_process_service import model_transpose from app.service.design_fast.model_process_service import model_transpose
@@ -15,177 +16,141 @@ logger = logging.getLogger()
@router.post("/design") @router.post("/design")
def design(request_data: DesignModel): def design(request_data: DesignModel):
""" """
objects.items.transparent: - **objects.items.transparent**:
"transparent":{ ```json
"mask_url":"test/transparent_test/transparent_mask.png", "transparent":{
"scale":0.1 "mask_url":"test/transparent_test/transparent_mask.png",
}, "scale":0.1
mask_url 为空"" -> 单件衣服透明 },
mask_url 非空"mask_url" -> 区域透明 ```
- **mask_url** 为空"" -> 单件衣服透明
- **mask_url** 非空"mask_url" -> 区域透明
- **transpose** 镜像模式 ,:"top_bottom""left_right"
- **rotate** 45,
创建一个具有以下参数的请求体: - ** design 参数变更:
示例参数: design detail 请求参数中 basic -> preview_submit 替换为design_type 可选参数 default ,merge (移除preview和submit)
{ design_type 参数说明:
"objects": [ defuault模式下 请求参数不变
{ merge模式下 items -> 每个item需要新增 merge_image_path , merge_image_path为前端处理 print color等操作后的单件结果图
"basic": {
"body_point_test": { **
"waistband_right": [
203, - 创建一个具有以下参数的请求体:
249 示例参数:
], ```json
"hand_point_right": [ {
229, "objects": [
343 {
], "basic": {
"waistband_left": [ "body_point_test": {
119, "waistband_right": [
248 203,
], 249
"hand_point_left": [ ],
97, "hand_point_right": [
343 229,
], 343
"shoulder_left": [ ],
108, "waistband_left": [
107 119,
], 248
"shoulder_right": [ ],
212, "hand_point_left": [
107 97,
] 343
}, ],
"layer_order": true, "shoulder_left": [
"preview_submit": "submit", 108,
"scale_bag": 0.7, 107
"scale_earrings": 0.16, ],
"self_template": true, "shoulder_right": [
"single_overall": "overall", 212,
"switch_category": "" 107
]
}, },
"items": [ "layer_order": true,
{ "design_type": "preview",
"businessId": 2377945, "scale_bag": 0.7,
"color": "209 196 171", "scale_earrings": 0.16,
"image_id": 189410, "self_template": true,
"offset": [ "single_overall": "overall",
0, "switch_category": ""
0 },
], "items": [
"path": "aida-collection-element/89/Sketchboard/53d38bd5-f77b-4034-ada2-45f1e2ebe00c.png", {
"print": { "businessId": 2115382,
"element": { "color": "",
"element_angle_list": [], "image_id": 61686,
"element_path_list": [], "offset": [
"element_scale_list": [], 0,
"location": [] 0
}, ],
"overall": { "path": "aida-sys-image/images/female/dress/0628000564.jpg",
"location": [], "transpose": [
"print_angle_list": [], 1,
"print_path_list": [], 1
"print_scale_list": [] ],
}, "rotate": 45,
"single": { "print": {
"location": [], "element": {
"print_angle_list": [], "element_angle_list": [],
"print_path_list": [], "element_path_list": [],
"print_scale_list": [] "element_scale_list": [],
} "location": []
}, },
"priority": 12, "overall": {
"resize_scale": [ "location": [
1.0, [
1.0 53.0,
], 118.5
"seg_mask_url": "aida-clothing/mask/mask_8e96ddb0-e466-11f0-8de2-0242ac130002.png", ]
"type": "Outwear" ],
}, "print_angle_list": [
{ 0.0
"businessId": 2377946, ],
"color": "122 152 139", "print_path_list": [
"image_id": 81868, "aida-users/89/print/02d57aa8-f342-4e1d-b02c-b278f94dcfe6-3-89.png"
"offset": [ ],
0, "print_scale_list": [
0 [
], 0.5,
"path": "aida-sys-image/images/female/blouse/0825001443.jpg", 0.5
"print": { ]
"element": { ],
"element_angle_list": [], "gap": [
"element_path_list": [], [
"element_scale_list": [], 10,
"location": [] 10
}, ]
"overall": { ]
"location": [],
"print_angle_list": [],
"print_path_list": [],
"print_scale_list": []
},
"single": {
"location": [],
"print_angle_list": [],
"print_path_list": [],
"print_scale_list": []
}
}, },
"priority": 11, "single": {
"resize_scale": [ "location": [],
1.0, "print_angle_list": [],
1.0 "print_path_list": [],
], "print_scale_list": []
"seg_mask_url": "aida-clothing/mask/mask_8f0fab78-e466-11f0-8de2-0242ac130002.png", }
"type": "Blouse"
}, },
{ "priority": 10,
"businessId": 2377947, "resize_scale": [
"color": "111 78 63", 1.0,
"gradient": "aida-gradient/517c3a4d-aed7-4423-aa99-7b60d3577df1.png", 1.0
"image_id": 116494, ],
"offset": [ "seg_mask_url": "aida-clothing/mask/mask_9698b428-eb93-11f0-9327-0242c0a80003.png",
0, "type": "Dress"
0 },
], {
"path": "aida-sys-image/images/female/skirt/0825000219.jpg", "body_path": "aida-sys-image/models/female/2e4815b9-1191-419d-94ed-5771239ca4a5.png",
"print": { "image_id": 67277,
"element": { "type": "Body"
"element_angle_list": [], }
"element_path_list": [], ]
"element_scale_list": [], }
"location": [] ],
}, "process_id": "89"
"overall": { }
"location": [], ```
"print_angle_list": [],
"print_path_list": [],
"print_scale_list": []
},
"single": {
"location": [],
"print_angle_list": [],
"print_path_list": [],
"print_scale_list": []
}
},
"priority": 10,
"resize_scale": [
1.0,
1.0
],
"seg_mask_url": "aida-clothing/mask/mask_8f6191fe-e466-11f0-8de2-0242ac130002.png",
"type": "Skirt"
},
{
"body_path": "aida-sys-image/models/female/2e4815b9-1191-419d-94ed-5771239ca4a5.png",
"image_id": 67277,
"type": "Body"
}
]
}
],
"process_id": "89"
}
""" """
# logger.info(f"design request item is : @@@@@@:{json.dumps(request_data.dict(),indent=4)}") # logger.info(f"design request item is : @@@@@@:{json.dumps(request_data.dict(),indent=4)}")
# data = generate(request_data=request_data) # data = generate(request_data=request_data)
@@ -421,6 +386,52 @@ async def design_v2(request_data: DesignStreamModel, background_tasks: Backgroun
return ResponseModel() return ResponseModel()
@router.post("/seg_anything")
async def seg_anything(request_data: SAMRequestModel):
"""
**Segment Anything 交互式分割接口**
通过传入图片路径和点击的点坐标,返回分割后的掩码数据。
### 参数说明:
- **user_id**:用户id 用于存储分割图
- **image_path**: 图片在服务器或云端的相对路径。
- **type**: 推理类型
- **box**: 框选矩形点位信息
- **points**: 交互点的坐标列表。每个点为 [x, y] 像素格式。
- **labels**: 坐标点的属性标签,必须与 points 长度一致:
- 1: **前景点** (代表想要分割出的区域)
- 0: **背景点** (代表想要排除的区域)
### 请求体示例:
```json
point
{
"user_id": 1,
"image_path": "aida-users/89/sketch/4e8fe37d-7068-400a-ac94-c01647fa5f6f.png",
"type":"point",
"points": [[310, 403], [493, 375], [261, 266], [404, 484]],
"labels": [1, 1, 0, 1]
}
box
{
"user_id": 1,
"image_path": "aida-users/89/sketch/4e8fe37d-7068-400a-ac94-c01647fa5f6f.png",
"type":"box",
"box": [350, 286, 544, 520]
}
```
"""
try:
logger.info(f"seg_anything request item is : @@@@@@:{json.dumps(request_data.dict(), indent=4)}")
data = requests.post("http://10.1.1.240:10075/predict", json=request_data.dict())
logger.info(f"seg_anything response @@@@@@:{json.dumps(json.loads(data.content), indent=4)}")
return ResponseModel(data=json.loads(data.content))
except Exception as e:
logger.warning(f"seg_anything Run Exception @@@@@@:{e}")
# @router.post('/get_progress') # @router.post('/get_progress')
# def get_progress(request_data: DesignProgressModel): # def get_progress(request_data: DesignProgressModel):
# """ # """

View File

@@ -137,10 +137,13 @@ router = APIRouter()
# logger.error(f"推荐失败: {str(e)}", exc_info=True) # logger.error(f"推荐失败: {str(e)}", exc_info=True)
# raise HTTPException(status_code=500, detail=str(e)) # raise HTTPException(status_code=500, detail=str(e))
# @router.on_event("startup") @router.on_event("startup")
async def startup_event(): async def startup_event():
"""启动时初始化增量监听任务""" """启动时初始化增量监听任务"""
try: try:
# 屏蔽 apscheduler 的 INFO 日志
logging.getLogger("apscheduler").setLevel(logging.WARNING)
# 确保 Milvus 集合已创建(若已存在则直接返回) # 确保 Milvus 集合已创建(若已存在则直接返回)
try: try:
create_collection() create_collection()
@@ -172,4 +175,32 @@ async def recommend(
return [path] return [path]
except Exception as e: except Exception as e:
logger.error("新版推荐接口失败 [user=%s, category=%s]: %s", user_id, category, e, exc_info=True) logger.error("新版推荐接口失败 [user=%s, category=%s]: %s", user_id, category, e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
@router.get("/redis/user_pref")
async def get_all_user_preferences():
"""
获取所有以 user_pref 为前缀的 Redis key 信息
"""
try:
from app.service.utils.redis_utils import Redis
from app.service.recommendation_system.config import REDIS_KEY_USER_PREF_PREFIX
# 扫描所有匹配 user_pref:* 的 key
pattern = f"{REDIS_KEY_USER_PREF_PREFIX}:*"
keys = Redis.scan_keys(pattern)
# 直接返回所有 key 和原始 value
result = {}
for key in keys:
# 读取对应的值
value = Redis.read(key)
if value:
result[key] = value
return result
except Exception as e:
logger.error("获取用户偏好数据失败: %s", e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))

View File

@@ -7,6 +7,7 @@ from app.api import api_design_pre_processing
from app.api import api_generate_image from app.api import api_generate_image
from app.api import api_mannequins_edit from app.api import api_mannequins_edit
from app.api import api_pose_transform from app.api import api_pose_transform
from app.api import api_precompute
from app.api import api_prompt_generation from app.api import api_prompt_generation
from app.api import api_recommendation from app.api import api_recommendation
from app.api import api_test from app.api import api_test
@@ -21,6 +22,7 @@ router.include_router(api_prompt_generation.router, tags=['prompt_generation'],
router.include_router(api_design_pre_processing.router, tags=['design_pre_processing'], prefix="/api") router.include_router(api_design_pre_processing.router, tags=['design_pre_processing'], prefix="/api")
router.include_router(api_brand_dna.router, tags=['api_brand_dna'], prefix="/api") router.include_router(api_brand_dna.router, tags=['api_brand_dna'], prefix="/api")
router.include_router(api_recommendation.router, tags=['api_recommendation'], prefix="/api") router.include_router(api_recommendation.router, tags=['api_recommendation'], prefix="/api")
router.include_router(api_precompute.router, tags=['api_precompute'], prefix="/api")
router.include_router(api_mannequins_edit.router, tags=['api_mannequins_edit'], prefix="/api") router.include_router(api_mannequins_edit.router, tags=['api_mannequins_edit'], prefix="/api")
router.include_router(api_pose_transform.router, tags=['api_pose_transform'], prefix="/api") router.include_router(api_pose_transform.router, tags=['api_pose_transform'], prefix="/api")
router.include_router(api_clothing_seg.router, tags=['api_clothing_seg'], prefix="/api") router.include_router(api_clothing_seg.router, tags=['api_clothing_seg'], prefix="/api")

View File

@@ -36,7 +36,7 @@ class Settings(BaseSettings):
# --- mysql 配置信息 --- # --- mysql 配置信息 ---
MYSQL_HOST: str = Field(default='', description="") MYSQL_HOST: str = Field(default='', description="")
MYSQL_PORT: str = Field(default='', description="") MYSQL_PORT: int = Field(default='', description="")
MYSQL_USER: str = Field(default='', description="") MYSQL_USER: str = Field(default='', description="")
MYSQL_PASSWORD: str = Field(default='', description="") MYSQL_PASSWORD: str = Field(default='', description="")
MYSQL_DB: str = Field(default='', description="") MYSQL_DB: str = Field(default='', description="")

View File

@@ -1,4 +1,15 @@
from pydantic import BaseModel from typing import List, Optional
from pydantic import BaseModel, Field
class SAMRequestModel(BaseModel):
user_id: int = Field(..., description="用户id, 必填字段")
image_path: str = Field(..., description="图片路径,必填字段")
type: str = Field(..., description="推理类型,必填字段")
points: Optional[List[List[float]]] = None
labels: Optional[List[int]] = None
box: Optional[List[int]] = None
class DesignModel(BaseModel): class DesignModel(BaseModel):

View File

@@ -6,10 +6,10 @@ import requests
from minio import Minio from minio import Minio
from app.core.config import settings from app.core.config import settings
from app.service.design_fast.item import BodyItem, TopItem, BottomItem, OthersItem from app.service.design_fast.item import BodyItem, TopItem, BottomItem, OthersItem, TopMergeItem, BottomMergeItem, OthersMergeItem
from app.service.design_fast.utils.organize import organize_body, organize_clothing, organize_others from app.service.design_fast.utils.organize import organize_body, organize_clothing, organize_others
from app.service.design_fast.utils.progress import final_progress, update_progress from app.service.design_fast.utils.progress import final_progress, update_progress
from app.service.design_fast.utils.synthesis_item import synthesis, synthesis_single, update_base_size_priority from app.service.design_fast.utils.synthesis_item import synthesis, synthesis_single, update_base_size_priority, merge
from app.service.utils.decorator import RunTime from app.service.utils.decorator import RunTime
id_lock = threading.Lock() id_lock = threading.Lock()
@@ -19,22 +19,46 @@ logger = logging.getLogger()
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE) minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
def process_item(item, basic): def process_item(item, basic, design_type):
# 处理project中单个item # 1. 定义映射配置
if item['type'] == "Body": # key 为 item_type 的小写value 为对应的处理类
body_server = BodyItem(data=item, basic=basic, minio_client=minio_client) DESIGN_MAP = {
item_data = body_server.process() 'body': BodyItem,
elif item['type'].lower() in ['blouse', 'outwear', 'dress', 'tops']: 'blouse': TopItem, 'outwear': TopItem,
top_server = TopItem(data=item, basic=basic, minio_client=minio_client) 'dress': TopItem, 'tops': TopItem,
item_data = top_server.process() 'skirt': BottomItem, 'trousers': BottomItem,
elif item['type'].lower() in ['skirt', 'trousers', 'bottoms']: 'bottoms': BottomItem,
bottom_server = BottomItem(data=item, basic=basic, minio_client=minio_client) 'others': OthersItem
item_data = bottom_server.process() }
elif item['type'].lower() in ['others']:
bottom_server = OthersItem(data=item, basic=basic, minio_client=minio_client) MERGE_MAP = {
item_data = bottom_server.process() 'body_merge': BodyItem,
'blouse_merge': TopMergeItem, 'outwear_merge': TopMergeItem,
'dress_merge': TopMergeItem, 'tops_merge': TopMergeItem,
'skirt_merge': BottomMergeItem, 'trousers_merge': BottomMergeItem,
'bottoms_merge': BottomMergeItem,
'others_merge': OthersMergeItem
}
# 2. 根据 design_type 选择映射表
mapping = MERGE_MAP if design_type == 'merge' else DESIGN_MAP
if design_type == 'merge':
item_type_key = f"{item['type'].lower()}_merge"
elif design_type == 'default':
item_type_key = item['type'].lower()
else: else:
raise NotImplementedError(f"Item type {item['type']} not implemented") item_type_key = item['type'].lower()
handler_class = mapping.get(item_type_key)
if not handler_class:
raise NotImplementedError(f"Item type {item['type']} not implemented for design_type={design_type}")
# 4. 统一实例化并执行
# 注意:这里假设所有 Item 类构造函数签名一致
server = handler_class(data=item, basic=basic, minio_client=minio_client)
item_data = server.process()
return item_data return item_data
@@ -44,7 +68,7 @@ def process_layer(item, layers):
body_layer = organize_body(item) body_layer = organize_body(item)
layers.append(body_layer) layers.append(body_layer)
return item['body_image'].size return item['body_image'].size
elif item['name'] == 'others': elif item['name'] in ['others', 'others_merge']:
front_layer, back_layer = organize_others(item) front_layer, back_layer = organize_others(item)
layers.append(front_layer) layers.append(front_layer)
layers.append(back_layer) layers.append(back_layer)
@@ -70,10 +94,11 @@ def design_generate(request_data):
nonlocal active_threads nonlocal active_threads
basic = object['basic'] basic = object['basic']
items_response = {'layers': [], 'objectSign': object['objectSign'] if 'objectSign' in object.keys() else ""} items_response = {'layers': [], 'objectSign': object['objectSign'] if 'objectSign' in object.keys() else ""}
design_type = basic.get('design_type', "default")
if basic['single_overall'] == "overall": if basic['single_overall'] == "overall":
item_results = [] item_results = []
for item in object['items']: for item in object['items']:
item_results.append(process_item(item, basic)) item_results.append(process_item(item, basic, design_type))
layers = [] layers = []
for item in item_results: for item in item_results:
process_layer(item, layers) process_layer(item, layers)
@@ -93,10 +118,17 @@ def design_generate(request_data):
'image_url': lay['image_url'] if 'image_url' in lay.keys() else None, 'image_url': lay['image_url'] if 'image_url' in lay.keys() else None,
'pattern_overall_image_url': lay['pattern_overall_image_url'] if 'pattern_overall_image_url' in lay.keys() else None, 'pattern_overall_image_url': lay['pattern_overall_image_url'] if 'pattern_overall_image_url' in lay.keys() else None,
'pattern_print_image_url': lay['pattern_print_image_url'] if 'pattern_print_image_url' in lay.keys() else None, 'pattern_print_image_url': lay['pattern_print_image_url'] if 'pattern_print_image_url' in lay.keys() else None,
'transpose': lay.get('transpose', None),
'rotate': lay.get('rotate', None),
# 'back_perspective_url': lay['back_perspective_url'] if 'back_perspective_url' in lay.keys() else None, # 'back_perspective_url': lay['back_perspective_url'] if 'back_perspective_url' in lay.keys() else None,
}) })
items_response['synthesis_url'] = synthesis(layers, new_size, basic) if basic.get('design_type') == 'default':
items_response['synthesis_url'] = synthesis(layers, new_size, basic)
elif basic.get('design_type') == 'merge':
items_response['synthesis_url'] = merge(layers, new_size, basic)
else:
items_response['synthesis_url'] = synthesis(layers, new_size, basic)
else: else:
item_result = process_item(object['items'][0], basic) item_result = process_item(object['items'][0], basic)
items_response['layers'].append({ items_response['layers'].append({

View File

@@ -7,6 +7,7 @@ class BaseItem:
self.result['name'] = data['type'].lower() self.result['name'] = data['type'].lower()
self.result.pop("type") self.result.pop("type")
self.result.update(basic) self.result.update(basic)
self.result['design_type'] = basic.get('design_type', None)
class OthersItem(BaseItem): class OthersItem(BaseItem):
@@ -14,13 +15,7 @@ class OthersItem(BaseItem):
super().__init__(data, basic) super().__init__(data, basic)
self.Others_pipeline = [ self.Others_pipeline = [
LoadImage(minio_client), LoadImage(minio_client),
# KeyPoint(),
# ContourDetection(),
Segmentation(minio_client), Segmentation(minio_client),
# BackPerspective(minio_client),
Color(minio_client),
NoSegPrintPainting(minio_client),
PrintPainting(minio_client),
Scaling(), Scaling(),
Split(minio_client) Split(minio_client)
] ]
@@ -74,6 +69,65 @@ class BottomItem(BaseItem):
return self.result return self.result
"""merge"""
class OthersMergeItem(BaseItem):
def __init__(self, data, basic, minio_client):
super().__init__(data, basic)
self.Others_pipeline = [
LoadImage(minio_client),
# KeyPoint(),
# ContourDetection(),
Segmentation(minio_client),
# BackPerspective(minio_client),
Color(minio_client),
NoSegPrintPainting(minio_client),
PrintPainting(minio_client),
Scaling(),
Split(minio_client)
]
def process(self):
for item in self.Others_pipeline:
self.result = item(self.result)
return self.result
class TopMergeItem(BaseItem):
def __init__(self, data, basic, minio_client):
super().__init__(data, basic)
self.top_pipeline = [
LoadImage(minio_client),
KeyPoint(),
Segmentation(minio_client),
Scaling(),
Split(minio_client)
]
def process(self):
for item in self.top_pipeline:
self.result = item(self.result)
return self.result
class BottomMergeItem(BaseItem):
def __init__(self, data, basic, minio_client):
super().__init__(data, basic)
self.bottom_pipeline = [
LoadImage(minio_client),
KeyPoint(),
Segmentation(minio_client),
Scaling(),
Split(minio_client)
]
def process(self):
for item in self.bottom_pipeline:
self.result = item(self.result)
return self.result
class BodyItem(BaseItem): class BodyItem(BaseItem):
def __init__(self, data, basic, minio_client): def __init__(self, data, basic, minio_client):
super().__init__(data, basic) super().__init__(data, basic)

View File

@@ -1,7 +1,7 @@
import logging import logging
import numpy as np import numpy as np
from pymilvus import MilvusClient # from pymilvus import MilvusClient
from app.core.config import KEYPOINT_RESULT_TABLE_FIELD_SET, MILVUS_TABLE_KEYPOINT, settings from app.core.config import KEYPOINT_RESULT_TABLE_FIELD_SET, MILVUS_TABLE_KEYPOINT, settings
from app.service.design_fast.utils.design_ensemble import get_keypoint_result from app.service.design_fast.utils.design_ensemble import get_keypoint_result
@@ -54,63 +54,64 @@ class KeyPoint:
"keypoint_vector": result.tolist() "keypoint_vector": result.tolist()
} }
] ]
try: return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS)
client.upsert(collection_name=MILVUS_TABLE_KEYPOINT, data=data)
client.close()
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
except Exception as e:
logger.info(f"save keypoint cache milvus error : {e}")
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
@staticmethod # try:
def update_keypoint_cache(keypoint_id, infer_result, search_result, site): # client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS)
if site == "up": # client.upsert(collection_name=MILVUS_TABLE_KEYPOINT, data=data)
# 需要的是up 即推理出来的是up 那么查询的就是down # client.close()
result = np.concatenate([infer_result.flatten(), search_result[-4:]]) # except Exception as e:
else: # logger.info(f"save keypoint cache milvus error : {e}")
# 需要的是down 即推理出来的是down 那么查询的就是up # return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
result = np.concatenate([search_result[:20], infer_result.flatten()])
data = [
{"keypoint_id": keypoint_id,
"keypoint_site": "all",
"keypoint_vector": result.tolist()
}
]
try: # @staticmethod
client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS) # def update_keypoint_cache(keypoint_id, infer_result, search_result, site):
client.upsert( # if site == "up":
collection_name=MILVUS_TABLE_KEYPOINT, # # 需要的是up 即推理出来的是up 那么查询的就是down
data=data # result = np.concatenate([infer_result.flatten(), search_result[-4:]])
) # else:
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) # # 需要的是down 即推理出来的是down 那么查询的就是up
except Exception as e: # result = np.concatenate([search_result[:20], infer_result.flatten()])
logger.info(f"save keypoint cache milvus error : {e}") # data = [
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist())) # {"keypoint_id": keypoint_id,
# "keypoint_site": "all",
# "keypoint_vector": result.tolist()
# }
# ]
#
# try:
# client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS)
# client.upsert(
# collection_name=MILVUS_TABLE_KEYPOINT,
# data=data
# )
# return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
# except Exception as e:
# logger.info(f"save keypoint cache milvus error : {e}")
# return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, result.reshape(12, 2).astype(int).tolist()))
@RunTime # @RunTime
def keypoint_cache(self, result, site): # def keypoint_cache(self, result, site):
try: # try:
client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS) # client = MilvusClient(uri=settings.MILVUS_URL, token=settings.MILVUS_TOKEN, db_name=settings.MILVUS_ALIAS)
keypoint_id = result['image_id'] # keypoint_id = result['image_id']
res = client.query( # res = client.query(
collection_name=MILVUS_TABLE_KEYPOINT, # collection_name=MILVUS_TABLE_KEYPOINT,
# ids=[keypoint_id], # # ids=[keypoint_id],
filter=f"keypoint_id == {keypoint_id}", # filter=f"keypoint_id == {keypoint_id}",
output_fields=['keypoint_vector', 'keypoint_site'] # output_fields=['keypoint_vector', 'keypoint_site']
) # )
if len(res) == 0: # if len(res) == 0:
# 没有结果 直接推理拿结果 并保存 # # 没有结果 直接推理拿结果 并保存
keypoint_infer_result, site = self.infer_keypoint_result(result) # keypoint_infer_result, site = self.infer_keypoint_result(result)
return self.save_keypoint_cache(result['image_id'], keypoint_infer_result, site) # return self.save_keypoint_cache(result['image_id'], keypoint_infer_result, site)
elif res[0]["keypoint_site"] == "all" or res[0]["keypoint_site"] == site: # elif res[0]["keypoint_site"] == "all" or res[0]["keypoint_site"] == site:
# 需要的类型和查询的类型一致或者查询的类型为all 则直接返回查询的结果 # # 需要的类型和查询的类型一致或者查询的类型为all 则直接返回查询的结果
return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, np.array(res[0]['keypoint_vector']).astype(int).reshape(12, 2).tolist())) # return dict(zip(KEYPOINT_RESULT_TABLE_FIELD_SET, np.array(res[0]['keypoint_vector']).astype(int).reshape(12, 2).tolist()))
elif res[0]["keypoint_site"] != site: # elif res[0]["keypoint_site"] != site:
# 需要的类型和查询到的不一致则更新类型为all # # 需要的类型和查询到的不一致则更新类型为all
keypoint_infer_result, site = self.infer_keypoint_result(result) # keypoint_infer_result, site = self.infer_keypoint_result(result)
return self.update_keypoint_cache(result["image_id"], keypoint_infer_result, res[0]['keypoint_vector'], site) # return self.update_keypoint_cache(result["image_id"], keypoint_infer_result, res[0]['keypoint_vector'], site)
except Exception as e: # except Exception as e:
logger.info(f"search keypoint cache milvus error {e}") # logger.info(f"search keypoint cache milvus error {e}")
return False # return False

View File

@@ -35,15 +35,9 @@ class LoadImage:
return cls.name return cls.name
def __call__(self, result): def __call__(self, result):
if result.get("merge_image_path"):
result['merge_image'], _ = self.read_image(result['merge_image_path'])
result['image'], result['pre_mask'] = self.read_image(result['path']) result['image'], result['pre_mask'] = self.read_image(result['path'])
# if 'extract_lines' in result.keys():
# if result['extract_lines']:
# result['gray'] = self.get_lines(cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY), result['path'])
# else:
# result['gray'] = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY)
# else:
# result['gray'] = cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY)
result['gray'] = self.get_lines(cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY)) result['gray'] = self.get_lines(cv2.cvtColor(result['image'], cv2.COLOR_BGR2GRAY))
result['keypoint'] = self.get_keypoint(result['name']) result['keypoint'] = self.get_keypoint(result['name'])
result['img_shape'] = result['image'].shape result['img_shape'] = result['image'].shape
@@ -61,21 +55,6 @@ class LoadImage:
mask = skeleton mask = skeleton
result = np.ones_like(img) * 255 result = np.ones_like(img) * 255
result[mask] = img[mask] result[mask] = img[mask]
# 步骤2细化边缘可选让线条更干净
# kernel = np.ones((1, 1), np.uint8)
# clean = cv2.morphologyEx(binary, cv2.MORPH_OPEN, kernel)
# thinned = cv2.ximgproc.thinning(binary, thinningType=cv2.ximgproc.THINNING_ZHANGSUEN) # thinning算法细化线条
# mask = thinned > 0
# result = np.ones_like(img) * 255
# result[mask] = img[mask]
# 步骤3反转回 白底黑线
# lines = cv2.bitwise_not(thinned)
# cv2.imwrite(os.path.join('/home/user/PycharmProjects/trinity_client_aida/test/lines_original_result_5', f"Original_{path.replace('/', '-')}.png"), img)
# cv2.imwrite(os.path.join('/home/user/PycharmProjects/trinity_client_aida/test/lines_original_result_5', f"Line_{path.replace('/', '-')}.png"), result)
return result return result
def read_image(self, image_path): def read_image(self, image_path):
@@ -96,19 +75,19 @@ class LoadImage:
@staticmethod @staticmethod
def get_keypoint(name): def get_keypoint(name):
if name == 'blouse' or name == 'outwear' or name == 'dress' or name == 'tops': if name in ['blouse', 'outwear', 'dress', 'tops', 'blouse_merge', 'outwear_merge', 'dress_merge', 'tops_merge']:
keypoint = 'shoulder' keypoint = 'shoulder'
elif name == 'trousers' or name == 'skirt' or name == 'bottoms': elif name in ['trousers', 'skirt', 'bottoms', 'trousers_merge', 'skirt_merge', 'bottoms_merge']:
keypoint = 'waistband' keypoint = 'waistband'
elif name == 'bag': elif name in ['bag', 'bag_merge']:
keypoint = 'hand_point' keypoint = 'hand_point'
elif name == 'shoes': elif name in ['shoes', 'shoes_merge']:
keypoint = 'toe' keypoint = 'toe'
elif name == 'hairstyle': elif name in ['hairstyle', 'hairstyle_merge']:
keypoint = 'head_point' keypoint = 'head_point'
elif name == 'earring': elif name in ['earring', 'earring_merge']:
keypoint = 'ear_point' keypoint = 'ear_point'
elif name == 'others': elif name in ['others', 'others_merge']:
keypoint = "others" keypoint = "others"
else: else:
raise KeyError(f"{name} does not belong to item category list: blouse, outwear, dress, trousers, skirt, " raise KeyError(f"{name} does not belong to item category list: blouse, outwear, dress, trousers, skirt, "

View File

@@ -9,7 +9,6 @@ from app.service.utils.new_oss_client import oss_get_image
class NoSegPrintPainting: class NoSegPrintPainting:
def __init__(self, minio_client): def __init__(self, minio_client):
self.random_seed = random.randint(0, 1000)
self.minio_client = minio_client self.minio_client = minio_client
def __call__(self, result): def __call__(self, result):
@@ -21,16 +20,8 @@ class NoSegPrintPainting:
if overall_print['print_path_list']: if overall_print['print_path_list']:
painting_dict = {'dim_image_h': result['pattern_image'].shape[0], 'dim_image_w': result['pattern_image'].shape[1]} painting_dict = {'dim_image_h': result['pattern_image'].shape[0], 'dim_image_w': result['pattern_image'].shape[1]}
if "print_angle_list" in overall_print.keys() and overall_print['print_angle_list'][0] != 0: # 获取平铺 + 旋转 的overall print
painting_dict = self.painting_collection(painting_dict, overall_print, print_trigger=True) painting_dict = self.painting_collection(painting_dict, overall_print)
painting_dict['tile_print'] = self.rotate_crop_image(img=painting_dict['tile_print'], angle=-overall_print['print_angle_list'][0], crop=True)
painting_dict['mask_inv_print'] = self.rotate_crop_image(img=painting_dict['mask_inv_print'], angle=-overall_print['print_angle_list'][0], crop=True)
# resize 到sketch大小
painting_dict['tile_print'] = self.resize_and_crop(img=painting_dict['tile_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h'])
painting_dict['mask_inv_print'] = self.resize_and_crop(img=painting_dict['mask_inv_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h'])
else:
painting_dict = self.painting_collection(painting_dict, overall_print, print_trigger=True, is_single=False)
result['no_seg_sketch_overall'] = result['no_seg_sketch_print'] = self.printpaint(result, painting_dict, print_=True) result['no_seg_sketch_overall'] = result['no_seg_sketch_print'] = self.printpaint(result, painting_dict, print_=True)
result['pattern_image'] = result['no_seg_sketch_overall'] result['pattern_image'] = result['no_seg_sketch_overall']
@@ -151,7 +142,6 @@ class NoSegPrintPainting:
temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2) temp_fg = np.expand_dims(result['mask'], axis=2).repeat(3, axis=2)
tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8) tmp2 = (result['final_image'] * (temp_fg / 255)).astype(np.uint8)
result['no_seg_sketch_print'] = cv2.add(tmp1, tmp2) result['no_seg_sketch_print'] = cv2.add(tmp1, tmp2)
return result return result
@staticmethod @staticmethod
@@ -166,26 +156,21 @@ class NoSegPrintPainting:
print_background = img1_bg + img2_fg print_background = img1_bg + img2_fg
return print_background return print_background
def painting_collection(self, painting_dict, print_dict, print_trigger=False, is_single=False): def painting_collection(self, painting_dict, print_dict):
if print_trigger: print_ = self.get_print(print_dict)
print_ = self.get_print(print_dict) painting_dict['location'] = print_['location']
painting_dict['Trigger'] = not is_single dim_max = max(painting_dict['dim_image_h'], painting_dict['dim_image_w'])
painting_dict['location'] = print_['location'] dim_pattern = (int(dim_max * print_['scale'] / 5), int(dim_max * print_['scale'] / 5))
single_mask_inv_print = self.get_mask_inv(print_['image']) gap = print_dict.get('gap', [[0, 0]])[0]
dim_max = max(painting_dict['dim_image_h'], painting_dict['dim_image_w']) painting_dict['tile_print'] = tile_image(pattern=print_['image'],
dim_pattern = (int(dim_max * print_['scale'] / 5), int(dim_max * print_['scale'] / 5)) dim=dim_pattern,
if not is_single: gap_x=gap[0],
# 如果print 模式为overall 且 有角度的话 组合的print为正方形方便裁剪 gap_y=gap[1],
if "print_angle_list" in print_dict.keys() and print_dict['print_angle_list'][0] != 0: canvas_h=painting_dict['dim_image_h'],
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], dim_max, dim_max, painting_dict['location'], trigger=True) canvas_w=painting_dict['dim_image_w'],
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], dim_max, dim_max, painting_dict['location'], trigger=True) location=painting_dict['location'],
else: angle=45)
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'], trigger=True) painting_dict['mask_inv_print'] = np.zeros(painting_dict['tile_print'].shape[:2], dtype=np.uint8)
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'], trigger=True)
else:
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'])
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'])
painting_dict['dim_print_h'], painting_dict['dim_print_w'] = dim_pattern
return painting_dict return painting_dict
def tile_image(self, pattern, dim, scale, dim_image_h, dim_image_w, location, trigger=False): def tile_image(self, pattern, dim, scale, dim_image_h, dim_image_w, location, trigger=False):
@@ -219,33 +204,32 @@ class NoSegPrintPainting:
@staticmethod @staticmethod
def printpaint(result, painting_dict, print_=False): def printpaint(result, painting_dict, print_=False):
if print_:
if print_ and painting_dict['Trigger']:
print_mask = cv2.bitwise_and(result['mask'], cv2.bitwise_not(painting_dict['mask_inv_print'])) print_mask = cv2.bitwise_and(result['mask'], cv2.bitwise_not(painting_dict['mask_inv_print']))
img_fg = cv2.bitwise_and(painting_dict['tile_print'], painting_dict['tile_print'], mask=print_mask) img_fg = cv2.bitwise_and(painting_dict['tile_print'], painting_dict['tile_print'], mask=print_mask)
else: else:
print_mask = result['mask'] print_mask = result['mask']
img_fg = result['final_image'] img_fg = result['final_image']
if print_ and not painting_dict['Trigger']: # if print_ and not painting_dict['Trigger']:
index_ = None # index_ = None
try: # try:
index_ = len(painting_dict['location']) # index_ = len(painting_dict['location'])
except: # except:
assert f'there must be parameter of location if choose IfSingle' # assert f'there must be parameter of location if choose IfSingle'
#
for i in range(index_): # for i in range(index_):
start_h, start_w = int(painting_dict['location'][i][1]), int(painting_dict['location'][i][0]) # start_h, start_w = int(painting_dict['location'][i][1]), int(painting_dict['location'][i][0])
#
length_h = min(start_h + painting_dict['dim_print_h'], img_fg.shape[0]) # length_h = min(start_h + painting_dict['dim_print_h'], img_fg.shape[0])
length_w = min(start_w + painting_dict['dim_print_w'], img_fg.shape[1]) # length_w = min(start_w + painting_dict['dim_print_w'], img_fg.shape[1])
#
change_region = img_fg[start_h: length_h, start_w: length_w, :] # change_region = img_fg[start_h: length_h, start_w: length_w, :]
# problem in change_mask # # problem in change_mask
change_mask = print_mask[start_h: length_h, start_w: length_w] # change_mask = print_mask[start_h: length_h, start_w: length_w]
# get real part into change mask # # get real part into change mask
_, change_mask = cv2.threshold(change_mask, 220, 255, cv2.THRESH_BINARY) # _, change_mask = cv2.threshold(change_mask, 220, 255, cv2.THRESH_BINARY)
cv2.bitwise_not(painting_dict['mask_inv_print']) # cv2.bitwise_not(painting_dict['mask_inv_print'])
img_fg[start_h:start_h + painting_dict['dim_print_h'], start_w:start_w + painting_dict['dim_print_w'], :] = change_region # img_fg[start_h:start_h + painting_dict['dim_print_h'], start_w:start_w + painting_dict['dim_print_w'], :] = change_region
clothes_mask_print = cv2.bitwise_not(print_mask) clothes_mask_print = cv2.bitwise_not(print_mask)
@@ -277,8 +261,6 @@ class NoSegPrintPainting:
print_w = print_shape[1] print_w = print_shape[1]
print_h = print_shape[0] print_h = print_shape[0]
random.seed(self.random_seed)
# 1.拿到偏移量后和resize后的print宽高取余 得到真正偏移量 # 1.拿到偏移量后和resize后的print宽高取余 得到真正偏移量
# 偏移量增加2分之print.w 使坐标位于图中间 如果要位于左上角删除+ print_w // 2 即可 # 偏移量增加2分之print.w 使坐标位于图中间 如果要位于左上角删除+ print_w // 2 即可
x_offset = print_w - int(location[0][1] % print_w) + print_w // 2 x_offset = print_w - int(location[0][1] % print_w) + print_w // 2
@@ -420,3 +402,96 @@ class NoSegPrintPainting:
cropped_img = resized_img[start_y:start_y + target_height, :] cropped_img = resized_img[start_y:start_y + target_height, :]
return cropped_img return cropped_img
def tile_image(pattern, dim, gap_x, gap_y, canvas_h, canvas_w, location, angle=0):
"""
按照指定的 X/Y 间距平铺印花,并支持旋转
:param angle: 旋转角度 (度数, 逆时针)
"""
# 1. 确保输入是 RGBA
if pattern.shape[2] == 3:
pattern = cv2.cvtColor(pattern, cv2.COLOR_BGR2BGRA)
# 2. 缩放与旋转印花
resized_p = cv2.resize(pattern, dim, interpolation=cv2.INTER_AREA)
rotated_p = rotate_image(resized_p, angle)
p_h, p_w = rotated_p.shape[:2]
# 3. 创建透明单元格
cell_h, cell_w = p_h + gap_y, p_w + gap_x
unit_cell = np.zeros((cell_h, cell_w, 4), dtype=np.uint8)
unit_cell[:p_h, :p_w, :] = rotated_p
# 4. 执行平铺
tiles_y = (canvas_h // cell_h) + 2
tiles_x = (canvas_w // cell_w) + 2
full_tiled = np.tile(unit_cell, (tiles_y, tiles_x, 1))
# 5. 裁剪平铺层
offset_x = int(location[0][1] % cell_w)
offset_y = int(location[0][0] % cell_h)
tiled_layer = full_tiled[offset_y: offset_y + canvas_h,
offset_x: offset_x + canvas_w]
# 6. 创建纯白色背景并合成
# 创建一个纯白色的 BGR 画布
white_background = np.full((canvas_h, canvas_w, 3), 255, dtype=np.uint8)
# 分离平铺层的颜色通道和 Alpha 通道
tiled_bgr = tiled_layer[:, :, :3]
alpha_mask = tiled_layer[:, :, 3] / 255.0 # 归一化到 0-1
alpha_mask = cv2.merge([alpha_mask, alpha_mask, alpha_mask]) # 扩展到 3 通道
# 执行 Alpha 混合:结果 = 平铺层 * alpha + 背景 * (1 - alpha)
result = (tiled_bgr * alpha_mask + white_background * (1 - alpha_mask)).astype(np.uint8)
return result
def rotate_image(image, angle):
"""
旋转图片并保持完整内容(自动扩大画布)
"""
if angle == 0:
return image
(h, w) = image.shape[:2]
(cX, cY) = (w // 2, h // 2)
# 获取旋转矩阵
M = cv2.getRotationMatrix2D((cX, cY), angle, 1.0)
# 计算旋转后新边界的 sine 和 cosine
cos = np.abs(M[0, 0])
sin = np.abs(M[0, 1])
# 计算新的画布尺寸
nW = int((h * sin) + (w * cos))
nH = int((h * cos) + (w * sin))
# 调整旋转矩阵以考虑平移
M[0, 2] += (nW / 2) - cX
M[1, 2] += (nH / 2) - cY
# 执行旋转
return cv2.warpAffine(image, M, (nW, nH))
def crop_image(image, image_size_h, image_size_w, location, print_shape):
print_w = print_shape[1]
print_h = print_shape[0]
# 1.拿到偏移量后和resize后的print宽高取余 得到真正偏移量
# 偏移量增加2分之print.w 使坐标位于图中间 如果要位于左上角删除+ print_w // 2 即可
x_offset = print_w - int(location[0][1] % print_w) + print_w // 2
y_offset = print_h - int(location[0][0] % print_h) + print_h // 2
# y_offset = int(location[0][0])
# x_offset = int(location[0][1])
if len(image.shape) == 2:
image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w]
elif len(image.shape) == 3:
image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w, :]
return image

View File

@@ -9,7 +9,6 @@ from app.service.utils.new_oss_client import oss_get_image
class PrintPainting: class PrintPainting:
def __init__(self, minio_client): def __init__(self, minio_client):
self.random_seed = None
self.minio_client = minio_client self.minio_client = minio_client
def __call__(self, result): def __call__(self, result):
@@ -39,23 +38,14 @@ class PrintPainting:
overall_print['location'][0] = [x * y for x, y in zip(overall_print['location'][0], result['resize_scale'])] overall_print['location'][0] = [x * y for x, y in zip(overall_print['location'][0], result['resize_scale'])]
painting_dict = {'dim_image_h': result['pattern_image'].shape[0], 'dim_image_w': result['pattern_image'].shape[1]} painting_dict = {'dim_image_h': result['pattern_image'].shape[0], 'dim_image_w': result['pattern_image'].shape[1]}
result['print_image'] = result['pattern_image'] result['print_image'] = result['pattern_image']
if "print_angle_list" in overall_print.keys() and overall_print['print_angle_list'][0] != 0: # 获取平铺 + 旋转 的overall print
painting_dict = self.painting_collection(painting_dict, overall_print, print_trigger=True) painting_dict = self.painting_collection(painting_dict, overall_print)
painting_dict['tile_print'] = self.rotate_crop_image(img=painting_dict['tile_print'], angle=-overall_print['print_angle_list'][0], crop=True)
painting_dict['mask_inv_print'] = self.rotate_crop_image(img=painting_dict['mask_inv_print'], angle=-overall_print['print_angle_list'][0], crop=True)
# resize 到sketch大小
painting_dict['tile_print'] = self.resize_and_crop(img=painting_dict['tile_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h'])
painting_dict['mask_inv_print'] = self.resize_and_crop(img=painting_dict['mask_inv_print'], target_width=painting_dict['dim_image_w'], target_height=painting_dict['dim_image_h'])
else:
painting_dict = self.painting_collection(painting_dict, overall_print, print_trigger=True, is_single=False)
result['print_image'] = self.printpaint(result, painting_dict, print_=True) result['print_image'] = self.printpaint(result, painting_dict, print_=True)
result['single_image'] = result['final_image'] = result['pattern_image'] = result['print_image'] result['single_image'] = result['final_image'] = result['pattern_image'] = result['print_image']
if single_print['print_path_list']: if single_print['print_path_list']:
# 2025-9-19 印花调整 印花坐标按照sketch的缩放比调整 # 2025-9-19 印花调整 印花坐标按照sketch的缩放比调整
sketch_resize_scale = result['resize_scale'] sketch_resize_scale = result['resize_scale']
print_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8) print_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
mask_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8) mask_background = np.zeros((result['pattern_image'].shape[0], result['pattern_image'].shape[1], 3), dtype=np.uint8)
for i in range(len(single_print['print_path_list'])): for i in range(len(single_print['print_path_list'])):
@@ -78,75 +68,6 @@ class PrintPainting:
print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR) print_background = cv2.cvtColor(np.array(source_image_pil), cv2.COLOR_RGBA2BGR)
mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR) mask_background = cv2.cvtColor(np.array(source_image_pil_mask), cv2.COLOR_RGBA2BGR)
ret, mask_background = cv2.threshold(mask_background, 124, 255, cv2.THRESH_BINARY) ret, mask_background = cv2.threshold(mask_background, 124, 255, cv2.THRESH_BINARY)
# else:
# mask = self.get_mask_inv(image)
# mask = np.expand_dims(mask, axis=2)
# mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
# mask = cv2.bitwise_not(mask)
#
# mask = cv2.resize(mask, (int(result['final_image'].shape[1] * single_print['print_scale_list'][i][0]), int(result['final_image'].shape[0] * single_print['print_scale_list'][i][1])))
# image = cv2.resize(image, (int(result['final_image'].shape[1] * single_print['print_scale_list'][i][0]), int(result['final_image'].shape[0] * single_print['print_scale_list'][i][1])))
# # 旋转后的坐标需要重新算
# rotate_mask, _ = self.img_rotate(mask, single_print['print_angle_list'][i])
# rotate_image, rotated_new_size = self.img_rotate(image, single_print['print_angle_list'][i])
# # x, y = int(result['print']['location'][i][0] - rotated_new_size[0] - (rotate_mask.shape[0] - image.shape[0]) / 2), int(result['print']['location'][i][1] - rotated_new_size[1] - (rotate_mask.shape[1] - image.shape[1]) / 2)
# x, y = int(single_print['location'][i][0] - rotated_new_size[0]), int(single_print['location'][i][1] - rotated_new_size[1])
#
# image_x = print_background.shape[1] # 底图宽
# image_y = print_background.shape[0] # 底图高
# print_x = rotate_image.shape[1] #印花宽
# print_y = rotate_image.shape[0] #印花高
#
# # 有bug
# # if x + print_x > image_x:
# # rotate_image = rotate_image[:, :x + print_x - image_x]
# # rotate_mask = rotate_mask[:, :x + print_x - image_x]
# # #
# # if y + print_y > image_y:
# # rotate_image = rotate_image[:y + print_y - image_y]
# # rotate_mask = rotate_mask[:y + print_y - image_y]
#
# # 不能是并行
# # 当前第一轮的if 108以及115是判断有没有过下界和右界。第二轮的是判断左上有没有超出。 如果这个样子的话先裁了右边再左移region就会有问题
# # 先挪 再判断 最后裁剪
#
# # 如果print旋转了 或者 print贴边了 则需要判断 判断左界和上界是否小于0
# if x <= 0: # 如果X轴偏移量小于0说明印花需要被裁剪至合适大小 或当X轴偏移量大于印花宽度时裁剪后的印花宽度为0
# rotate_image = rotate_image[:, abs(x):]
# rotate_mask = rotate_mask[:, abs(x):]
# start_x = x = 0
# else:
# start_x = x
#
# if y <= 0: # 如果X轴偏移量大于0说明印花需要被裁剪至合适大小 或当Y轴偏移量大于印花宽度时裁剪后的印花宽度为0
# rotate_image = rotate_image[abs(y):, :]
# rotate_mask = rotate_mask[abs(y):, :]
# start_y = y = 0
# else:
# start_y = y
#
# # ------------------
# # 如果print-size大于image-size 则需要裁剪print
#
# if x + print_x > image_x:
# rotate_image = rotate_image[:, :image_x - x]
# rotate_mask = rotate_mask[:, :image_x - x]
#
# if y + print_y > image_y:
# rotate_image = rotate_image[:image_y - y, :]
# rotate_mask = rotate_mask[:image_y - y, :]
#
# # mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = cv2.bitwise_xor(mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]], rotate_mask)
# # print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = cv2.add(print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]], rotate_image)
#
# # mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = rotate_mask
# # print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image
# mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x)
# print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x)
# gray_image = cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY)
# print_background = cv2.bitwise_and(print_background, print_background, mask=gray_image)
print_mask = cv2.bitwise_and(result['mask'], cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY)) print_mask = cv2.bitwise_and(result['mask'], cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY))
img_fg = cv2.bitwise_or(print_background, print_background, mask=print_mask) img_fg = cv2.bitwise_or(print_background, print_background, mask=print_mask)
img_bg = cv2.bitwise_and(result['pattern_image'], result['pattern_image'], mask=cv2.bitwise_not(print_mask)) img_bg = cv2.bitwise_and(result['pattern_image'], result['pattern_image'], mask=cv2.bitwise_not(print_mask))
@@ -166,7 +87,6 @@ class PrintPainting:
if element_print['element_path_list']: if element_print['element_path_list']:
# 2025-9-19 印花调整 印花坐标按照sketch的缩放比调整 # 2025-9-19 印花调整 印花坐标按照sketch的缩放比调整
sketch_resize_scale = result['resize_scale'] sketch_resize_scale = result['resize_scale']
print_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8) print_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8)
mask_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8) mask_background = np.zeros((result['final_image'].shape[0], result['final_image'].shape[1], 3), dtype=np.uint8)
for i in range(len(element_print['element_path_list'])): for i in range(len(element_print['element_path_list'])):
@@ -207,20 +127,6 @@ class PrintPainting:
print_x = rotate_image.shape[1] print_x = rotate_image.shape[1]
print_y = rotate_image.shape[0] print_y = rotate_image.shape[0]
# 有bug
# if x + print_x > image_x:
# rotate_image = rotate_image[:, :x + print_x - image_x]
# rotate_mask = rotate_mask[:, :x + print_x - image_x]
# #
# if y + print_y > image_y:
# rotate_image = rotate_image[:y + print_y - image_y]
# rotate_mask = rotate_mask[:y + print_y - image_y]
# 不能是并行
# 当前第一轮的if 108以及115是判断有没有过下界和右界。第二轮的是判断左上有没有超出。 如果这个样子的话先裁了右边再左移region就会有问题
# 先挪 再判断 最后裁剪
# 如果print旋转了 或者 print贴边了 则需要判断 判断左界和上界是否小于0
if x <= 0: if x <= 0:
rotate_image = rotate_image[:, -x:] rotate_image = rotate_image[:, -x:]
rotate_mask = rotate_mask[:, -x:] rotate_mask = rotate_mask[:, -x:]
@@ -235,9 +141,6 @@ class PrintPainting:
else: else:
start_y = y start_y = y
# ------------------
# 如果print-size大于image-size 则需要裁剪print
if x + print_x > image_x: if x + print_x > image_x:
rotate_image = rotate_image[:, :image_x - x] rotate_image = rotate_image[:, :image_x - x]
rotate_mask = rotate_mask[:, :image_x - x] rotate_mask = rotate_mask[:, :image_x - x]
@@ -246,11 +149,6 @@ class PrintPainting:
rotate_image = rotate_image[:image_y - y, :] rotate_image = rotate_image[:image_y - y, :]
rotate_mask = rotate_mask[:image_y - y, :] rotate_mask = rotate_mask[:image_y - y, :]
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = cv2.bitwise_xor(mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]], rotate_mask)
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = cv2.add(print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]], rotate_image)
# mask_background[start_y:y + rotate_mask.shape[0], start_x:x + rotate_mask.shape[1]] = rotate_mask
# print_background[start_y:y + rotate_image.shape[0], start_x:x + rotate_image.shape[1]] = rotate_image
mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x) mask_background = self.stack_prin(mask_background, result['pattern_image'], rotate_mask, start_y, y, start_x, x)
print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x) print_background = self.stack_prin(print_background, result['pattern_image'], rotate_image, start_y, y, start_x, x)
@@ -298,12 +196,8 @@ class PrintPainting:
ret, mask_background = cv2.threshold(mask_background, 124, 255, cv2.THRESH_BINARY) ret, mask_background = cv2.threshold(mask_background, 124, 255, cv2.THRESH_BINARY)
print_mask = cv2.bitwise_and(result['mask'], cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY)) print_mask = cv2.bitwise_and(result['mask'], cv2.cvtColor(mask_background, cv2.COLOR_BGR2GRAY))
img_fg = cv2.bitwise_or(print_background, print_background, mask=print_mask) img_fg = cv2.bitwise_or(print_background, print_background, mask=print_mask)
# TODO element 丢失信息
three_channel_image = cv2.merge([cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask)]) three_channel_image = cv2.merge([cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask), cv2.bitwise_not(print_mask)])
img_bg = cv2.bitwise_and(result['final_image'], three_channel_image) img_bg = cv2.bitwise_and(result['final_image'], three_channel_image)
# mask_mo = np.expand_dims(print_mask, axis=2).repeat(3, axis=2)
# gray_mo = np.expand_dims(result['gray'], axis=2).repeat(3, axis=2)
# img_fg = (img_fg * (mask_mo / 255) * (gray_mo / 255)).astype(np.uint8)
result['final_image'] = cv2.add(img_bg, img_fg) result['final_image'] = cv2.add(img_bg, img_fg)
canvas = np.full_like(result['final_image'], 255) canvas = np.full_like(result['final_image'], 255)
temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2) temp_bg = np.expand_dims(cv2.bitwise_not(result['mask']), axis=2).repeat(3, axis=2)
@@ -325,27 +219,21 @@ class PrintPainting:
print_background = img1_bg + img2_fg print_background = img1_bg + img2_fg
return print_background return print_background
def painting_collection(self, painting_dict, print_dict, print_trigger=False, is_single=False): def painting_collection(self, painting_dict, print_dict):
if print_trigger: print_ = self.get_print(print_dict)
print_ = self.get_print(print_dict) painting_dict['location'] = print_['location']
painting_dict['Trigger'] = not is_single dim_max = max(painting_dict['dim_image_h'], painting_dict['dim_image_w'])
painting_dict['location'] = print_['location'] dim_pattern = (int(dim_max * print_['scale'] / 5), int(dim_max * print_['scale'] / 5))
single_mask_inv_print = self.get_mask_inv(print_['image']) gap = print_dict.get('gap', [[0, 0]])[0]
dim_max = max(painting_dict['dim_image_h'], painting_dict['dim_image_w']) painting_dict['tile_print'] = tile_image(pattern=print_['image'],
dim_pattern = (int(dim_max * print_['scale'] / 5), int(dim_max * print_['scale'] / 5)) dim=dim_pattern,
if not is_single: gap_x=gap[0],
self.random_seed = random.randint(0, 1000) gap_y=gap[1],
# 如果print 模式为overall 且 有角度的话 组合的print为正方形方便裁剪 canvas_h=painting_dict['dim_image_h'],
if "print_angle_list" in print_dict.keys() and print_dict['print_angle_list'][0] != 0: canvas_w=painting_dict['dim_image_w'],
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], dim_max, dim_max, painting_dict['location'], trigger=True) location=painting_dict['location'],
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], dim_max, dim_max, painting_dict['location'], trigger=True) angle=45)
else: painting_dict['mask_inv_print'] = np.zeros(painting_dict['tile_print'].shape[:2], dtype=np.uint8)
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'], trigger=True)
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'], trigger=True)
else:
painting_dict['mask_inv_print'] = self.tile_image(single_mask_inv_print, dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'])
painting_dict['tile_print'] = self.tile_image(print_['image'], dim_pattern, print_['scale'], painting_dict['dim_image_h'], painting_dict['dim_image_w'], painting_dict['location'])
painting_dict['dim_print_h'], painting_dict['dim_print_w'] = dim_pattern
return painting_dict return painting_dict
def tile_image(self, pattern, dim, scale, dim_image_h, dim_image_w, location, trigger=False): def tile_image(self, pattern, dim, scale, dim_image_h, dim_image_w, location, trigger=False):
@@ -374,51 +262,37 @@ class PrintPainting:
mask_inv = cv2.inRange(print_tile, lower, upper) mask_inv = cv2.inRange(print_tile, lower, upper)
return mask_inv return mask_inv
else: else:
# bg_color = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)[0][0]
# print_tile = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)
# bg_l, bg_a, bg_b = bg_color[0], bg_color[1], bg_color[2]
# bg_L_high, bg_L_low = self.get_low_high_lab(bg_l, L=True)
# bg_a_high, bg_a_low = self.get_low_high_lab(bg_a)
# bg_b_high, bg_b_low = self.get_low_high_lab(bg_b)
# lower = np.array([bg_L_low, bg_a_low, bg_b_low])
# upper = np.array([bg_L_high, bg_a_high, bg_b_high])
# print_tile = cv2.cvtColor(print_, cv2.COLOR_BGR2LAB)
# mask_inv = cv2.cvtColor(print_tile, cv2.COLOR_BGR2GRAY)
# mask_inv = cv2.cvtColor(print_, cv2.COLOR_BGR2GRAY)
mask_inv = np.zeros(print_.shape[:2], dtype=np.uint8) mask_inv = np.zeros(print_.shape[:2], dtype=np.uint8)
return mask_inv return mask_inv
@staticmethod @staticmethod
def printpaint(result, painting_dict, print_=False): def printpaint(result, painting_dict, print_=False):
if print_:
if print_ and painting_dict['Trigger']:
print_mask = cv2.bitwise_and(result['mask'], cv2.bitwise_not(painting_dict['mask_inv_print'])) print_mask = cv2.bitwise_and(result['mask'], cv2.bitwise_not(painting_dict['mask_inv_print']))
img_fg = cv2.bitwise_and(painting_dict['tile_print'], painting_dict['tile_print'], mask=print_mask) img_fg = cv2.bitwise_and(painting_dict['tile_print'], painting_dict['tile_print'], mask=print_mask)
else: else:
print_mask = result['mask'] print_mask = result['mask']
img_fg = result['final_image'] img_fg = result['final_image']
if print_ and not painting_dict['Trigger']: # if print_ and not painting_dict['Trigger']:
index_ = None # index_ = None
try: # try:
index_ = len(painting_dict['location']) # index_ = len(painting_dict['location'])
except: # except:
assert f'there must be parameter of location if choose IfSingle' # assert f'there must be parameter of location if choose IfSingle'
#
for i in range(index_): # for i in range(index_):
start_h, start_w = int(painting_dict['location'][i][1]), int(painting_dict['location'][i][0]) # start_h, start_w = int(painting_dict['location'][i][1]), int(painting_dict['location'][i][0])
#
length_h = min(start_h + painting_dict['dim_print_h'], img_fg.shape[0]) # length_h = min(start_h + painting_dict['dim_print_h'], img_fg.shape[0])
length_w = min(start_w + painting_dict['dim_print_w'], img_fg.shape[1]) # length_w = min(start_w + painting_dict['dim_print_w'], img_fg.shape[1])
#
change_region = img_fg[start_h: length_h, start_w: length_w, :] # change_region = img_fg[start_h: length_h, start_w: length_w, :]
# problem in change_mask # # problem in change_mask
change_mask = print_mask[start_h: length_h, start_w: length_w] # change_mask = print_mask[start_h: length_h, start_w: length_w]
# get real part into change mask # # get real part into change mask
_, change_mask = cv2.threshold(change_mask, 220, 255, cv2.THRESH_BINARY) # _, change_mask = cv2.threshold(change_mask, 220, 255, cv2.THRESH_BINARY)
cv2.bitwise_not(painting_dict['mask_inv_print']) # cv2.bitwise_not(painting_dict['mask_inv_print'])
img_fg[start_h:start_h + painting_dict['dim_print_h'], start_w:start_w + painting_dict['dim_print_w'], :] = change_region # img_fg[start_h:start_h + painting_dict['dim_print_h'], start_w:start_w + painting_dict['dim_print_w'], :] = change_region
clothes_mask_print = cv2.bitwise_not(print_mask) clothes_mask_print = cv2.bitwise_not(print_mask)
@@ -450,11 +324,6 @@ class PrintPainting:
print_w = print_shape[1] print_w = print_shape[1]
print_h = print_shape[0] print_h = print_shape[0]
random.seed(self.random_seed)
# logging.info(f'overall print location : {location}')
# x_offset = random.randint(0, image.shape[0] - image_size_h)
# y_offset = random.randint(0, image.shape[1] - image_size_w)
# 1.拿到偏移量后和resize后的print宽高取余 得到真正偏移量 # 1.拿到偏移量后和resize后的print宽高取余 得到真正偏移量
# 偏移量增加2分之print.w 使坐标位于图中间 如果要位于左上角删除+ print_w // 2 即可 # 偏移量增加2分之print.w 使坐标位于图中间 如果要位于左上角删除+ print_w // 2 即可
x_offset = print_w - int(location[0][1] % print_w) + print_w // 2 x_offset = print_w - int(location[0][1] % print_w) + print_w // 2
@@ -596,3 +465,96 @@ class PrintPainting:
cropped_img = resized_img[start_y:start_y + target_height, :] cropped_img = resized_img[start_y:start_y + target_height, :]
return cropped_img return cropped_img
def tile_image(pattern, dim, gap_x, gap_y, canvas_h, canvas_w, location, angle=0):
"""
按照指定的 X/Y 间距平铺印花,并支持旋转
:param angle: 旋转角度 (度数, 逆时针)
"""
# 1. 确保输入是 RGBA
if pattern.shape[2] == 3:
pattern = cv2.cvtColor(pattern, cv2.COLOR_BGR2BGRA)
# 2. 缩放与旋转印花
resized_p = cv2.resize(pattern, dim, interpolation=cv2.INTER_AREA)
rotated_p = rotate_image(resized_p, angle)
p_h, p_w = rotated_p.shape[:2]
# 3. 创建透明单元格
cell_h, cell_w = p_h + gap_y, p_w + gap_x
unit_cell = np.zeros((cell_h, cell_w, 4), dtype=np.uint8)
unit_cell[:p_h, :p_w, :] = rotated_p
# 4. 执行平铺
tiles_y = (canvas_h // cell_h) + 2
tiles_x = (canvas_w // cell_w) + 2
full_tiled = np.tile(unit_cell, (tiles_y, tiles_x, 1))
# 5. 裁剪平铺层
offset_x = int(location[0][1] % cell_w)
offset_y = int(location[0][0] % cell_h)
tiled_layer = full_tiled[offset_y: offset_y + canvas_h,
offset_x: offset_x + canvas_w]
# 6. 创建纯白色背景并合成
# 创建一个纯白色的 BGR 画布
white_background = np.full((canvas_h, canvas_w, 3), 255, dtype=np.uint8)
# 分离平铺层的颜色通道和 Alpha 通道
tiled_bgr = tiled_layer[:, :, :3]
alpha_mask = tiled_layer[:, :, 3] / 255.0 # 归一化到 0-1
alpha_mask = cv2.merge([alpha_mask, alpha_mask, alpha_mask]) # 扩展到 3 通道
# 执行 Alpha 混合:结果 = 平铺层 * alpha + 背景 * (1 - alpha)
result = (tiled_bgr * alpha_mask + white_background * (1 - alpha_mask)).astype(np.uint8)
return result
def rotate_image(image, angle):
"""
旋转图片并保持完整内容(自动扩大画布)
"""
if angle == 0:
return image
(h, w) = image.shape[:2]
(cX, cY) = (w // 2, h // 2)
# 获取旋转矩阵
M = cv2.getRotationMatrix2D((cX, cY), angle, 1.0)
# 计算旋转后新边界的 sine 和 cosine
cos = np.abs(M[0, 0])
sin = np.abs(M[0, 1])
# 计算新的画布尺寸
nW = int((h * sin) + (w * cos))
nH = int((h * cos) + (w * sin))
# 调整旋转矩阵以考虑平移
M[0, 2] += (nW / 2) - cX
M[1, 2] += (nH / 2) - cY
# 执行旋转
return cv2.warpAffine(image, M, (nW, nH))
def crop_image(image, image_size_h, image_size_w, location, print_shape):
print_w = print_shape[1]
print_h = print_shape[0]
# 1.拿到偏移量后和resize后的print宽高取余 得到真正偏移量
# 偏移量增加2分之print.w 使坐标位于图中间 如果要位于左上角删除+ print_w // 2 即可
x_offset = print_w - int(location[0][1] % print_w) + print_w // 2
y_offset = print_h - int(location[0][0] % print_h) + print_h // 2
# y_offset = int(location[0][0])
# x_offset = int(location[0][1])
if len(image.shape) == 2:
image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w]
elif len(image.shape) == 3:
image = image[x_offset: x_offset + image_size_h, y_offset: y_offset + image_size_w, :]
return image

View File

@@ -34,15 +34,15 @@ class Segmentation:
result['mask'] = result['front_mask'] + result['back_mask'] result['mask'] = result['front_mask'] + result['back_mask']
else: else:
# preview 过模型 不缓存 # preview 过模型 不缓存
if "preview_submit" in result.keys() and result['preview_submit'] == "preview": if result.get("design_type", None) == "merge":
# 推理获得seg 结果
seg_result = get_seg_result(result['image']) seg_result = get_seg_result(result['image'])
# submit 过模型 缓存 # 默认design 模式 - 过模型 缓存
elif "preview_submit" in result.keys() and result['preview_submit'] == "submit": # elif result.get("design_type", None) == "submit":
# 推理获得seg 结果 # 推理获得seg 结果
seg_result = get_seg_result(result['image']) # seg_result = get_seg_result(result['image'])
self.save_seg_result(seg_result, result['image_id']) # self.save_seg_result(seg_result, result['image_id'])
# null 正常流程 加载本地缓存 无缓存则过模型
# 默认模式- 加载模型,找不到则过模型推理,推理后保存到本地
else: else:
# 本地查询seg 缓存是否存在 # 本地查询seg 缓存是否存在
_, seg_result = self.load_seg_result(result["image_id"]) _, seg_result = self.load_seg_result(result["image_id"])

View File

@@ -4,6 +4,7 @@ import logging
import cv2 import cv2
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from celery.bin.result import result
from app.service.design_fast.utils.conversion_image import rgb_to_rgba from app.service.design_fast.utils.conversion_image import rgb_to_rgba
from app.service.design_fast.utils.transparent import sketch_to_transparent from app.service.design_fast.utils.transparent import sketch_to_transparent
@@ -19,105 +20,106 @@ class Split(object):
def __call__(self, result): def __call__(self, result):
try: try:
if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms', 'others'): if result['name'] in ('outwear', 'dress', 'blouse', 'skirt', 'trousers', 'tops', 'bottoms', 'others'):
ori_front_mask = result['front_mask'].copy() if result.get('design_type', None) == 'merge':
ori_back_mask = result['back_mask'].copy() # merge 不需要返回mask (红绿图)
if result['resize_scale'][0] == 1.0 and result['resize_scale'][1] == 1.0:
if result['resize_scale'][0] == 1.0 and result['resize_scale'][1] == 1.0: front_mask = result['front_mask']
front_mask = result['front_mask'] back_mask = result['back_mask']
back_mask = result['back_mask']
else:
height, width = result['front_mask'].shape[:2]
new_width = int(width * result['resize_scale'][0])
new_height = int(height * result['resize_scale'][1])
front_mask = cv2.resize(result['front_mask'], (new_width, new_height), interpolation=cv2.INTER_AREA)
back_mask = cv2.resize(result['back_mask'], (new_width, new_height), interpolation=cv2.INTER_AREA)
rgba_image = rgb_to_rgba(result['final_image'], front_mask + back_mask)
new_size = (int(rgba_image.shape[1] * result["scale"]), int(rgba_image.shape[0] * result["scale"]))
rgba_image = cv2.resize(rgba_image, new_size, interpolation=cv2.INTER_AREA)
result_front_image = np.zeros_like(rgba_image)
front_mask = cv2.resize(front_mask, new_size, interpolation=cv2.INTER_AREA)
result_front_image[front_mask != 0] = rgba_image[front_mask != 0]
result_front_image_pil = Image.fromarray(cv2.cvtColor(result_front_image, cv2.COLOR_BGR2RGBA))
if 'transparent' in result.keys():
# 用户自选区域transparent
transparent = result['transparent']
if transparent['mask_url'] is not None and transparent['mask_url'] != "":
# 预处理用户自选区mask
seg_mask = oss_get_image(oss_client=self.minio_client, bucket=transparent['mask_url'].split('/')[0], object_name=transparent['mask_url'][transparent['mask_url'].find('/') + 1:], data_type="cv2")
seg_mask = cv2.resize(seg_mask, new_size, interpolation=cv2.INTER_AREA)
# 转换颜色空间为 RGBOpenCV 默认是 BGR
image_rgb = cv2.cvtColor(seg_mask, cv2.COLOR_BGR2RGB)
r, g, b = cv2.split(image_rgb)
blue_mask = b > r
# 创建红色和绿色掩码
transparent_mask = np.array(blue_mask, dtype=np.uint8) * 255
result_front_image_pil = sketch_to_transparent(result_front_image_pil, transparent_mask, transparent["scale"])
else: else:
result_front_image_pil = sketch_to_transparent(result_front_image_pil, front_mask, transparent["scale"]) height, width = result['front_mask'].shape[:2]
result['front_image'], result["front_image_url"], _ = upload_png_mask(self.minio_client, result_front_image_pil, f'{generate_uuid()}', mask=None) new_width = int(width * result['resize_scale'][0])
new_height = int(height * result['resize_scale'][1])
# 前片部分 (红图部分) front_mask = cv2.resize(result['front_mask'], (new_width, new_height), interpolation=cv2.INTER_AREA)
# height, width = front_mask.shape back_mask = cv2.resize(result['back_mask'], (new_width, new_height), interpolation=cv2.INTER_AREA)
# mask_image = np.zeros((height, width, 3)) result['merge_image'] = cv2.resize(result['merge_image'], (new_width, new_height), interpolation=cv2.INTER_AREA)
# mask_image[front_mask != 0] = [0, 0, 255]
# 切换为原始图片尺寸------------------------------- rgba_image = rgb_to_rgba(result['merge_image'], front_mask + back_mask)
height, width = ori_front_mask.shape new_size = (int(rgba_image.shape[1] * result["scale"]), int(rgba_image.shape[0] * result["scale"]))
mask_image = np.zeros((height, width, 3)) rgba_image = cv2.resize(rgba_image, new_size, interpolation=cv2.INTER_AREA)
mask_image[ori_front_mask != 0] = [0, 0, 255] result_front_image = np.zeros_like(rgba_image)
# ----------------------------------------------- front_mask = cv2.resize(front_mask, new_size, interpolation=cv2.INTER_AREA)
result_front_image[front_mask != 0] = rgba_image[front_mask != 0]
result_front_image_pil = Image.fromarray(cv2.cvtColor(result_front_image, cv2.COLOR_BGR2RGBA))
result['front_image'], result["front_image_url"], _ = upload_png_mask(self.minio_client, result_front_image_pil, f'{generate_uuid()}', mask=None)
# if result["name"] in ('blouse', 'dress', 'outwear', 'tops'): result_back_image = np.zeros_like(rgba_image)
# result_back_image = np.zeros_like(rgba_image) back_mask = cv2.resize(back_mask, new_size, interpolation=cv2.INTER_AREA)
# back_mask = cv2.resize(back_mask, new_size, interpolation=cv2.INTER_AREA) result_back_image[back_mask != 0] = rgba_image[back_mask != 0]
# result_back_image[back_mask != 0] = rgba_image[back_mask != 0] result_back_image_pil = Image.fromarray(cv2.cvtColor(result_back_image, cv2.COLOR_BGR2RGBA))
# result_back_image_pil = Image.fromarray(cvtColor(result_back_image, COLOR_BGR2RGBA)) result['back_image'], result["back_image_url"], _ = upload_png_mask(self.minio_client, result_back_image_pil, f'{generate_uuid()}', mask=None)
# result['back_image'], result["back_image_url"], _ = upload_png_mask(self.minio_client, result_back_image_pil, f'{generate_uuid()}', mask=None) return result
# mask_image[back_mask != 0] = [0, 255, 0] else:
# ori_front_mask = result['front_mask'].copy()
# rbga_mask = rgb_to_rgba(mask_image, front_mask + back_mask) ori_back_mask = result['back_mask'].copy()
# mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
# image_data = io.BytesIO()
# mask_pil.save(image_data, format='PNG')
# image_data.seek(0)
# image_bytes = image_data.read()
# req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
# result['mask_url'] = req.bucket_name + "/" + req.object_name
# else:
# rbga_mask = rgb_to_rgba(mask_image, front_mask)
# mask_pil = Image.fromarray(cvtColor(rbga_mask.astype(np.uint8), COLOR_BGR2RGBA))
# image_data = io.BytesIO()
# mask_pil.save(image_data, format='PNG')
# image_data.seek(0)
# image_bytes = image_data.read()
# req = oss_upload_image(oss_client=self.minio_client, bucket=AIDA_CLOTHING, object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
# result['mask_url'] = req.bucket_name + "/" + req.object_name
# result['back_image'] = None
# result["back_image_url"] = None
# # result["back_mask_url"] = None
# # result['back_mask_image'] = None
result_back_image = np.zeros_like(rgba_image) if result['resize_scale'][0] == 1.0 and result['resize_scale'][1] == 1.0:
back_mask = cv2.resize(back_mask, new_size, interpolation=cv2.INTER_AREA) front_mask = result['front_mask']
result_back_image[back_mask != 0] = rgba_image[back_mask != 0] back_mask = result['back_mask']
result_back_image_pil = Image.fromarray(cv2.cvtColor(result_back_image, cv2.COLOR_BGR2RGBA)) else:
result['back_image'], result["back_image_url"], _ = upload_png_mask(self.minio_client, result_back_image_pil, f'{generate_uuid()}', mask=None) height, width = result['front_mask'].shape[:2]
new_width = int(width * result['resize_scale'][0])
new_height = int(height * result['resize_scale'][1])
# mask_image[back_mask != 0] = [0, 255, 0] front_mask = cv2.resize(result['front_mask'], (new_width, new_height), interpolation=cv2.INTER_AREA)
mask_image[ori_back_mask != 0] = [0, 255, 0] back_mask = cv2.resize(result['back_mask'], (new_width, new_height), interpolation=cv2.INTER_AREA)
rbga_mask = rgb_to_rgba(mask_image, ori_front_mask + ori_back_mask) rgba_image = rgb_to_rgba(result['final_image'], front_mask + back_mask)
mask_pil = Image.fromarray(cv2.cvtColor(rbga_mask.astype(np.uint8), cv2.COLOR_BGR2RGBA)) new_size = (int(rgba_image.shape[1] * result["scale"]), int(rgba_image.shape[0] * result["scale"]))
image_data = io.BytesIO() rgba_image = cv2.resize(rgba_image, new_size, interpolation=cv2.INTER_AREA)
mask_pil.save(image_data, format='PNG') result_front_image = np.zeros_like(rgba_image)
image_data.seek(0) front_mask = cv2.resize(front_mask, new_size, interpolation=cv2.INTER_AREA)
image_bytes = image_data.read() result_front_image[front_mask != 0] = rgba_image[front_mask != 0]
req = oss_upload_image(oss_client=self.minio_client, bucket="aida-clothing", object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes) result_front_image_pil = Image.fromarray(cv2.cvtColor(result_front_image, cv2.COLOR_BGR2RGBA))
result['mask_url'] = req.bucket_name + "/" + req.object_name if 'transparent' in result.keys():
# 用户自选区域transparent
transparent = result['transparent']
if transparent['mask_url'] is not None and transparent['mask_url'] != "":
# 预处理用户自选区mask
seg_mask = oss_get_image(oss_client=self.minio_client, bucket=transparent['mask_url'].split('/')[0], object_name=transparent['mask_url'][transparent['mask_url'].find('/') + 1:], data_type="cv2")
seg_mask = cv2.resize(seg_mask, new_size, interpolation=cv2.INTER_AREA)
# 转换颜色空间为 RGBOpenCV 默认是 BGR
image_rgb = cv2.cvtColor(seg_mask, cv2.COLOR_BGR2RGB)
r, g, b = cv2.split(image_rgb)
blue_mask = b > r
# 创建红色和绿色掩码
transparent_mask = np.array(blue_mask, dtype=np.uint8) * 255
result_front_image_pil = sketch_to_transparent(result_front_image_pil, transparent_mask, transparent["scale"])
else:
result_front_image_pil = sketch_to_transparent(result_front_image_pil, front_mask, transparent["scale"])
result['front_image'], result["front_image_url"], _ = upload_png_mask(self.minio_client, result_front_image_pil, f'{generate_uuid()}', mask=None)
height, width = ori_front_mask.shape
mask_image = np.zeros((height, width, 3))
mask_image[ori_front_mask != 0] = [0, 0, 255]
result_back_image = np.zeros_like(rgba_image)
back_mask = cv2.resize(back_mask, new_size, interpolation=cv2.INTER_AREA)
result_back_image[back_mask != 0] = rgba_image[back_mask != 0]
result_back_image_pil = Image.fromarray(cv2.cvtColor(result_back_image, cv2.COLOR_BGR2RGBA))
result['back_image'], result["back_image_url"], _ = upload_png_mask(self.minio_client, result_back_image_pil, f'{generate_uuid()}', mask=None)
# mask_image[back_mask != 0] = [0, 255, 0]
mask_image[ori_back_mask != 0] = [0, 255, 0]
rbga_mask = rgb_to_rgba(mask_image, ori_front_mask + ori_back_mask)
mask_pil = Image.fromarray(cv2.cvtColor(rbga_mask.astype(np.uint8), cv2.COLOR_BGR2RGBA))
image_data = io.BytesIO()
mask_pil.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
req = oss_upload_image(oss_client=self.minio_client, bucket="aida-clothing", object_name=f"mask/mask_{generate_uuid()}.png", image_bytes=image_bytes)
result['mask_url'] = req.bucket_name + "/" + req.object_name
# 创建中间图层(未分割图层) 1.color + overall_print 2.color + overall_print + print
result_pattern_overall_image_pil = Image.fromarray(cv2.cvtColor(rgb_to_rgba(result['no_seg_sketch_overall'], ori_front_mask + ori_back_mask), cv2.COLOR_BGR2RGBA))
result['pattern_overall_image'], result['pattern_overall_image_url'], _ = upload_png_mask(self.minio_client, result_pattern_overall_image_pil, f'{generate_uuid()}')
result_pattern_print_image_pil = Image.fromarray(cv2.cvtColor(rgb_to_rgba(result['no_seg_sketch_print'], ori_front_mask + ori_back_mask), cv2.COLOR_BGR2RGBA))
result['pattern_print_image'], result['pattern_print_image_url'], _ = upload_png_mask(self.minio_client, result_pattern_print_image_pil, f'{generate_uuid()}')
return result
else: else:
ori_front_mask, ori_back_mask = None, None ori_front_mask, ori_back_mask = None, None
# 创建中间图层(未分割图层) 1.color + overall_print 2.color + overall_print + print # 创建中间图层(未分割图层) 1.color + overall_print 2.color + overall_print + print
@@ -127,5 +129,6 @@ class Split(object):
result_pattern_print_image_pil = Image.fromarray(cv2.cvtColor(rgb_to_rgba(result['no_seg_sketch_print'], ori_front_mask + ori_back_mask), cv2.COLOR_BGR2RGBA)) result_pattern_print_image_pil = Image.fromarray(cv2.cvtColor(rgb_to_rgba(result['no_seg_sketch_print'], ori_front_mask + ori_back_mask), cv2.COLOR_BGR2RGBA))
result['pattern_print_image'], result['pattern_print_image_url'], _ = upload_png_mask(self.minio_client, result_pattern_print_image_pil, f'{generate_uuid()}') result['pattern_print_image'], result['pattern_print_image_url'], _ = upload_png_mask(self.minio_client, result_pattern_print_image_pil, f'{generate_uuid()}')
return result return result
except Exception as e: except Exception as e:
logging.warning(f"split runtime exception : {e} image_id : {result['image_id']}") logging.warning(f"split runtime exception : {e} image_id : {result['image_id']}")

View File

@@ -23,20 +23,23 @@ def organize_clothing(layer):
front_layer = dict(priority=layer['priority'] if layer.get("layer_order", False) else PRIORITY_DICT.get(f'{layer["name"].lower()}_front', None), front_layer = dict(priority=layer['priority'] if layer.get("layer_order", False) else PRIORITY_DICT.get(f'{layer["name"].lower()}_front', None),
name=f'{layer["name"].lower()}_front', name=f'{layer["name"].lower()}_front',
image=layer["front_image"], image=layer["front_image"],
merge_image=layer["front_image"],
# mask_image=layer['front_mask_image'], # mask_image=layer['front_mask_image'],
image_url=layer['front_image_url'], image_url=layer['front_image_url'],
mask_url=layer['mask_url'], mask_url=layer.get("mask_url", None),
sacle=layer['scale'], sacle=layer['scale'],
clothes_keypoint=layer['clothes_keypoint'], clothes_keypoint=layer['clothes_keypoint'],
position=start_point, position=start_point,
resize_scale=layer["resize_scale"], resize_scale=layer["resize_scale"],
mask=cv2.resize(layer['mask'], layer["front_image"].size), mask=cv2.resize(layer['mask'], layer["front_image"].size),
gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "", gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "",
pattern_overall_image_url=layer['pattern_overall_image_url'], pattern_overall_image_url=layer.get('pattern_overall_image_url', None),
pattern_print_image_url=layer['pattern_print_image_url'], pattern_print_image_url=layer.get('pattern_print_image_url', None),
pattern_image=layer['pattern_image'], pattern_image=layer.get('pattern_image', None),
# back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else "" # back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else ""
transpose=layer.get("transpose", [1, 1]), # 默认为1, 1代表不镜像
rotate=layer.get('rotate', 0),
) )
# 后片数据 # 后片数据
back_layer = dict(priority=-layer.get("priority", 0) if layer.get("layer_order", False) else PRIORITY_DICT.get(f'{layer["name"].lower()}_back', None), back_layer = dict(priority=-layer.get("priority", 0) if layer.get("layer_order", False) else PRIORITY_DICT.get(f'{layer["name"].lower()}_back', None),
@@ -44,16 +47,18 @@ def organize_clothing(layer):
image=layer["back_image"], image=layer["back_image"],
# mask_image=layer['back_mask_image'], # mask_image=layer['back_mask_image'],
image_url=layer['back_image_url'], image_url=layer['back_image_url'],
mask_url=layer['mask_url'], mask_url=layer.get('mask_url', None),
sacle=layer['scale'], sacle=layer['scale'],
clothes_keypoint=layer['clothes_keypoint'], clothes_keypoint=layer['clothes_keypoint'],
position=start_point, position=start_point,
resize_scale=layer["resize_scale"], resize_scale=layer["resize_scale"],
mask=cv2.resize(layer['mask'], layer["front_image"].size), mask=cv2.resize(layer['mask'], layer["front_image"].size),
gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "", gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "",
pattern_overall_image_url=layer['pattern_overall_image_url'], pattern_overall_image_url=layer.get('pattern_overall_image_url', None),
pattern_print_image_url=layer['pattern_print_image_url'], pattern_print_image_url=layer.get('pattern_print_image_url', None),
# back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else "" # back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else ""
transpose=layer.get("transpose", [1, 1]), # 默认为1, 1代表不镜像
rotate=layer.get('rotate', 0),
) )
return front_layer, back_layer return front_layer, back_layer
@@ -76,16 +81,16 @@ def organize_others(layer):
image=layer["front_image"], image=layer["front_image"],
# mask_image=layer['front_mask_image'], # mask_image=layer['front_mask_image'],
image_url=layer['front_image_url'], image_url=layer['front_image_url'],
mask_url=layer['mask_url'], mask_url=layer.get('mask_url', None),
sacle=layer['scale'], sacle=layer['scale'],
clothes_keypoint=(0, 0), clothes_keypoint=(0, 0),
position=start_point, position=start_point,
resize_scale=layer["resize_scale"], resize_scale=layer["resize_scale"],
mask=cv2.resize(layer['mask'], layer["front_image"].size), mask=cv2.resize(layer['mask'], layer["front_image"].size),
gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "", gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "",
pattern_overall_image_url=layer['pattern_overall_image_url'], pattern_overall_image_url=layer.get('pattern_overall_image_url', None),
pattern_print_image_url=layer['pattern_print_image_url'], pattern_print_image_url=layer.get('pattern_print_image_url', None),
pattern_image=layer['pattern_image'], pattern_image=layer.get('pattern_image', None),
# back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else "" # back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else ""
) )
# 后片数据 # 后片数据
@@ -94,15 +99,15 @@ def organize_others(layer):
image=layer["back_image"], image=layer["back_image"],
# mask_image=layer['back_mask_image'], # mask_image=layer['back_mask_image'],
image_url=layer['back_image_url'], image_url=layer['back_image_url'],
mask_url=layer['mask_url'], mask_url=layer.get('mask_url', None),
sacle=layer['scale'], sacle=layer['scale'],
clothes_keypoint=(0, 0), clothes_keypoint=(0, 0),
position=start_point, position=start_point,
resize_scale=layer["resize_scale"], resize_scale=layer["resize_scale"],
mask=cv2.resize(layer['mask'], layer["front_image"].size), mask=cv2.resize(layer['mask'], layer["front_image"].size),
gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "", gradient_string=layer['gradient_string'] if 'gradient_string' in layer.keys() else "",
pattern_overall_image_url=layer['pattern_overall_image_url'], pattern_overall_image_url=layer.get('pattern_overall_image_url', None),
pattern_print_image_url=layer['pattern_print_image_url'], pattern_print_image_url=layer.get('pattern_print_image_url', None),
# back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else "" # back_perspective_url=layer['back_perspective_url'] if 'back_perspective_url' in layer.keys() else ""
) )
return front_layer, back_layer return front_layer, back_layer

View File

@@ -151,9 +151,11 @@ def synthesis(data, size, basic_info):
if layer['image'] is not None: if layer['image'] is not None:
if layer['name'] != "body": if layer['name'] != "body":
test_image = Image.new('RGBA', size, (0, 0, 0, 0)) test_image = Image.new('RGBA', size, (0, 0, 0, 0))
test_image.paste(layer['image'], (layer['adaptive_position'][1], layer['adaptive_position'][0]), layer['image']) paste_img, position = transpose_rotate(layer, layer['image'])
test_image.paste(paste_img, position, paste_img)
mask_data = np.where(all_mask > 0, 255, 0).astype(np.uint8) mask_data = np.where(all_mask > 0, 255, 0).astype(np.uint8)
mask_alpha = Image.fromarray(mask_data) mask_alpha = Image.fromarray(mask_data)
mask_alpha.paste(paste_img.getchannel('A'), position, paste_img.getchannel('A'))
cropped_image = Image.composite(test_image, Image.new("RGBA", test_image.size, (255, 255, 255, 0)), mask_alpha) cropped_image = Image.composite(test_image, Image.new("RGBA", test_image.size, (255, 255, 255, 0)), mask_alpha)
base_image.paste(test_image, (0, 0), cropped_image) # test_image 已经按照坐标贴到最大宽值的图片上 坐着这里坐标为00 base_image.paste(test_image, (0, 0), cropped_image) # test_image 已经按照坐标贴到最大宽值的图片上 坐着这里坐标为00
else: else:
@@ -185,6 +187,111 @@ def synthesis(data, size, basic_info):
logging.warning(f"synthesis runtime exception : {e}") logging.warning(f"synthesis runtime exception : {e}")
def merge(data, size, basic_info):
# out_of_bounds_control: 是否允许服装越界 True 允许 False 不允许 默认情况允许
out_of_bounds_control = basic_info.get('out_of_bounds_control', True)
# 创建底图
base_image = Image.new('RGBA', size, (0, 0, 0, 0))
try:
all_mask_shape = (size[1], size[0])
body_mask = None
for d in data:
if d['name'] == 'body' or d['name'] == 'mannequin':
# 创建一个新的宽高透明图像, 把模特贴上去获取mask
transparent_image = Image.new("RGBA", size, (0, 0, 0, 0))
transparent_image.paste(d['image'], (d['adaptive_position'][1], d['adaptive_position'][0]), d['image']) # 此处可变数组会被paste篡改值所以使用下标获取position
body_mask = np.array(transparent_image.split()[3])
# 根据新的坐标获取新的肩点
left_shoulder = [x + y for x, y in zip(basic_info['body_point_test']['shoulder_left'], [d['adaptive_position'][1], d['adaptive_position'][0]])]
right_shoulder = [x + y for x, y in zip(basic_info['body_point_test']['shoulder_right'], [d['adaptive_position'][1], d['adaptive_position'][0]])]
body_mask[:min(left_shoulder[1], right_shoulder[1]), left_shoulder[0]:right_shoulder[0]] = 255
_, binary_body_mask = cv2.threshold(body_mask, 127, 255, cv2.THRESH_BINARY)
top_outer_mask = np.array(binary_body_mask)
bottom_outer_mask = np.array(binary_body_mask)
others_outer_mask = np.array(binary_body_mask)
top = True
bottom = True
others = True
i = len(data)
while i:
i -= 1
if top and data[i]['name'] in ["blouse_front", "outwear_front", "dress_front", "tops_front"]:
if out_of_bounds_control:
top = True
else:
top = False
mask_shape = data[i]['mask'].shape
y_offset, x_offset = data[i]['adaptive_position']
# 初始化叠加区域的起始和结束位置
all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset)
all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset)
# 将叠加区域赋值为相应的像素值
_, sketch_mask = cv2.threshold(data[i]['mask'], 127, 255, cv2.THRESH_BINARY)
background = np.zeros_like(top_outer_mask)
background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end]
top_outer_mask = background + top_outer_mask
elif bottom and data[i]['name'] in ["trousers_front", "skirt_front", "bottoms_front", "dress_front"]:
# bottom = False
mask_shape = data[i]['mask'].shape
y_offset, x_offset = data[i]['adaptive_position']
# 初始化叠加区域的起始和结束位置
all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset)
all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset)
# 将叠加区域赋值为相应的像素值
_, sketch_mask = cv2.threshold(data[i]['mask'], 127, 255, cv2.THRESH_BINARY)
background = np.zeros_like(top_outer_mask)
background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end]
bottom_outer_mask = background + bottom_outer_mask
elif others and data[i]['name'] in ['others_front']:
mask_shape = data[i]['mask'].shape
y_offset, x_offset = data[i]['adaptive_position']
# 初始化叠加区域的起始和结束位置
all_y_start, all_y_end, mask_y_start, mask_y_end = positioning(all_mask_shape=all_mask_shape[0], mask_shape=mask_shape[0], offset=y_offset)
all_x_start, all_x_end, mask_x_start, mask_x_end = positioning(all_mask_shape=all_mask_shape[1], mask_shape=mask_shape[1], offset=x_offset)
# 将叠加区域赋值为相应的像素值
_, sketch_mask = cv2.threshold(data[i]['mask'], 127, 255, cv2.THRESH_BINARY)
background = np.zeros_like(top_outer_mask)
background[all_y_start:all_y_end, all_x_start:all_x_end] = sketch_mask[mask_y_start:mask_y_end, mask_x_start:mask_x_end]
others_outer_mask = background + others_outer_mask
pass
elif bottom is False and top is False:
break
all_mask = cv2.bitwise_or(top_outer_mask, bottom_outer_mask)
all_mask = cv2.bitwise_or(all_mask, others_outer_mask)
for layer in data:
if layer['image'] is not None:
if layer['name'] != "body":
test_image = Image.new('RGBA', size, (0, 0, 0, 0))
paste_img, position = transpose_rotate(layer, layer['image'])
test_image.paste(paste_img, position, paste_img)
mask_data = np.where(all_mask > 0, 255, 0).astype(np.uint8)
mask_alpha = Image.fromarray(mask_data)
mask_alpha.paste(paste_img.getchannel('A'), position, paste_img.getchannel('A'))
cropped_image = Image.composite(test_image, Image.new("RGBA", test_image.size, (255, 255, 255, 0)), mask_alpha)
base_image.paste(test_image, (0, 0), cropped_image) # test_image 已经按照坐标贴到最大宽值的图片上 坐着这里坐标为00
else:
base_image.paste(layer['merge_image'], (layer['adaptive_position'][1], layer['adaptive_position'][0]), layer['merge_image'])
result_image = base_image
image_data = io.BytesIO()
result_image.save(image_data, format='PNG')
image_data.seek(0)
# oss upload
image_bytes = image_data.read()
bucket_name = "aida-results"
object_name = f'result_{generate_uuid()}.png'
oss_upload_image(oss_client=minio_client, bucket=bucket_name, object_name=object_name, image_bytes=image_bytes)
return f"{bucket_name}/{object_name}"
except Exception as e:
logging.warning(f"synthesis runtime exception : {e}")
def synthesis_single(front_image, back_image): def synthesis_single(front_image, back_image):
result_image = None result_image = None
if front_image: if front_image:
@@ -232,3 +339,35 @@ def update_base_size_priority(layers):
for info in layers: for info in layers:
info['adaptive_position'] = (info['position'][0], info['position'][1] - min_x) info['adaptive_position'] = (info['position'][0], info['position'][1] - min_x)
return layers, (new_width, new_height) return layers, (new_width, new_height)
def transpose_rotate(layer, image):
# transpose[0]是左右 transpose[1]是上下
transpose = layer.get('transpose', [1, 1]) # 默认为1, 1代表不镜像
rotate = layer.get('rotate', 0)
paste_x, paste_y = layer['adaptive_position'][1], layer['adaptive_position'][0]
# transpose左右是1 上下是-1
if transpose[0] != 1:
# 左右
image = image.transpose(0)
if transpose[1] != 1:
# 上下
image = image.transpose(1)
if rotate:
image = image.rotate(-rotate, expand=True)
# 4. 计算粘贴位置以保持视觉中心一致
# 原本 (15, 36) 是 288*288 的左上角,我们计算其中心点
target_center_x = 15 + 288 // 2
target_center_y = 36 + 288 // 2
# 获取旋转后图像的新尺寸
new_w, new_h = image.size
# 计算新的左上角坐标,使得旋转后的图像中心依然在原定的中心位置
paste_x = target_center_x - new_w // 2
paste_y = target_center_y - new_h // 2
return image, (paste_x, paste_y)

View File

@@ -14,7 +14,7 @@ REDIS_KEY_USER_PREF_PREFIX = "user_pref"
RECOMMENDATION_CONFIG = { RECOMMENDATION_CONFIG = {
# 时间衰减半衰期(用于计算时间衰减权重) # 时间衰减半衰期(用于计算时间衰减权重)
# 值越小,最近的行为权重越大 # 值越小,最近的行为权重越大
"K_half": 20, "K_half": 10,
# 探索与利用的比例 (0.0-1.0) # 探索与利用的比例 (0.0-1.0)
# - 值越大,使用探索分支(随机推荐)的几率越大,结果更随机 # - 值越大,使用探索分支(随机推荐)的几率越大,结果更随机
@@ -25,7 +25,7 @@ RECOMMENDATION_CONFIG = {
# 向量检索返回的候选数量 # 向量检索返回的候选数量
# 值越大,候选池越大,但计算成本也越高 # 值越大,候选池越大,但计算成本也越高
# 建议范围: 100-1000 # 建议范围: 100-1000
"topk": 1000, "topk": 200,
# Style 加分系数(同 style 的候选进行加分) # Style 加分系数(同 style 的候选进行加分)
# 值越大,匹配 style 的候选被选中的概率越大 # 值越大,匹配 style 的候选被选中的概率越大
@@ -53,7 +53,7 @@ RECOMMENDATION_CONFIG = {
} }
# 数据库表名 # 数据库表名
TABLE_USER_PREFERENCE_LOG = "user_preference_log_test" TABLE_USER_PREFERENCE_LOG = "user_preference"
TABLE_SYS_FILE = "t_sys_file" TABLE_SYS_FILE = "t_sys_file"
# MySQL 连接配置(用于推荐系统) # MySQL 连接配置(用于推荐系统)

View File

@@ -1,6 +1,6 @@
""" """
增量监听模块 增量监听模块
实时监听 user_preference_log_test 表的新增记录,更新用户偏好向量 实时监听 user_preference 表的新增记录,更新用户偏好向量
""" """
import logging import logging
import math import math
@@ -48,7 +48,7 @@ class IncrementalListener:
if self.last_process_time is None: if self.last_process_time is None:
# 第一次运行查询最近30分钟的数据 # 第一次运行查询最近30分钟的数据
cursor.execute(f""" cursor.execute(f"""
SELECT id, account_id, path, category, style, data_time, is_system_sketch, sys_file_id SELECT id, account_id, path, category, style, data_time
FROM {TABLE_USER_PREFERENCE_LOG} FROM {TABLE_USER_PREFERENCE_LOG}
WHERE data_time > DATE_SUB(NOW(), INTERVAL 30 MINUTE) WHERE data_time > DATE_SUB(NOW(), INTERVAL 30 MINUTE)
ORDER BY data_time ORDER BY data_time
@@ -56,7 +56,7 @@ class IncrementalListener:
else: else:
# 基于上次处理时间查询 # 基于上次处理时间查询
cursor.execute(f""" cursor.execute(f"""
SELECT id, account_id, path, category, style, data_time, is_system_sketch, sys_file_id SELECT id, account_id, path, category, style, data_time
FROM {TABLE_USER_PREFERENCE_LOG} FROM {TABLE_USER_PREFERENCE_LOG}
WHERE data_time > %s WHERE data_time > %s
ORDER BY data_time ORDER BY data_time
@@ -258,7 +258,7 @@ class IncrementalListener:
} }
else: else:
# 用户图 # 用户图
# 从 user_preference_log_test 获取 category如果有 # 从 user_preference 获取 category如果有
cursor.execute(f""" cursor.execute(f"""
SELECT category SELECT category
FROM {TABLE_USER_PREFERENCE_LOG} FROM {TABLE_USER_PREFERENCE_LOG}
@@ -308,6 +308,10 @@ class IncrementalListener:
def start_background_listener(scheduler: BackgroundScheduler): def start_background_listener(scheduler: BackgroundScheduler):
"""将增量监听任务注册到后台调度器""" """将增量监听任务注册到后台调度器"""
# 降低 apscheduler 的日志级别,避免大量刷屏
logging.getLogger('apscheduler.executors.default').setLevel(logging.WARNING)
logging.getLogger('apscheduler.scheduler').setLevel(logging.WARNING)
listener = IncrementalListener() listener = IncrementalListener()
scheduler.add_job( scheduler.add_job(
listener.process_once, listener.process_once,

View File

@@ -23,7 +23,7 @@ def get_milvus_client() -> MilvusClient:
_milvus_client = MilvusClient( _milvus_client = MilvusClient(
uri=settings.MILVUS_URL, uri=settings.MILVUS_URL,
token=settings.MILVUS_TOKEN, token=settings.MILVUS_TOKEN,
db_name=settings.MILVUS_DB, db_name="",
) )
logger.info("Milvus 客户端连接成功") logger.info("Milvus 客户端连接成功")
except Exception as e: except Exception as e:
@@ -203,39 +203,74 @@ def search_similar_vectors(
query_vector: np.ndarray, query_vector: np.ndarray,
category: str, category: str,
topk: int = 500, topk: int = 500,
style: Optional[str] = None style: Optional[str] = None,
style_boost_ratio: float = 0.2
) -> List[Dict]: ) -> List[Dict]:
""" """
向量相似度检索 向量相似度检索
Args: Args:
query_vector: 查询向量2048维 query_vector: 查询向量2048维
category: 类别过滤 category: 类别过滤
topk: 返回数量 topk: 返回数量
style: 风格过滤(可选) style: 风格过滤(可选)- 当提供时会给对应style的结果加分
style_boost_ratio: 风格加分比例默认0.1即10%
Returns: Returns:
检索结果列表,每个元素包含 path, score, style, category 等字段 检索结果列表,每个元素包含 path, score, style, category 等字段
""" """
client = get_milvus_client() client = get_milvus_client()
try: try:
# 构建过滤表达式 # 如果没有指定style使用原始逻辑
# 使用 filter 参数而不是 expr根据 pymilvus MilvusClient API if not style:
filter_expr = f"category == '{category}' && deprecated == 0" filter_expr = f"category == '{category}' && deprecated == 0"
if style: results = client.search(
filter_expr += f" && style == '{style}'" collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
data=[query_vector.tolist()],
anns_field="feature_vector",
search_params={"metric_type": "IP", "params": {"nprobe": 10}},
limit=topk,
filter=filter_expr,
output_fields=["path", "style", "category", "sys_file_id"]
)
else:
# 有style参数时使用两阶段搜索策略
# 搜索 # 第一阶段搜索匹配style的向量使用boosted query vector
results = client.search( filter_expr_style = f"category == '{category}' && deprecated == 0 && style == '{style}'"
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS, boosted_query = query_vector * (1 + style_boost_ratio)
data=[query_vector.tolist()], results_style = client.search(
anns_field="feature_vector", collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
search_params={"metric_type": "IP", "params": {"nprobe": 10}}, data=[boosted_query.tolist()],
limit=topk, anns_field="feature_vector",
filter=filter_expr, search_params={"metric_type": "IP", "params": {"nprobe": 10}},
output_fields=["path", "style", "category", "sys_file_id"] limit=topk,
) filter=filter_expr_style,
output_fields=["path", "style", "category", "sys_file_id"]
)
# 第二阶段搜索其他style的向量
filter_expr_others = f"category == '{category}' && deprecated == 0 && style != '{style}'"
results_others = client.search(
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
data=[query_vector.tolist()],
anns_field="feature_vector",
search_params={"metric_type": "IP", "params": {"nprobe": 10}},
limit=topk,
filter=filter_expr_others,
output_fields=["path", "style", "category", "sys_file_id"]
)
# 合并结果
results = []
if results_style and len(results_style) > 0:
results.extend(results_style[0])
if results_others and len(results_others) > 0:
results.extend(results_others[0])
# 转换为单个结果列表格式
results = [results] if results else []
# 格式化结果 # 格式化结果
formatted_results = [] formatted_results = []
@@ -249,7 +284,10 @@ def search_similar_vectors(
"sys_file_id": hit.get("entity", {}).get("sys_file_id") "sys_file_id": hit.get("entity", {}).get("sys_file_id")
}) })
return formatted_results # 按分数排序并返回topk
formatted_results.sort(key=lambda x: x["score"], reverse=True)
return formatted_results[:topk]
except Exception as e: except Exception as e:
logger.error(f"向量检索失败: {e}", exc_info=True) logger.error(f"向量检索失败: {e}", exc_info=True)
return [] return []
@@ -280,7 +318,7 @@ def query_random_candidates(category: str, style: Optional[str] = None, limit: i
collection_name=MILVUS_COLLECTION_SKETCH_VECTORS, collection_name=MILVUS_COLLECTION_SKETCH_VECTORS,
filter=filter_expr, filter=filter_expr,
output_fields=["path", "style", "category"], output_fields=["path", "style", "category"],
limit=10000 # 先查询大量数据,然后随机选择 limit=10000
) )
# 随机选择 # 随机选择

View File

@@ -6,6 +6,7 @@ import logging
import math import math
import pymysql import pymysql
import numpy as np import numpy as np
from datetime import datetime
from typing import List, Dict, Tuple, Optional from typing import List, Dict, Tuple, Optional
from collections import defaultdict from collections import defaultdict
@@ -25,7 +26,7 @@ logger = logging.getLogger(__name__)
def optimize_database_table(): def optimize_database_table():
""" """
优化 user_preference_log_test 表结构 优化 user_preference 表结构
添加冗余字段和索引 添加冗余字段和索引
""" """
conn = None conn = None
@@ -317,8 +318,8 @@ def precompute_system_sketch_vectors(batch_size: int = 1000, retry_times: int =
def compute_user_preference_vector( def compute_user_preference_vector(
account_id: int, account_id: int,
category: str, category: str,
conn: Optional[pymysql.connections.Connection] = None conn: Optional[pymysql.connections.Connection] = None,
# max_date: Optional[datetime] = None max_date: Optional[datetime] = None
) -> Optional[np.ndarray]: ) -> Optional[np.ndarray]:
""" """
计算用户偏好向量 计算用户偏好向量
@@ -419,8 +420,8 @@ def compute_user_preference_vector(
p_i = 1 + math.log(1 + like_count) p_i = 1 + math.log(1 + like_count)
# 综合权重 # 综合权重
# w_i = d_k * p_i w_i = d_k * p_i
w_i = p_i # w_i = p_i
vectors.append(feature_vector) vectors.append(feature_vector)
weights.append(w_i) weights.append(w_i)
@@ -518,16 +519,16 @@ def run_precompute():
logger.info("=" * 50) logger.info("=" * 50)
# 1. 优化数据库表结构 # 1. 优化数据库表结构
logger.info("\n[1/5] 优化数据库表结构...") # logger.info("\n[1/5] 优化数据库表结构...")
optimize_database_table() # optimize_database_table()
# # 2. 创建 Milvus 集合 # # 2. 创建 Milvus 集合
# logger.info("\n[2/5] 创建 Milvus 集合...") # logger.info("\n[2/5] 创建 Milvus 集合...")
# create_collection() # create_collection()
# 3. 历史数据迁移 # 3. 历史数据迁移
logger.info("\n[3/5] 历史数据迁移...") # logger.info("\n[3/5] 历史数据迁移...")
migrate_historical_data() # migrate_historical_data()
# # 4. 系统图向量预计算 # # 4. 系统图向量预计算
# logger.info("\n[4/5] 系统图向量预计算...") # logger.info("\n[4/5] 系统图向量预计算...")
@@ -543,13 +544,13 @@ def run_precompute():
if __name__ == "__main__": if __name__ == "__main__":
# 1. 优化数据库表结构 # # 1. 优化数据库表结构
logger.info("\n[1/5] 优化数据库表结构...") # logger.info("\n[1/5] 优化数据库表结构...")
optimize_database_table() # optimize_database_table()
#
# 3. 历史数据迁移 # # 3. 历史数据迁移
logger.info("\n[3/5] 历史数据迁移...") # logger.info("\n[3/5] 历史数据迁移...")
migrate_historical_data() # migrate_historical_data()
# 5. 初始用户偏好向量生成 # 5. 初始用户偏好向量生成
logger.info("\n[5/5] 初始用户偏好向量生成...") logger.info("\n[5/5] 初始用户偏好向量生成...")

View File

@@ -10,7 +10,7 @@ from torchvision import models, transforms
from PIL import Image from PIL import Image
from minio import Minio from minio import Minio
from app.core.config import MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE from app.core.config import settings
from app.service.recommendation_system.config import RECOMMENDATION_CONFIG from app.service.recommendation_system.config import RECOMMENDATION_CONFIG
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -48,10 +48,10 @@ def get_minio_client():
global _minio_client global _minio_client
if _minio_client is None: if _minio_client is None:
_minio_client = Minio( _minio_client = Minio(
MINIO_URL, settings.MINIO_URL,
access_key=MINIO_ACCESS, access_key=settings.MINIO_ACCESS,
secret_key=MINIO_SECRET, secret_key=settings.MINIO_SECRET,
secure=MINIO_SECURE secure=settings.MINIO_SECURE
) )
return _minio_client return _minio_client

View File

@@ -81,7 +81,7 @@ if __name__ == '__main__':
# url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg" # url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg"
# url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png" # url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png"
# url = "aida-users/89/single_logo/123-89.png" # url = "aida-users/89/single_logo/123-89.png"
url = "lanecarford/lc_stylist_agent_outfit_items/141/ee25ec85-d504-4b42-9a18-db6682fe9e3b-6.jpg" url = "aida-results/result_a7adcbd8-ef8d-11f0-8c92-0966ede33ab5.png"
# url = "aida-collection-element/12148/Sketchboard/95ea577b-305b-4a62-b30a-39c0dd3ddb3f.png" # url = "aida-collection-element/12148/Sketchboard/95ea577b-305b-4a62-b30a-39c0dd3ddb3f.png"
read_type = "2" read_type = "2"

View File

@@ -91,6 +91,21 @@ class Redis(object):
r = cls._get_r() r = cls._get_r()
r.expire(name, expire_in_seconds) r.expire(name, expire_in_seconds)
@classmethod
def scan_keys(cls, pattern="*"):
"""
扫描匹配模式的key
"""
r = cls._get_r()
keys = []
cursor = 0
while True:
cursor, partial_keys = r.scan(cursor, match=pattern, count=1000)
keys.extend(partial_keys)
if cursor == 0:
break
return [key.decode('utf-8') if isinstance(key, bytes) else key for key in keys]
if __name__ == '__main__': if __name__ == '__main__':
redis_client = Redis() redis_client = Redis()

View File

@@ -10,4 +10,16 @@ services:
- /etc/localtime:/etc/localtime:ro - /etc/localtime:/etc/localtime:ro
- ./seg_cache:/seg_cache - ./seg_cache:/seg_cache
ports: ports:
- "10200:80" - "10200:80"
depends_on:
- redis
redis:
image: redis
container_name: aida_redis
restart: always
ports:
- "6400:6379"
volumes:
- ./redis/data:/data
- ./redis/conf/redis.conf:/etc/redis/redis.conf
command: redis-server /etc/redis/redis.conf --appendonly yes

View File

@@ -1,10 +1,15 @@
import os
from app.core.config import settings from app.core.config import settings
LOGGER_CONFIG_DICT = { LOGGER_CONFIG_DICT = {
'version': 1, 'version': 1,
'disable_existing_loggers': False, 'disable_existing_loggers': False,
'formatters': { 'formatters': {
'simple': {'format': '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s'} 'simple': {
'format': '%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s',
'datefmt': '%Y-%m-%d %H:%M:%S' # 补充日期格式,日志更易读
}
}, },
'handlers': { 'handlers': {
'console': { 'console': {
@@ -17,7 +22,7 @@ LOGGER_CONFIG_DICT = {
'class': 'logging.handlers.RotatingFileHandler', 'class': 'logging.handlers.RotatingFileHandler',
'level': 'INFO', 'level': 'INFO',
'formatter': 'simple', 'formatter': 'simple',
'filename': f'{settings.LOGS_PATH}info.log', 'filename': os.path.join(settings.LOGS_PATH, 'info.log'),
'maxBytes': 10485760, 'maxBytes': 10485760,
'backupCount': 50, 'backupCount': 50,
'encoding': 'utf8', 'encoding': 'utf8',
@@ -26,7 +31,7 @@ LOGGER_CONFIG_DICT = {
'class': 'logging.handlers.RotatingFileHandler', 'class': 'logging.handlers.RotatingFileHandler',
'level': 'ERROR', 'level': 'ERROR',
'formatter': 'simple', 'formatter': 'simple',
'filename': f'{settings.LOGS_PATH}error.log', 'filename': os.path.join(settings.LOGS_PATH, 'error.log'),
'maxBytes': 10485760, 'maxBytes': 10485760,
'backupCount': 20, 'backupCount': 20,
'encoding': 'utf8', 'encoding': 'utf8',
@@ -35,7 +40,7 @@ LOGGER_CONFIG_DICT = {
'class': 'logging.handlers.RotatingFileHandler', 'class': 'logging.handlers.RotatingFileHandler',
'level': 'DEBUG', 'level': 'DEBUG',
'formatter': 'simple', 'formatter': 'simple',
'filename': f'{settings.LOGS_PATH}debug.log', 'filename': os.path.join(settings.LOGS_PATH, 'debug.log'),
'maxBytes': 10485760, 'maxBytes': 10485760,
'backupCount': 50, 'backupCount': 50,
'encoding': 'utf8', 'encoding': 'utf8',
@@ -45,7 +50,7 @@ LOGGER_CONFIG_DICT = {
'my_module': {'level': 'INFO', 'handlers': ['console'], 'propagate': 'no'} 'my_module': {'level': 'INFO', 'handlers': ['console'], 'propagate': 'no'}
}, },
'root': { 'root': {
'level': 'INFO', 'level': 'DEBUG',
'handlers': ['error_file_handler', 'info_file_handler', 'debug_file_handler', 'console'], 'handlers': ['error_file_handler', 'info_file_handler', 'debug_file_handler', 'console'],
}, },
} }