1
Some checks failed
CI / lint (push) Has been cancelled

This commit is contained in:
zcr
2026-03-18 10:12:06 +08:00
parent b56ac61450
commit d89a176b63
19 changed files with 2685 additions and 1 deletions

1
.gitignore vendored
View File

@@ -230,3 +230,4 @@ $RECYCLE.BIN/
# End of https://www.toptal.com/developers/gitignore/api/linux,windows,macos,visualstudiocode,python
output/
*.png

10
.idea/.gitignore generated vendored Normal file
View File

@@ -0,0 +1,10 @@
# 默认忽略的文件
/shelf/
/workspace.xml
# 已忽略包含查询文件的默认文件夹
/queries/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml
# 基于编辑器的 HTTP 客户端请求
/httpRequests/

15
.idea/flux2.iml generated Normal file
View File

@@ -0,0 +1,15 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$">
<sourceFolder url="file://$MODULE_DIR$/src" isTestSource="false" />
<excludeFolder url="file://$MODULE_DIR$/.venv" />
</content>
<orderEntry type="jdk" jdkName="uv (flux2)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="PyDocumentationSettings">
<option name="format" value="PLAIN" />
<option name="myDocStringFormat" value="Plain" />
</component>
</module>

View File

@@ -0,0 +1,39 @@
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="DuplicatedCode" enabled="false" level="WEAK WARNING" enabled_by_default="false" />
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
<inspection_tool class="GrazieInspection" enabled="false" level="GRAMMAR_ERROR" enabled_by_default="false" />
<inspection_tool class="HttpUrlsUsage" enabled="false" level="WEAK WARNING" enabled_by_default="false" />
<inspection_tool class="OutdatedRequirementInspection" enabled="false" level="WEAK WARNING" enabled_by_default="false" />
<inspection_tool class="PyBroadExceptionInspection" enabled="false" level="WEAK WARNING" enabled_by_default="false" />
<inspection_tool class="PyDeprecationInspection" enabled="false" level="WARNING" enabled_by_default="false" />
<inspection_tool class="PyPep8NamingInspection" enabled="false" level="WEAK WARNING" enabled_by_default="false">
<option name="ignoredErrors">
<list>
<option value="N801" />
<option value="N806" />
</list>
</option>
</inspection_tool>
<inspection_tool class="PyShadowingBuiltinsInspection" enabled="true" level="WEAK WARNING" enabled_by_default="true">
<option name="ignoredNames">
<list>
<option value="input" />
<option value="type" />
<option value="object" />
<option value="id" />
</list>
</option>
</inspection_tool>
<inspection_tool class="PyShadowingNamesInspection" enabled="false" level="WEAK WARNING" enabled_by_default="false" />
<inspection_tool class="PyTypeCheckerInspection" enabled="false" level="WARNING" enabled_by_default="false" />
<inspection_tool class="SpellCheckingInspection" enabled="false" level="TYPO" enabled_by_default="false">
<option name="processCode" value="true" />
<option name="processLiterals" value="true" />
<option name="processComments" value="false" />
</inspection_tool>
<inspection_tool class="UnsatisfiedRequirementInspection" enabled="false" level="WARNING" enabled_by_default="false" />
<inspection_tool class="VulnerableLibrariesLocal" enabled="false" level="WARNING" enabled_by_default="false" />
</profile>
</component>

View File

@@ -0,0 +1,6 @@
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>

7
.idea/misc.xml generated Normal file
View File

@@ -0,0 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="Black">
<option name="sdkName" value="uv (flux2)" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="uv (flux2)" project-jdk-type="Python SDK" />
</project>

8
.idea/modules.xml generated Normal file
View File

@@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/flux2.iml" filepath="$PROJECT_DIR$/.idea/flux2.iml" />
</modules>
</component>
</project>

6
.idea/vcs.xml generated Normal file
View File

@@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="" vcs="Git" />
</component>
</project>

10
client.py Normal file
View File

