feat(新功能):

fix(修复bug):
docs(文档变更):
refactor(重构):
test(增加测试): Agent generate test
This commit is contained in:
zhouchengrong
2025-03-13 15:14:19 +08:00
parent 2e717f0145
commit 00b8e9fb02
2 changed files with 19 additions and 16 deletions

View File

@@ -1,6 +1,8 @@
import io
import logging import logging
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, HTTPException
from starlette.responses import StreamingResponse
from app.schemas.response_template import ResponseModel from app.schemas.response_template import ResponseModel
from app.service.generate_image.agent_generate import GenerateImage from app.service.generate_image.agent_generate import GenerateImage
@@ -11,10 +13,7 @@ logger = logging.getLogger()
@router.get("/agent_generate_image") @router.get("/agent_generate_image")
def generate_image(prompt: str): def generate_image(prompt: str):
try:
server = GenerateImage() server = GenerateImage()
data = server.get_result(prompt) byte_stream = server.get_result(prompt)
except Exception as e: # 返回流式响应
logger.warning(f"generate_image Run Exception @@@@@@:{e}") return StreamingResponse(byte_stream, media_type="image/png")
raise HTTPException(status_code=404, detail=str(e))
return data

View File

@@ -7,6 +7,7 @@
@Date 2023/7/26 12:01:05 @Date 2023/7/26 12:01:05
@detail @detail
""" """
import io
import logging import logging
from datetime import timedelta from datetime import timedelta
@@ -52,15 +53,18 @@ class GenerateImage:
image = result.as_numpy("generated_image") image = result.as_numpy("generated_image")
image_result = cv2.cvtColor(np.squeeze(image.astype(np.uint8)), cv2.COLOR_RGB2BGR) image_result = cv2.cvtColor(np.squeeze(image.astype(np.uint8)), cv2.COLOR_RGB2BGR)
_, img_byte_array = cv2.imencode('.jpg', image_result) _, img_byte_array = cv2.imencode('.jpg', image_result)
object_name = f'test.jpg' byte_stream = io.BytesIO(img_byte_array)
req = oss_upload_image(bucket='test', object_name=object_name, image_bytes=img_byte_array) byte_stream.seek(0)
url = self.minio_client.get_presigned_url(
"GET", # object_name = f'test.jpg'
"test", # req = oss_upload_image(bucket='test', object_name=object_name, image_bytes=img_byte_array)
object_name, # url = self.minio_client.get_presigned_url(
expires=timedelta(hours=2), # "GET",
) # "test",
return url # object_name,
# expires=timedelta(hours=2),
# )
return byte_stream
if __name__ == '__main__': if __name__ == '__main__':