From 00b8e9fb02b6db39e60186e975f52f74f25ccb06 Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Thu, 13 Mar 2025 15:14:19 +0800 Subject: [PATCH] =?UTF-8?q?feat=EF=BC=88=E6=96=B0=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=89:=20fix=EF=BC=88=E4=BF=AE=E5=A4=8Dbug=EF=BC=89:=20docs?= =?UTF-8?q?=EF=BC=88=E6=96=87=E6=A1=A3=E5=8F=98=E6=9B=B4=EF=BC=89:=20refac?= =?UTF-8?q?tor=EF=BC=88=E9=87=8D=E6=9E=84=EF=BC=89:=20test(=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0=E6=B5=8B=E8=AF=95):=20Agent=20generate=20test?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api/api_agent_generate_image.py | 13 ++++++------ app/service/generate_image/agent_generate.py | 22 ++++++++++++-------- 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/app/api/api_agent_generate_image.py b/app/api/api_agent_generate_image.py index 9aeda19..e4001ff 100644 --- a/app/api/api_agent_generate_image.py +++ b/app/api/api_agent_generate_image.py @@ -1,6 +1,8 @@ +import io import logging from fastapi import APIRouter, HTTPException +from starlette.responses import StreamingResponse from app.schemas.response_template import ResponseModel from app.service.generate_image.agent_generate import GenerateImage @@ -11,10 +13,7 @@ logger = logging.getLogger() @router.get("/agent_generate_image") def generate_image(prompt: str): - try: - server = GenerateImage() - data = server.get_result(prompt) - except Exception as e: - logger.warning(f"generate_image Run Exception @@@@@@:{e}") - raise HTTPException(status_code=404, detail=str(e)) - return data + server = GenerateImage() + byte_stream = server.get_result(prompt) + # 返回流式响应 + return StreamingResponse(byte_stream, media_type="image/png") diff --git a/app/service/generate_image/agent_generate.py b/app/service/generate_image/agent_generate.py index 24623dc..58ac869 100644 --- a/app/service/generate_image/agent_generate.py +++ b/app/service/generate_image/agent_generate.py @@ -7,6 +7,7 @@ @Date :2023/7/26 12:01:05 @detail : """ +import io import logging from datetime import timedelta @@ -52,15 +53,18 @@ class GenerateImage: image = result.as_numpy("generated_image") image_result = cv2.cvtColor(np.squeeze(image.astype(np.uint8)), cv2.COLOR_RGB2BGR) _, img_byte_array = cv2.imencode('.jpg', image_result) - object_name = f'test.jpg' - req = oss_upload_image(bucket='test', object_name=object_name, image_bytes=img_byte_array) - url = self.minio_client.get_presigned_url( - "GET", - "test", - object_name, - expires=timedelta(hours=2), - ) - return url + byte_stream = io.BytesIO(img_byte_array) + byte_stream.seek(0) + + # object_name = f'test.jpg' + # req = oss_upload_image(bucket='test', object_name=object_name, image_bytes=img_byte_array) + # url = self.minio_client.get_presigned_url( + # "GET", + # "test", + # object_name, + # expires=timedelta(hours=2), + # ) + return byte_stream if __name__ == '__main__':