diff --git a/app/api/api_generate_image.py b/app/api/api_generate_image.py index cce1300..82ad571 100644 --- a/app/api/api_generate_image.py +++ b/app/api/api_generate_image.py @@ -151,12 +151,14 @@ def generate_relight_image(request_item: GenerateRelightImageModel, background_t - **tasks_id**: 任务id 用于取消生成任务和获取生成结果 - **prompt**: 想要生成图片的描述词 - **image_url**: 被生成图片的S3或minio url地址 + - **direction**: 光源方向 Right Light Left Light Top Light Bottom Light 示例参数: { "tasks_id": "123-89", "prompt": "beautiful woman, detailed face, sunshine, outdoor, warm atmosphere", - "image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png" + "image_url": "aida-results/result_0000b606-1902-11ef-9424-0242ac180002.png", + "direction": "Right Light" } """ try: diff --git a/app/schemas/generate_image.py b/app/schemas/generate_image.py index 29f34d6..2a16442 100644 --- a/app/schemas/generate_image.py +++ b/app/schemas/generate_image.py @@ -27,3 +27,4 @@ class GenerateRelightImageModel(BaseModel): tasks_id: str prompt: str image_url: str + direction: str diff --git a/app/service/generate_image/service_generate_relight_image.py b/app/service/generate_image/service_generate_relight_image.py index ca32c73..6f51435 100644 --- a/app/service/generate_image/service_generate_relight_image.py +++ b/app/service/generate_image/service_generate_relight_image.py @@ -7,16 +7,15 @@ @Date :2023/7/26 12:01:05 @detail : """ -import io import json import logging import time + import cv2 +import numpy as np import redis import tritonclient.grpc as grpcclient -import numpy as np -from PIL import Image, ImageOps -from minio import Minio +from PIL import Image from tritonclient.utils import np_to_triton_dtype from app.core.config import * @@ -40,7 +39,7 @@ class GenerateRelightImage: self.prompt = request_data.prompt self.seed = "1" self.negative_prompt = 'lowres, bad anatomy, bad hands, cropped, worst quality' - self.direction = "Right Light" + self.direction = request_data.direction self.image_url = request_data.image_url self.image = oss_get_image(bucket=self.image_url.split('/')[0], object_name=self.image_url[self.image_url.find('/') + 1:], data_type="cv2") self.tasks_id = request_data.tasks_id