Merge remote-tracking branch 'origin/develop' into research
This commit is contained in:
File diff suppressed because it is too large
Load Diff
35
app/service/sketch2garment/callback.py
Normal file
35
app/service/sketch2garment/callback.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import logging
|
||||
|
||||
import httpx
|
||||
|
||||
logger = logging.getLogger("app")
|
||||
|
||||
|
||||
async def notify_callback(callback_url: str, task_id: str, status: str, result: dict, ):
|
||||
"""
|
||||
调用客户端提供的回调接口
|
||||
"""
|
||||
try:
|
||||
payload = {
|
||||
"task_id": task_id,
|
||||
"status": status,
|
||||
"result": result
|
||||
}
|
||||
logger.info(payload)
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
resp = await client.post(
|
||||
str(callback_url),
|
||||
json=payload,
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
if 200 <= resp.status_code < 300:
|
||||
logger.info(f"回调成功 | task_id: {task_id} | status: {status} | url: {callback_url}")
|
||||
return True
|
||||
else:
|
||||
logger.warning(f"回调返回非2xx状态码 | task_id: {task_id} | status: {resp.status_code} | url: {callback_url}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"回调失败 | task_id: {task_id} | url: {callback_url} | error: {e}", exc_info=True)
|
||||
return False
|
||||
46
app/service/sketch2garment/celery_app.py
Normal file
46
app/service/sketch2garment/celery_app.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from celery import Celery
|
||||
from kombu import Queue, Exchange
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
celery_app = Celery(
|
||||
"sketch_to_garment",
|
||||
broker=f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}/2",
|
||||
backend=f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}/{settings.REDIS_DB}",
|
||||
include=["app.service.sketch2garment.tasks"]
|
||||
)
|
||||
print(f"redis://{settings.REDIS_HOST}:{settings.REDIS_PORT}/3")
|
||||
print(f"celery_app: {celery_app}")
|
||||
|
||||
celery_app.conf.update(
|
||||
task_serializer="json",
|
||||
accept_content=["json"],
|
||||
result_serializer="json",
|
||||
timezone="Asia/Hong_Kong",
|
||||
enable_utc=True,
|
||||
task_track_started=True,
|
||||
task_time_limit=300, # 单个任务最长 5 分钟
|
||||
task_soft_time_limit=280,
|
||||
# 定义队列
|
||||
task_queues=(
|
||||
Queue("sketch_to_garment_queue",
|
||||
exchange=Exchange("sketch_to_garment_exchange", type="direct"),
|
||||
durable=True),
|
||||
|
||||
),
|
||||
|
||||
task_routes={
|
||||
'app.service.sketch2garment.tasks.sketch_to_garment':
|
||||
{
|
||||
'queue': 'sketch_to_garment_queue',
|
||||
'exchange': 'sketch_to_garment_exchange', # ← 修改这里
|
||||
},
|
||||
},
|
||||
task_default_queue="sketch_to_garment_queue",
|
||||
|
||||
worker_concurrency=1,
|
||||
worker_prefetch_multiplier=1,
|
||||
worker_max_tasks_per_child=1,
|
||||
task_acks_late=True,
|
||||
task_reject_on_worker_lost=True,
|
||||
)
|
||||
44
app/service/sketch2garment/server.py
Normal file
44
app/service/sketch2garment/server.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import logging
|
||||
|
||||
from app.service.sketch2garment.tasks import sketch_to_garment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def submit_sketch_to_garment_task(model: str = "single", task_id: str = "", callback_url: str = "", bucket_name: str = "test", user_id: str = "123", input_image_path: str = ""):
|
||||
"""提交 img_to_3D 任务(带队列长度限制)"""
|
||||
queue_name = "img_to_3d_queue"
|
||||
max_queue_length = 10
|
||||
|
||||
try:
|
||||
# current_length = get_queue_length(queue_name)
|
||||
|
||||
# if current_length >= max_queue_length:
|
||||
# return {
|
||||
# "state": "queue_full",
|
||||
# "message": "当前 3D 生成请求较多,请稍后重试。",
|
||||
# "queue_length": current_length,
|
||||
# "max_length": max_queue_length
|
||||
# }
|
||||
|
||||
# 提交任务
|
||||
task = sketch_to_garment.apply_async(
|
||||
args=(task_id, callback_url, bucket_name, input_image_path, user_id, model),
|
||||
task_id=task_id,
|
||||
queue="sketch_to_garment_queue")
|
||||
|
||||
# logger.info(f"img_to_3d_task 已提交 | task_id: {task_id} | 当前队列长度: {current_length}")
|
||||
|
||||
return {
|
||||
"state": "success",
|
||||
"task_id": task_id,
|
||||
"message": "任务已成功提交,正在后台处理...",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"提交 img_to_3d_task 失败: {e}", exc_info=True)
|
||||
return {
|
||||
"state": "fail",
|
||||
"message": "提交失败,请稍后重试。",
|
||||
"error": str(e)
|
||||
}
|
||||
57
app/service/sketch2garment/tasks.py
Normal file
57
app/service/sketch2garment/tasks.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
from app.core.config import settings
|
||||
from app.service.sketch2garment.callback import notify_callback
|
||||
import httpx
|
||||
|
||||
from app.service.sketch2garment.celery_app import celery_app
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@celery_app.task(bind=True, queue="sketch_to_garment_queue", max_retries=3, name='app.service.sketch2garment.tasks.sketch_to_garment')
|
||||
def sketch_to_garment(self, task_id: str, callback_url: str, bucket_name: str, input_image_path: str, user_id: str, category: str = None):
|
||||
payload = {
|
||||
"bucket_name": bucket_name,
|
||||
"category": category or settings.DEFAULT_CATEGORY,
|
||||
"input_image_path": input_image_path,
|
||||
"user_id": user_id
|
||||
}
|
||||
logger.info(f"payload: {payload}")
|
||||
|
||||
try:
|
||||
with httpx.Client(timeout=300.0) as client: # 注意这里用 AsyncClient 配合 Celery
|
||||
# 如果你的 LitServe 是同步 endpoint,也可以用 httpx.Client()
|
||||
response = client.post(settings.SKETCH_TO_GARMENT_URL, json=payload)
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
result_json = {
|
||||
"pattern": result[1],
|
||||
"texture": result[2],
|
||||
"glb": result[3],
|
||||
"texture_fabric": result[4]
|
||||
}
|
||||
asyncio.run(
|
||||
notify_callback(callback_url=callback_url, task_id=task_id, result=result_json, status="success")
|
||||
)
|
||||
else:
|
||||
asyncio.run(
|
||||
notify_callback(
|
||||
callback_url=callback_url,
|
||||
task_id=task_id,
|
||||
result={
|
||||
"status": "fail",
|
||||
"task_id": task_id,
|
||||
"message": "fail",
|
||||
"error": "fail"
|
||||
},
|
||||
status="fail")
|
||||
)
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "failed",
|
||||
"task_id": task_id,
|
||||
"input": payload,
|
||||
"error": str(e)
|
||||
}
|
||||
@@ -11,7 +11,12 @@ from minio import Minio
|
||||
from app.core.config import settings
|
||||
from app.service.utils.decorator import RunTime
|
||||
|
||||
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
|
||||
minio_client = Minio(
|
||||
settings.MINIO_URL,
|
||||
access_key=settings.MINIO_ACCESS,
|
||||
secret_key=settings.MINIO_SECRET,
|
||||
secure=settings.MINIO_SECURE,
|
||||
)
|
||||
|
||||
|
||||
# 自定义 Retry 类
|
||||
@@ -30,7 +35,7 @@ http_client = urllib3.PoolManager(
|
||||
num_pools=10, # 设置连接池大小
|
||||
maxsize=10,
|
||||
timeout=timeout,
|
||||
cert_reqs='CERT_REQUIRED', # 需要证书验证
|
||||
cert_reqs="CERT_REQUIRED", # 需要证书验证
|
||||
retries=CustomRetry(
|
||||
total=5,
|
||||
backoff_factor=0.2,
|
||||
@@ -51,7 +56,7 @@ def oss_get_image(oss_client, bucket, object_name, data_type):
|
||||
image_array = np.frombuffer(image_bytes, np.uint8) # 转成8位无符号整型
|
||||
image_object = cv2.imdecode(image_array, cv2.IMREAD_UNCHANGED)
|
||||
if image_object.dtype == np.uint16:
|
||||
image_object = (image_object / 256).astype('uint8')
|
||||
image_object = (image_object / 256).astype("uint8")
|
||||
else:
|
||||
data_bytes = BytesIO(image_data.read())
|
||||
image_object = Image.open(data_bytes)
|
||||
@@ -63,13 +68,19 @@ def oss_get_image(oss_client, bucket, object_name, data_type):
|
||||
def oss_upload_image(oss_client, bucket, object_name, image_bytes):
|
||||
req = None
|
||||
try:
|
||||
req = oss_client.put_object(bucket_name=bucket, object_name=object_name, data=io.BytesIO(image_bytes), length=len(image_bytes), content_type='image/png')
|
||||
req = oss_client.put_object(
|
||||
bucket_name=bucket,
|
||||
object_name=object_name,
|
||||
data=io.BytesIO(image_bytes),
|
||||
length=len(image_bytes),
|
||||
content_type="image/png",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f" | 上传图片出现异常 ######: {e}")
|
||||
return req
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
# url = "aida-results/result_0002186a-e631-11ee-86a6-b48351119060.png"
|
||||
# url = "aida-collection-element/11523/Moodboard/f60af0d2-94c2-48f9-90ff-74b8e8a481b5.jpg"
|
||||
# url = "aida-sys-image/images/female/outwear/0628000054.jpg"
|
||||
@@ -81,16 +92,26 @@ if __name__ == '__main__':
|
||||
# url = "aida-users/89/sketchboard/female/Dress/e6724ab7-8d3f-4677-abe0-c3e42ab7af85.jpeg"
|
||||
# url = "aida-users/87/print/956614a2-7e75-4fbe-9ed0-c1831e37a2c9-4-87.png"
|
||||
# url = "aida-users/89/single_logo/123-89.png"
|
||||
url = "aida-results/result_a7adcbd8-ef8d-11f0-8c92-0966ede33ab5.png"
|
||||
url = "aida-collection-element/26293/Sketchboard/b503d482-3334-46e7-9dee-44e380fb4294.png"
|
||||
|
||||
# url = "aida-collection-element/12148/Sketchboard/95ea577b-305b-4a62-b30a-39c0dd3ddb3f.png"
|
||||
read_type = "2"
|
||||
if read_type == "cv2":
|
||||
img = oss_get_image(oss_client=minio_client, bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type)
|
||||
img = oss_get_image(
|
||||
oss_client=minio_client,
|
||||
bucket=url.split("/")[0],
|
||||
object_name=url[url.find("/") + 1 :],
|
||||
data_type=read_type,
|
||||
)
|
||||
cv2.imshow("", img)
|
||||
cv2.waitKey(0)
|
||||
else:
|
||||
img = oss_get_image(oss_client=minio_client, bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type)
|
||||
img = oss_get_image(
|
||||
oss_client=minio_client,
|
||||
bucket=url.split("/")[0],
|
||||
object_name=url[url.find("/") + 1 :],
|
||||
data_type=read_type,
|
||||
)
|
||||
draw = ImageDraw.Draw(img)
|
||||
# 获取图片尺寸
|
||||
width, height = img.size
|
||||
@@ -103,7 +124,7 @@ if __name__ == '__main__':
|
||||
draw.line(
|
||||
[(center_x, 0), (center_x, height)], # 从顶部到底部的垂直线
|
||||
fill=(255, 0, 0), # 红色 (R, G, B)
|
||||
width=2 # 线宽
|
||||
width=2, # 线宽
|
||||
)
|
||||
|
||||
img.show()
|
||||
|
||||
Reference in New Issue
Block a user