import io import torch import litserve as ls from diffusers import Flux2KleinPipeline from minio import Minio from app.config.config import settings from app.server.utils.minio_client import oss_get_image, oss_upload_image minio_client = Minio(settings.MINIO_URL, access_key=settings.MINIO_ACCESS, secret_key=settings.MINIO_SECRET, secure=settings.MINIO_SECURE) class Flux2KleinServer(ls.LitAPI): def setup(self, device): # Load the model dtype = torch.bfloat16 self.device = device self.model = Flux2KleinPipeline.from_pretrained("black-forest-labs/FLUX.2-klein-4B", torch_dtype=dtype, is_distilled=False) self.model.to(device) # save some VRAM by offloading the model to CPU def decode_request(self, request): return request def predict(self, request_data): image_path = request_data.get("image_path", "") prompt = request_data.get("prompt", "") height = request_data.get("height", 768) width = request_data.get("width", 512) infer_step = request_data.get("infer_step", 4) tasks_id = request_data.get("tasks_id", "test") user_id = tasks_id[tasks_id.rfind('-') + 1:] input_image = oss_get_image(oss_client=minio_client, path=image_path, data_type='pil') with torch.no_grad(): images = self.model( image=input_image, prompt=prompt, height=height, width=width, guidance_scale=1.0, num_inference_steps=infer_step, # generator=torch.Generator(device='cuda').manual_seed(3) )[0] # save image to minio image = images[0] # Assuming you want to retrieve the first image image_data = io.BytesIO() image.save(image_data, format='PNG') image_data.seek(0) image_bytes = image_data.read() object_name = f'{user_id}/product_image/{tasks_id}.png' req = oss_upload_image(oss_client=minio_client, bucket="aida-users", object_name=object_name, image_bytes=image_bytes) image_url = f"aida-users/{object_name}" return image_url def encode_response(self, image_url): return {"image_url": image_url} # Starting the server if __name__ == "__main__": # Assume that an appropriate device (e.g., 'cuda', 'cpu') is specified api = Flux2KleinServer() server = ls.LitServer(api) server.run(port=8011)