feat: brand dna logo生成替换flux2klein ; fix:
This commit is contained in:
@@ -1,19 +1,10 @@
|
||||
import logging
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import tritonclient.grpc as grpcclient
|
||||
import uuid
|
||||
import httpx
|
||||
from langchain_classic.output_parsers import ResponseSchema, StructuredOutputParser
|
||||
from langchain_community.chat_models import ChatTongyi
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from minio import Minio
|
||||
from tritonclient.utils import np_to_triton_dtype
|
||||
|
||||
from app.core.config import GI_MODEL_URL, GI_MODEL_NAME
|
||||
from app.schemas.brand_dna import GenerateBrandModel
|
||||
from app.service.utils.generate_uuid import generate_uuid
|
||||
from app.service.utils.new_oss_client import oss_upload_image
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
|
||||
@@ -26,14 +17,9 @@ class GenerateBrandInfo:
|
||||
# user info init
|
||||
self.user_id = request_data.user_id
|
||||
self.category = "brand_logo"
|
||||
# generate logo init
|
||||
self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL)
|
||||
self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8)
|
||||
self.batch_size = 1
|
||||
self.mode = 'txt2img'
|
||||
|
||||
# llm generate brand info init
|
||||
self.model = ChatTongyi(model="qwen2.5-14b-instruct", api_key="sk-7658298c6b99443c98184a5e634fe6ab")
|
||||
self.model = ChatTongyi(model="qwen2.5-14b-instruct", api_key=settings.QWEN_API_KEY)
|
||||
|
||||
self.response_schemas = [
|
||||
ResponseSchema(name="brand_name", description="Brand name."),
|
||||
@@ -63,38 +49,20 @@ class GenerateBrandInfo:
|
||||
self.generate_logo_prompt = brand_data['brand_logo_prompt']
|
||||
|
||||
def generate_brand_logo(self):
|
||||
prompts = [self.generate_logo_prompt] * self.batch_size
|
||||
modes = [self.mode] * self.batch_size
|
||||
images = [self.image.astype(np.float16)] * self.batch_size
|
||||
|
||||
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
|
||||
mode_obj = np.array(modes, dtype="object").reshape((-1, 1))
|
||||
image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3))
|
||||
|
||||
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
|
||||
input_image = grpcclient.InferInput("input_image", image_obj.shape, np_to_triton_dtype(image_obj.dtype))
|
||||
input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(mode_obj.dtype))
|
||||
|
||||
input_text.set_data_from_numpy(text_obj)
|
||||
input_image.set_data_from_numpy(image_obj)
|
||||
input_mode.set_data_from_numpy(mode_obj)
|
||||
|
||||
inputs = [input_text, input_image, input_mode]
|
||||
result = self.grpc_client.infer(model_name=GI_MODEL_NAME, inputs=inputs)
|
||||
image = result.as_numpy("generated_image")
|
||||
image_result = cv2.cvtColor(np.squeeze(image.astype(np.uint8)), cv2.COLOR_RGB2BGR)
|
||||
logo_url = self.upload_logo_image(image_result, generate_uuid())
|
||||
self.result_data['brand_logo'] = logo_url
|
||||
|
||||
def upload_logo_image(self, image, object_name):
|
||||
try:
|
||||
_, img_byte_array = cv2.imencode('.jpg', image)
|
||||
object_name = f'{self.user_id}/{self.category}/{object_name}.jpg'
|
||||
oss_upload_image(oss_client=self.minio_client, bucket="aida-users", object_name=object_name, image_bytes=img_byte_array)
|
||||
image_url = f"aida-users/{object_name}"
|
||||
return image_url
|
||||
except Exception as e:
|
||||
logging.warning(f"upload_png_mask runtime exception : {e}")
|
||||
request_item = {
|
||||
"bucket_name": "aida-users",
|
||||
"object_name": f'{self.user_id}/{self.category}/{uuid.uuid4().hex}.png',
|
||||
"prompt": self.generate_logo_prompt,
|
||||
"height": 1024,
|
||||
"width": 1024
|
||||
}
|
||||
with httpx.Client(timeout=120) as client:
|
||||
resp = client.post(
|
||||
f"http://{settings.FLUX2_GEN_IMG_MODEL_URL}/predict",
|
||||
json=request_item,
|
||||
)
|
||||
result = resp.json()
|
||||
self.result_data['brand_logo'] = result.get("output_path", "")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user