feat: brand dna logo生成替换flux2klein ; fix:
This commit is contained in:
@@ -1,19 +1,10 @@
|
|||||||
import logging
|
import uuid
|
||||||
|
import httpx
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
import tritonclient.grpc as grpcclient
|
|
||||||
from langchain_classic.output_parsers import ResponseSchema, StructuredOutputParser
|
from langchain_classic.output_parsers import ResponseSchema, StructuredOutputParser
|
||||||
from langchain_community.chat_models import ChatTongyi
|
from langchain_community.chat_models import ChatTongyi
|
||||||
from langchain_core.prompts import PromptTemplate
|
from langchain_core.prompts import PromptTemplate
|
||||||
from minio import Minio
|
from minio import Minio
|
||||||
from tritonclient.utils import np_to_triton_dtype
|
|
||||||
|
|
||||||
from app.core.config import GI_MODEL_URL, GI_MODEL_NAME
|
|
||||||
from app.schemas.brand_dna import GenerateBrandModel
|
from app.schemas.brand_dna import GenerateBrandModel
|
||||||
from app.service.utils.generate_uuid import generate_uuid
|
|
||||||
from app.service.utils.new_oss_client import oss_upload_image
|
|
||||||
|
|
||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
|
|
||||||
|
|
||||||
@@ -26,14 +17,9 @@ class GenerateBrandInfo:
|
|||||||
# user info init
|
# user info init
|
||||||
self.user_id = request_data.user_id
|
self.user_id = request_data.user_id
|
||||||
self.category = "brand_logo"
|
self.category = "brand_logo"
|
||||||
# generate logo init
|
|
||||||
self.grpc_client = grpcclient.InferenceServerClient(url=GI_MODEL_URL)
|
|
||||||
self.image = np.random.randint(0, 256, (1024, 1024, 3), dtype=np.uint8)
|
|
||||||
self.batch_size = 1
|
|
||||||
self.mode = 'txt2img'
|
|
||||||
|
|
||||||
# llm generate brand info init
|
# llm generate brand info init
|
||||||
self.model = ChatTongyi(model="qwen2.5-14b-instruct", api_key="sk-7658298c6b99443c98184a5e634fe6ab")
|
self.model = ChatTongyi(model="qwen2.5-14b-instruct", api_key=settings.QWEN_API_KEY)
|
||||||
|
|
||||||
self.response_schemas = [
|
self.response_schemas = [
|
||||||
ResponseSchema(name="brand_name", description="Brand name."),
|
ResponseSchema(name="brand_name", description="Brand name."),
|
||||||
@@ -63,38 +49,20 @@ class GenerateBrandInfo:
|
|||||||
self.generate_logo_prompt = brand_data['brand_logo_prompt']
|
self.generate_logo_prompt = brand_data['brand_logo_prompt']
|
||||||
|
|
||||||
def generate_brand_logo(self):
|
def generate_brand_logo(self):
|
||||||
prompts = [self.generate_logo_prompt] * self.batch_size
|
request_item = {
|
||||||
modes = [self.mode] * self.batch_size
|
"bucket_name": "aida-users",
|
||||||
images = [self.image.astype(np.float16)] * self.batch_size
|
"object_name": f'{self.user_id}/{self.category}/{uuid.uuid4().hex}.png',
|
||||||
|
"prompt": self.generate_logo_prompt,
|
||||||
text_obj = np.array(prompts, dtype="object").reshape((-1, 1))
|
"height": 1024,
|
||||||
mode_obj = np.array(modes, dtype="object").reshape((-1, 1))
|
"width": 1024
|
||||||
image_obj = np.array(images, dtype=np.float16).reshape((-1, 1024, 1024, 3))
|
}
|
||||||
|
with httpx.Client(timeout=120) as client:
|
||||||
input_text = grpcclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))
|
resp = client.post(
|
||||||
input_image = grpcclient.InferInput("input_image", image_obj.shape, np_to_triton_dtype(image_obj.dtype))
|
f"http://{settings.FLUX2_GEN_IMG_MODEL_URL}/predict",
|
||||||
input_mode = grpcclient.InferInput("mode", mode_obj.shape, np_to_triton_dtype(mode_obj.dtype))
|
json=request_item,
|
||||||
|
)
|
||||||
input_text.set_data_from_numpy(text_obj)
|
result = resp.json()
|
||||||
input_image.set_data_from_numpy(image_obj)
|
self.result_data['brand_logo'] = result.get("output_path", "")
|
||||||
input_mode.set_data_from_numpy(mode_obj)
|
|
||||||
|
|
||||||
inputs = [input_text, input_image, input_mode]
|
|
||||||
result = self.grpc_client.infer(model_name=GI_MODEL_NAME, inputs=inputs)
|
|
||||||
image = result.as_numpy("generated_image")
|
|
||||||
image_result = cv2.cvtColor(np.squeeze(image.astype(np.uint8)), cv2.COLOR_RGB2BGR)
|
|
||||||
logo_url = self.upload_logo_image(image_result, generate_uuid())
|
|
||||||
self.result_data['brand_logo'] = logo_url
|
|
||||||
|
|
||||||
def upload_logo_image(self, image, object_name):
|
|
||||||
try:
|
|
||||||
_, img_byte_array = cv2.imencode('.jpg', image)
|
|
||||||
object_name = f'{self.user_id}/{self.category}/{object_name}.jpg'
|
|
||||||
oss_upload_image(oss_client=self.minio_client, bucket="aida-users", object_name=object_name, image_bytes=img_byte_array)
|
|
||||||
image_url = f"aida-users/{object_name}"
|
|
||||||
return image_url
|
|
||||||
except Exception as e:
|
|
||||||
logging.warning(f"upload_png_mask runtime exception : {e}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
Reference in New Issue
Block a user