@@ -0,0 +1,10 @@
# This file is auto-generated by LitServe.
# Disable auto-generation by setting `generate_client_file=False` in `LitServer.run()`.
import time
import requests
start_time = time.time()
response = requests.post("http://127.0.0.1:8011/predict", json={"prompt": "变成一只猫"})
# print(f"Status: {response.status_code}\nResponse:\n {response.text}")
print(time.time() - start_time)

110
litserve/app.py Normal file
View File

@@ -0,0 +1,110 @@
import torch
import litserve as ls
from PIL import Image
import io
import base64
from diffusers import Flux2KleinPipeline
# 保持原有的辅助函数
def aspect_to_wh(aspect_ratio: str, base_long_edge: int) -> tuple[int, int]:
w_str, h_str = aspect_ratio.split(":")
w, h = float(w_str), float(h_str)
if w >= h:
width = base_long_edge
height = int(round(base_long_edge * (h / w)))
else:
height = base_long_edge
width = int(round(base_long_edge * (w / h)))
width = max(64, (width // 8) * 8)
height = max(64, (height // 8) * 8)
return width, height
class FluxKleinAPI(ls.LitAPI):
def setup(self, device):
# 1. 模型初始化
self.repo_id = "black-forest-labs/FLUX.2-klein-4B"
self.device = device
self.dtype = torch.bfloat16
self.pipe = Flux2KleinPipeline.from_pretrained(
self.repo_id,
torch_dtype=self.dtype
)
self.pipe.to(device)
self.default_system_prompt = (
"Pure white background (#FFFFFF) with a seamless white studio backdrop; no environment... "
"Realistic scale, high material fidelity, sharp focus, 4K quality."
)
def decode_request(self, request):
# 1. 获取 Base64 字符串
image_data = request.get("image", None)
if not image_data:
raise ValueError("Request must contain 'image' field with base64 data.")
# 2. 如果带有 data:image/png;base64, 前缀,去掉它
if "," in image_data:
image_data = image_data.split(",")[1]
# 3. 解码并转换为 PIL 对象 (修复点)
img_bytes = base64.b64decode(image_data)
img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
return {
"image": img,
"prompt": request.get("prompt", ""),
"steps": request.get("steps", 28),
"guidance": request.get("guidance", 4.0),
"seed": request.get("seed", 42),
"aspect_ratio": request.get("aspect_ratio", "4:3"),
"base_long_edge": request.get("base_long_edge", 1024)
}
@torch.inference_mode()
def predict(self, payload):
# 3. 执行推理逻辑
img = payload.get("image", None)
prompt = payload.get("prompt", "")
W, H = aspect_to_wh(payload["aspect_ratio"], payload["base_long_edge"])
gen = torch.Generator(device=self.device)
if img:
# 这里简化了 resize 逻辑,实际可调用你原有的 resize_pad_to_target
img = img.convert("RGB").resize((W, H), Image.Resampling.LANCZOS)
output = self.pipe(
image=img,
prompt=prompt,
height=H,
width=W,
guidance_scale=payload["guidance"],
num_inference_steps=payload["steps"],
generator=gen,
).images[0]
else:
output = self.pipe(
image=img,
prompt=prompt,
height=H,
width=W,
guidance_scale=payload["guidance"],
num_inference_steps=payload["steps"],
generator=gen,
).images[0]
return output
def encode_response(self, output):
# 4. 将 PIL 图像转回 Base64 返回
buffered = io.BytesIO()
output.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return {"image": img_str}
if __name__ == "__main__":
# 启动服务器
api = FluxKleinAPI()
server = ls.LitServer(api, accelerator="cuda", devices=1)
server.run(port=8451)

389
litserve/cli.py Normal file
View File

@@ -0,0 +1,389 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
FLUX.2 [klein] image-to-image editing with CLI args (single or batch).
Prereqs:
pip install -U "diffusers>=0.30.0" transformers accelerate safetensors pillow
# and you have CUDA set up for torch.
Examples:
# Single image
python flux2_klein_cli.py \
--image test_sketch/sofa_2.png \
--user_input "caramel brown leather, matte black legs" \
--out outputs/sofa_2_to_real.png
# Batch: one prompt applied to all images
python flux2_klein_cli.py \
--image_dir test_sketch \
--user_input "pure white background, photorealistic materials" \
--out_dir outputs/batch
# Batch: a prompt per image using a prompt file (mapping or list)
python flux2_klein_cli.py \
--image_dir test_sketch \
--prompt_file prompt.txt \
--out_dir outputs/batch
# Optional knobs
python flux2_klein_cli.py \
--image test_sketch/a.png \
--user_input "..." \
--out outputs/a.png \
--aspect_ratio 4:3 \
--base_long_edge 1024 \
--steps 28 \
--guidance 4.0 \
--seed 0 \
--cpu_offload
"""
from __future__ import annotations
import argparse
from pathlib import Path
from typing import Dict, List, Tuple, Optional
import torch
from PIL import Image, ImageOps
from diffusers import Flux2KleinPipeline
from diffusers.utils import load_image
DEFAULT_REPO_ID = "black-forest-labs/FLUX.2-klein-4B"
DEFAULT_DEVICE = "cuda"
DEFAULT_TORCH_DTYPE = "bfloat16"
DEFAULT_SYSTEM_PROMPT = (
"Pure white background (#FFFFFF) with a seamless white studio backdrop; no environment, no interior, "
"no floor texture, and no visible floor/wall lines. Studio product photography lighting with soft "
"diffused light; ultra-clean industrial design presentation. Render a single furniture object only, "
"isolated and perfectly centered on a borderless canvas. Strictly preserve the geometry, silhouette, "
"proportions, and key structural details of the reference sketch (do not change shape, do not add/remove parts). "
"Realistic scale, high material fidelity, sharp focus, 4K quality. No people, no props, no text, no logos, no watermark."
)
# -----------------------------
# Prompt parsing (optional)
# -----------------------------
def load_prompts(prompt_txt: Path) -> Tuple[Dict[str, str], List[str]]:
"""
Supports:
A) "sofa_2.png: caramel brown leather ..." (filename: user_input)
B) "caramel brown leather ..." (user_input only; matched by sorted image filenames)
"""
if not prompt_txt.exists():
raise FileNotFoundError(f"prompt file not found: {prompt_txt}")
lines: List[str] = []
for raw in prompt_txt.read_text(encoding="utf-8").splitlines():
s = raw.strip()
if not s or s.startswith("#"):
continue
lines.append(s)
mapping: Dict[str, str] = {}
list_only: List[str] = []
looks_like_mapping = any(
(":" in ln and ln.split(":", 1)[0].strip().lower().endswith((".png", ".jpg", ".jpeg", ".webp")))
for ln in lines
)
if looks_like_mapping:
for ln in lines:
if ":" not in ln:
continue
k, v = ln.split(":", 1)
key = k.strip()
val = v.strip()
if key and val:
mapping[key] = val
else:
list_only = lines
return mapping, list_only
def list_images(image_dir: Path) -> List[Path]:
exts = {".png", ".jpg", ".jpeg", ".webp"}
imgs = [p for p in sorted(image_dir.iterdir()) if p.is_file() and p.suffix.lower() in exts]
if not imgs:
raise RuntimeError(f"No images found in: {image_dir}")
return imgs
def build_final_prompt(system_prompt: str, user_input: str) -> str:
user_input = (user_input or "").strip()
if not user_input:
return system_prompt
return f"{system_prompt} User requirements: {user_input}"
# -----------------------------
# Aspect ratio utils (resize+pad)
# -----------------------------
def aspect_to_wh(aspect_ratio: str, base_long_edge: int) -> tuple[int, int]:
w_str, h_str = aspect_ratio.split(":")
w, h = float(w_str), float(h_str)
if w >= h:
width = base_long_edge
height = int(round(base_long_edge * (h / w)))
else:
height = base_long_edge
width = int(round(base_long_edge * (w / h)))
width = max(64, (width // 8) * 8)
height = max(64, (height // 8) * 8)
return width, height
def resize_pad_to_target(img: Image.Image, width: int, height: int, fill=(255, 255, 255)) -> Image.Image:
img = img.convert("RGB")
contained = ImageOps.contain(img, (width, height), method=Image.Resampling.LANCZOS)
pad_w = width - contained.size[0]
pad_h = height - contained.size[1]
padding = (pad_w // 2, pad_h // 2, pad_w - pad_w // 2, pad_h - pad_h // 2)
return ImageOps.expand(contained, border=padding, fill=fill)
def resize_to_long_edge(img: Image.Image, base_long_edge: int) -> Image.Image:
img = img.convert("RGB")
w, h = img.size
if max(w, h) == base_long_edge:
return img
if w >= h:
nw, nh = base_long_edge, int(round(base_long_edge * h / w))
else:
nw, nh = int(round(base_long_edge * w / h)), base_long_edge
nw = max(64, (nw // 8) * 8)
nh = max(64, (nh // 8) * 8)
return img.resize((nw, nh), Image.Resampling.LANCZOS)
# -----------------------------
# Run one
# -----------------------------
@torch.inference_mode()
def run_one(
pipe: Flux2KleinPipeline,
device: str,
image_path: Path,
final_prompt: str,
out_path: Path,
aspect_ratio: Optional[str],
base_long_edge: int,
steps: int,
guidance: float,
seed: int,
) -> None:
ref = load_image(str(image_path)) # PIL.Image
if aspect_ratio:
W, H = aspect_to_wh(aspect_ratio, base_long_edge)
ref = resize_pad_to_target(ref, W, H)
else:
ref = resize_to_long_edge(ref, base_long_edge)
W, H = ref.size
gen = torch.Generator(device=device).manual_seed(seed)
result = pipe(
image=ref,
prompt=final_prompt,
height=H,
width=W,
guidance_scale=guidance,
num_inference_steps=steps,
generator=gen,
).images[0]
out_path.parent.mkdir(parents=True, exist_ok=True)
result.save(out_path)
print(f"[OK] saved: {out_path}")
def parse_args() -> argparse.Namespace:
ap = argparse.ArgumentParser(description="FLUX.2 [klein] img2img editing with CLI args.")
io = ap.add_argument_group("I/O")
io.add_argument("--image", type=str, default="", help="Single input image path.")
io.add_argument("--image_dir", type=str, default="", help="Directory of images for batch mode.")
io.add_argument("--out", type=str, default="", help="Single output image path (for --image).")
io.add_argument("--out_dir", type=str, default="", help="Output directory (for --image_dir).")
prompt = ap.add_argument_group("Prompt")
prompt.add_argument("--user_input", type=str, default="", help="User prompt (applied to single image or all images).")
prompt.add_argument(
"--prompt_file",
type=str,
default="",
help='Optional prompt file for batch mode. Supports "filename: prompt" mapping or one-prompt-per-line list.',
)
prompt.add_argument("--system_prompt", type=str, default=DEFAULT_SYSTEM_PROMPT, help="System prompt string.")
model = ap.add_argument_group("Model/Runtime")
model.add_argument("--repo_id", type=str, default=DEFAULT_REPO_ID, help=f"HF repo id (default: {DEFAULT_REPO_ID}).")
model.add_argument("--device", type=str, default=DEFAULT_DEVICE, help='Device, e.g. "cuda" or "cpu".')
model.add_argument(
"--torch_dtype",
type=str,
default=DEFAULT_TORCH_DTYPE,
choices=["float16", "bfloat16", "float32"],
help="torch dtype for loading weights.",
)
model.add_argument("--cpu_offload", action="store_true", help="Enable model CPU offload (saves VRAM).")
gen = ap.add_argument_group("Generation")
gen.add_argument("--aspect_ratio", type=str, default="4:3", help='e.g. "16:9", "1:1". Set "" to disable.')
gen.add_argument("--base_long_edge", type=int, default=1024, help="Target long edge for output/resized reference.")
gen.add_argument("--steps", type=int, default=28, help="num_inference_steps.")
gen.add_argument("--guidance", type=float, default=4.0, help="guidance_scale.")
gen.add_argument("--seed", type=int, default=0, help="Random seed.")
args = ap.parse_args()
# Validate mode
if bool(args.image) == bool(args.image_dir):
raise SystemExit("Provide exactly one of --image or --image_dir.")
if args.image:
if not args.out:
raise SystemExit("In single-image mode, --out is required.")
if not args.user_input and not args.system_prompt:
raise SystemExit("Provide --user_input and/or --system_prompt.")
else:
if not args.out_dir:
raise SystemExit("In batch mode, --out_dir is required.")
if not args.prompt_file and not args.user_input:
raise SystemExit("Batch mode requires either --prompt_file or --user_input.")
return args
def resolve_torch_dtype(dtype_str: str) -> torch.dtype:
if dtype_str == "float16":
return torch.float16
if dtype_str == "bfloat16":
return torch.bfloat16
return torch.float32
def main() -> None:
args = parse_args()
# Normalize aspect_ratio
aspect_ratio = (args.aspect_ratio or "").strip()
if aspect_ratio == "":
aspect_ratio = None
torch_dtype = resolve_torch_dtype(args.torch_dtype)
pipe = Flux2KleinPipeline.from_pretrained(args.repo_id, torch_dtype=torch_dtype)
if args.cpu_offload:
pipe.enable_model_cpu_offload()
else:
pipe.to(args.device)
system_prompt = args.system_prompt
if args.image:
img_path = Path(args.image)
if not img_path.exists():
raise FileNotFoundError(f"Image not found: {img_path}")
out_path = Path(args.out)
final_prompt = build_final_prompt(system_prompt, args.user_input)
run_one(
pipe=pipe,
device=args.device,
image_path=img_path,
final_prompt=final_prompt,
out_path=out_path,
aspect_ratio=aspect_ratio,
base_long_edge=args.base_long_edge,
steps=args.steps,
guidance=args.guidance,
seed=args.seed,
)
return
# Batch mode
image_dir = Path(args.image_dir)
if not image_dir.exists():
raise FileNotFoundError(f"Image dir not found: {image_dir}")
images = list_images(image_dir)
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
mapping: Dict[str, str] = {}
list_only: List[str] = []
if args.prompt_file:
mapping, list_only = load_prompts(Path(args.prompt_file))
if mapping:
# Mapping mode: filename -> prompt
for img_path in images:
user_input = mapping.get(img_path.name, "")
final_prompt = build_final_prompt(system_prompt, user_input)
out_path = out_dir / f"{img_path.stem}_to_real_flux2klein.png"
run_one(
pipe=pipe,
device=args.device,
image_path=img_path,
final_prompt=final_prompt,
out_path=out_path,
aspect_ratio=aspect_ratio,
base_long_edge=args.base_long_edge,
steps=args.steps,
guidance=args.guidance,
seed=args.seed,
)
elif list_only:
# List mode: one prompt per line, matched by sorted filenames
n = min(len(list_only), len(images))
if len(list_only) != len(images):
print(
f"[WARN] prompt lines ({len(list_only)}) != images ({len(images)}). "
"Will match by min length and leave the rest empty."
)
for i, img_path in enumerate(images):
user_input = list_only[i] if i < n else ""
final_prompt = build_final_prompt(system_prompt, user_input)
out_path = out_dir / f"{img_path.stem}_to_real_flux2klein.png"
run_one(
pipe=pipe,
device=args.device,
image_path=img_path,
final_prompt=final_prompt,
out_path=out_path,
aspect_ratio=aspect_ratio,
base_long_edge=args.base_long_edge,
steps=args.steps,
guidance=args.guidance,
seed=args.seed,
)
else:
# No prompt file parsed (or no prompt file). Apply --user_input to all
for img_path in images:
final_prompt = build_final_prompt(system_prompt, args.user_input)
out_path = out_dir / f"{img_path.stem}_to_real_flux2klein.png"
run_one(
pipe=pipe,
device=args.device,
image_path=img_path,
final_prompt=final_prompt,
out_path=out_path,
aspect_ratio=aspect_ratio,
base_long_edge=args.base_long_edge,
steps=args.steps,
guidance=args.guidance,
seed=args.seed,
)
if __name__ == "__main__":
main()

17
litserve/client.py Normal file
View File

@@ -0,0 +1,17 @@
import requests
import base64
# 将你的图片转为 base64
with open("/mnt/data/workspace/Code/flux2/20260123_152354_2steps.png", "rb") as f:
img_base64 = base64.b64encode(f.read()).decode("utf-8")
response = requests.post("http://localhost:8451/predict", json={
# "image": img_base64,
"prompt": "紫色实木窗帘",
"aspect_ratio": "1:1",
"steps": 4
})
# 保存结果
with open("result.png", "wb") as f:
f.write(base64.b64decode(response.json()["image"]))

84
litserve/model.py Normal file
View File

@@ -0,0 +1,84 @@
# model.py
import torch
from diffusers import FluxPipeline # 假设你用的是 flux from diffusers
from PIL import Image
import os
from typing import List, Optional, Dict, Any
class Flux2KleinModel:
def __init__(self, model_id="black-forest-labs/FLUX.1-dev", device="cuda"):
print(f"Loading Flux model on {device} ...")
self.pipe = FluxPipeline.from_pretrained(
model_id,
torch_dtype=torch.bfloat16
)
self.pipe.enable_model_cpu_offload() # 或 .to(device) 看显存情况
# self.pipe.vae.enable_slicing() # 可选,省显存
self.device = device
def preprocess_image(self, image_path: str) -> Image.Image:
img = Image.open(image_path).convert("RGB")
return img
def run_inference(
self,
image: Image.Image,
prompt: str,
aspect_ratio: str = "1:1",
base_long_edge: int = 1024,
steps: int = 28,
guidance: float = 4.0,
seed: int = 0,
**kwargs
) -> Image.Image:
# 根据 aspect_ratio 计算 width/height示例实现可按需修改
if ":" in aspect_ratio:
w, h = map(int, aspect_ratio.split(":"))
ratio = w / h
else:
ratio = 1.0
if ratio >= 1:
width = base_long_edge
height = int(base_long_edge / ratio)
else:
height = base_long_edge
width = int(base_long_edge * ratio)
# 取整到 8 的倍数flux 常见要求)
width = (width // 8) * 8
height = (height // 8) * 8
generator = torch.Generator(device=self.device).manual_seed(seed)
result = self.pipe(
prompt=prompt,
image=image,
height=height,
width=width,
num_inference_steps=steps,
guidance_scale=guidance,
generator=generator,
**kwargs
).images[0]
return result
def batch_run(
self,
image_paths: List[str],
prompts: List[str],
out_paths: List[str],
**common_kwargs
):
assert len(image_paths) == len(prompts) == len(out_paths)
for img_path, prompt, out_path in zip(image_paths, prompts, out_paths):
try:
img = self.preprocess_image(img_path)
result = self.run_inference(img, prompt, **common_kwargs)
os.makedirs(os.path.dirname(out_path), exist_ok=True)
result.save(out_path)
print(f"Saved: {out_path}")
except Exception as e:
print(f"Failed {img_path}: {e}")

156
litserve_serve.py Normal file
View File

@@ -0,0 +1,156 @@
import uuid
import torch
from minio import Minio
import litserve as ls
from PIL import Image
import io
import base64
from diffusers import Flux2KleinPipeline
from utils.new_oss_client import oss_get_image, oss_upload_image, MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
# 保持原有的辅助函数
def aspect_to_wh(aspect_ratio: str, base_long_edge: int) -> tuple[int, int]:
w_str, h_str = aspect_ratio.split(":")
w, h = float(w_str), float(h_str)
if w >= h:
width = base_long_edge
height = int(round(base_long_edge * (h / w)))
else:
height = base_long_edge
width = int(round(base_long_edge * (w / h)))
width = max(64, (width // 8) * 8)
height = max(64, (height // 8) * 8)
return width, height
class FluxKleinAPI(ls.LitAPI):
def setup(self, device):
# 1. 模型初始化
self.repo_id = "black-forest-labs/FLUX.2-klein-4B"
self.device = device
self.dtype = torch.bfloat16
self.pipe = Flux2KleinPipeline.from_pretrained(
self.repo_id,
torch_dtype=self.dtype
)
self.pipe.to(device)
def decode_request(self, request):
"""
解析请求参数并加载OSS图片的接口函数
接口入参说明request字典结构
----------
request : dict
核心请求参数字典,各字段说明如下:
- input_image_paths : list[str] | None (可选)
OSS图片路径列表格式为 "bucket/object_name"(如 "test/typical_b/uildi/ng_space_station.png"
若不传则为None会导致后续图片加载失败建议必传
- width : int (可选默认值512)
图片宽度默认512像素
- height : int (可选默认值512)
图片高度默认512像素
- bucket_name : str | None (可选)
OSS桶名不传则为None
- object_name : str | None (可选)
OSS对象名文件路径不传则为None
- prompt : str (可选,默认值空字符串)
文本提示词,用于模型推理等场景
- steps : int (可选默认值28)
推理步数,控制模型生成过程的迭代次数
- guidance : float (可选默认值4.0)
引导系数,调节提示词对生成结果的影响程度
- seed : int (可选默认值42)
随机种子,保证生成结果的可复现性
返回值说明
-------
dict
解析后的参数字典,包含:
- bucket_name: 请求中的桶名None/字符串)
- object_name: 请求中的对象名None/字符串)
- images: 从OSS加载的图片列表按input_image_paths顺序
- prompt: 文本提示词(默认空字符串)
- steps: 推理步数默认28
- guidance: 引导系数默认4.0
- seed: 随机种子默认42
- height: 图片高度默认512
- width: 图片宽度默认512
异常说明
-------
- 若input_image_paths非None但格式错误"/"分割且非空可能导致rest[0]索引错误
- 若OSS图片加载失败如路径不存在oss_get_image会抛出对应异常
"""
input_image_paths = request.get("input_image_paths", None)
W = request.get("width", 512)
H = request.get("height", 512)
images = []
for path in input_image_paths:
bucket, *rest = path.split("/", 1) # 拆分为 ["test", "typical_b/uildi/ng_space_station.png"]
object_name = rest[0] if rest else ""
image = oss_get_image(oss_client=minio_client, bucket=bucket, object_name=object_name)
images.append(image)
return {
"bucket_name": request.get("bucket_name", None),
"object_name": request.get("object_name", None),
"images": images,
"prompt": request.get("prompt", ""),
"steps": request.get("steps", 4),
"guidance": request.get("guidance", 4.0),
"seed": request.get("seed", 42),
"height": H,
"width": W
}
@torch.inference_mode()
def predict(self, payload):
# 3. 执行推理逻辑
images = payload.get("images", [])
prompt = payload.get("prompt", "")
W, H = aspect_to_wh(payload["aspect_ratio"], payload["base_long_edge"])
gen = torch.Generator(device=self.device)
output = {}
if images:
output['im'] = self.pipe(
image=images,
prompt=prompt,
height=H,
width=W,
guidance_scale=payload["guidance"],
num_inference_steps=payload["steps"],
generator=gen,
).images[0]
else:
output = self.pipe(
prompt=prompt,
height=H,
width=W,
guidance_scale=payload["guidance"],
num_inference_steps=payload["steps"],
generator=gen,
).images[0]
image_data = io.BytesIO()
output.save(image_data, format='PNG')
image_data.seek(0)
image_bytes = image_data.read()
req = oss_upload_image(oss_client=minio_client, bucket=payload.get("bucket_name", "test"), object_name=payload.get("object_name", f"fida_generate_image/{uuid.uuid4().hex}.png"), image_bytes=image_bytes)
output_path = req.bucket_name + "/" + req.object_name
return output_path
def encode_response(self, output_path):
return {"output_path": output_path}
if __name__ == "__main__":
# 启动服务器
api = FluxKleinAPI()
server = ls.LitServer(api, accelerator="cuda", devices=1)
server.run(port=8451)

View File

@@ -18,6 +18,9 @@ dependencies = [
"fire==0.7.1",
"openai==2.8.1",
"accelerate==1.12.0",
"datetime>=6.0",
"litserve>=0.2.17",
"minio>=7.2.20",
]
[project.optional-dependencies]

BIN
result1.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 963 KiB

31
test.py Normal file
View File

@@ -0,0 +1,31 @@
import datetime
import time
import torch
from PIL import Image
from diffusers import Flux2KleinPipeline
device = "cuda"
dtype = torch.bfloat16
pipe = Flux2KleinPipeline.from_pretrained("black-forest-labs/FLUX.2-klein-4B", torch_dtype=dtype, is_distilled=False)
pipe.to(device) # save some VRAM by offloading the model to CPUsave some VRAM by offloading the model to CPU
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
prompt = ""
num_inference_steps = 4
input_image = Image.open("result1.png")
start_time = time.time()
image = pipe(
image=input_image,
prompt=prompt,
height=768,
width=512,
guidance_scale=1.0,
num_inference_steps=num_inference_steps,
# generator=torch.Generator(device=device).manual_seed(3)
).images[0]
image.save(f"{timestamp}_{num_inference_steps}steps.png")
print(f"infer time : {time.time() - start_time}")

68
utils/new_oss_client.py Normal file
View File

@@ -0,0 +1,68 @@
import io
import logging
from io import BytesIO
import urllib3
from PIL import Image
from minio import Minio
MINIO_URL = "www.minio-api.aida.com.hk"
MINIO_ACCESS = "vXKFLSJkYeEq2DrSZvkB"
MINIO_SECRET = "uKTZT3x7C43WvPN9QTc99DiRkwddWZrG9Uh3JVlR"
MINIO_SECURE = True
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
# 自定义 Retry 类
class CustomRetry(urllib3.Retry):
def increment(self, method=None, url=None, response=None, error=None, **kwargs):
# 调用父类的 increment 方法
new_retry = super(CustomRetry, self).increment(method, url, response, error, **kwargs)
# 打印重试信息
logger.info(f"重试连接: {method} {url},错误: {error},重试次数: {self.total - new_retry.total}")
return new_retry
logger = logging.getLogger()
timeout = urllib3.Timeout(connect=1, read=10.0) # 连接超时 5 秒,读取超时 10 秒
http_client = urllib3.PoolManager(
num_pools=10, # 设置连接池大小
maxsize=10,
timeout=timeout,
cert_reqs='CERT_REQUIRED', # 需要证书验证
retries=CustomRetry(
total=5,
backoff_factor=0.2,
status_forcelist=[500, 502, 503, 504],
),
)
# 获取图片
def oss_get_image(oss_client, bucket, object_name):
# cv2 默认全通道读取
image_object = None
try:
image_data = oss_client.get_object(bucket_name=bucket, object_name=object_name)
data_bytes = BytesIO(image_data.read())
image_object = Image.open(data_bytes)
except Exception as e:
logger.warning(f" | 获取图片出现异常 ######: {e}")
return image_object
def oss_upload_image(oss_client, bucket, object_name, image_bytes):
req = None
try:
req = oss_client.put_object(bucket_name=bucket, object_name=object_name, data=io.BytesIO(image_bytes), length=len(image_bytes), content_type='image/png')
except Exception as e:
logger.warning(f" | 上传图片出现异常 ######: {e}")
return req
if __name__ == '__main__':
url = "fida-test/furniture/sketches/4449a66d-6267-43f7-86a2-1e42bd19ec61.png"
read_type = "2"
img = oss_get_image(oss_client=minio_client, bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type)
img.show()
img.save("result.png")

1724
uv.lock generated Normal file

File diff suppressed because it is too large Load Diff