Files
AiDA_Model_Litserve/app/server/generate_image/flux2_klein/server.py
2026-01-26 11:40:11 +08:00

69 lines
2.5 KiB
Python

import io
import os
import torch
import litserve as ls
from diffusers import Flux2KleinPipeline
from minio import Minio
from app.config.config import settings
from app.server.utils.minio_client import oss_get_image, oss_upload_image
minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE)
class Flux2KleinServer(ls.LitAPI):
def setup(self, device):
# Load the model
dtype = torch.bfloat16
self.device = device
model_path = os.path.join(settings.FLUX2_KLEIN_MODEL_PATH, "FLUX.2-klein-4B")
self.model = Flux2KleinPipeline.from_pretrained(model_path, torch_dtype=dtype, is_distilled=False)
self.model.to(device) # save some VRAM by offloading the model to CPU
def decode_request(self, request):
return request
def predict(self, request_data):
image_path = request_data.get("image_path", "")
prompt = request_data.get("prompt", "")
height = request_data.get("height", 768)
width = request_data.get("width", 512)
infer_step = request_data.get("infer_step", 4)
tasks_id = request_data.get("tasks_id", "test")
user_id = tasks_id[tasks_id.rfind('-') + 1:]
input_image = oss_get_image(oss_client=minio_client, path=image_path, data_type='pil')
with torch.no_grad():
images = self.model(
image=input_image,
prompt=prompt,
height=height,
width=width,
guidance_scale=1.0,
num_inference_steps=infer_step,
# generator=torch.Generator(device='cuda').manual_seed(3)
)[0]
# save image to minio
image = images[0] # Assuming you want to retrieve the first image
image_data = io.BytesIO()
image.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
object_name = f'{user_id}/product_image/{tasks_id}.png'
req = oss_upload_image(oss_client=minio_client, bucket="aida-users", object_name=object_name, image_bytes=image_bytes)
image_url = f"aida-users/{object_name}"
return image_url
def encode_response(self, image_url):
return {"image_url": image_url}
# Starting the server
if __name__ == "__main__":
# Assume that an appropriate device (e.g., 'cuda', 'cpu') is specified
api = Flux2KleinServer()
server = ls.LitServer(api)
server.run(port=8011)