1
.gitignore
vendored
1
.gitignore
vendored
@@ -230,3 +230,4 @@ $RECYCLE.BIN/
|
|||||||
# End of https://www.toptal.com/developers/gitignore/api/linux,windows,macos,visualstudiocode,python
|
# End of https://www.toptal.com/developers/gitignore/api/linux,windows,macos,visualstudiocode,python
|
||||||
|
|
||||||
output/
|
output/
|
||||||
|
*.png
|
||||||
10
.idea/.gitignore
generated
vendored
Normal file
10
.idea/.gitignore
generated
vendored
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
# 默认忽略的文件
|
||||||
|
/shelf/
|
||||||
|
/workspace.xml
|
||||||
|
# 已忽略包含查询文件的默认文件夹
|
||||||
|
/queries/
|
||||||
|
# Datasource local storage ignored files
|
||||||
|
/dataSources/
|
||||||
|
/dataSources.local.xml
|
||||||
|
# 基于编辑器的 HTTP 客户端请求
|
||||||
|
/httpRequests/
|
||||||
15
.idea/flux2.iml
generated
Normal file
15
.idea/flux2.iml
generated
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<module type="PYTHON_MODULE" version="4">
|
||||||
|
<component name="NewModuleRootManager">
|
||||||
|
<content url="file://$MODULE_DIR$">
|
||||||
|
<sourceFolder url="file://$MODULE_DIR$/src" isTestSource="false" />
|
||||||
|
<excludeFolder url="file://$MODULE_DIR$/.venv" />
|
||||||
|
</content>
|
||||||
|
<orderEntry type="jdk" jdkName="uv (flux2)" jdkType="Python SDK" />
|
||||||
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
|
</component>
|
||||||
|
<component name="PyDocumentationSettings">
|
||||||
|
<option name="format" value="PLAIN" />
|
||||||
|
<option name="myDocStringFormat" value="Plain" />
|
||||||
|
</component>
|
||||||
|
</module>
|
||||||
39
.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
39
.idea/inspectionProfiles/Project_Default.xml
generated
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
<component name="InspectionProjectProfileManager">
|
||||||
|
<profile version="1.0">
|
||||||
|
<option name="myName" value="Project Default" />
|
||||||
|
<inspection_tool class="DuplicatedCode" enabled="false" level="WEAK WARNING" enabled_by_default="false" />
|
||||||
|
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
|
||||||
|
<inspection_tool class="GrazieInspection" enabled="false" level="GRAMMAR_ERROR" enabled_by_default="false" />
|
||||||
|
<inspection_tool class="HttpUrlsUsage" enabled="false" level="WEAK WARNING" enabled_by_default="false" />
|
||||||
|
<inspection_tool class="OutdatedRequirementInspection" enabled="false" level="WEAK WARNING" enabled_by_default="false" />
|
||||||
|
<inspection_tool class="PyBroadExceptionInspection" enabled="false" level="WEAK WARNING" enabled_by_default="false" />
|
||||||
|
<inspection_tool class="PyDeprecationInspection" enabled="false" level="WARNING" enabled_by_default="false" />
|
||||||
|
<inspection_tool class="PyPep8NamingInspection" enabled="false" level="WEAK WARNING" enabled_by_default="false">
|
||||||
|
<option name="ignoredErrors">
|
||||||
|
<list>
|
||||||
|
<option value="N801" />
|
||||||
|
<option value="N806" />
|
||||||
|
</list>
|
||||||
|
</option>
|
||||||
|
</inspection_tool>
|
||||||
|
<inspection_tool class="PyShadowingBuiltinsInspection" enabled="true" level="WEAK WARNING" enabled_by_default="true">
|
||||||
|
<option name="ignoredNames">
|
||||||
|
<list>
|
||||||
|
<option value="input" />
|
||||||
|
<option value="type" />
|
||||||
|
<option value="object" />
|
||||||
|
<option value="id" />
|
||||||
|
</list>
|
||||||
|
</option>
|
||||||
|
</inspection_tool>
|
||||||
|
<inspection_tool class="PyShadowingNamesInspection" enabled="false" level="WEAK WARNING" enabled_by_default="false" />
|
||||||
|
<inspection_tool class="PyTypeCheckerInspection" enabled="false" level="WARNING" enabled_by_default="false" />
|
||||||
|
<inspection_tool class="SpellCheckingInspection" enabled="false" level="TYPO" enabled_by_default="false">
|
||||||
|
<option name="processCode" value="true" />
|
||||||
|
<option name="processLiterals" value="true" />
|
||||||
|
<option name="processComments" value="false" />
|
||||||
|
</inspection_tool>
|
||||||
|
<inspection_tool class="UnsatisfiedRequirementInspection" enabled="false" level="WARNING" enabled_by_default="false" />
|
||||||
|
<inspection_tool class="VulnerableLibrariesLocal" enabled="false" level="WARNING" enabled_by_default="false" />
|
||||||
|
</profile>
|
||||||
|
</component>
|
||||||
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
6
.idea/inspectionProfiles/profiles_settings.xml
generated
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
<component name="InspectionProjectProfileManager">
|
||||||
|
<settings>
|
||||||
|
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||||
|
<version value="1.0" />
|
||||||
|
</settings>
|
||||||
|
</component>
|
||||||
7
.idea/misc.xml
generated
Normal file
7
.idea/misc.xml
generated
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="Black">
|
||||||
|
<option name="sdkName" value="uv (flux2)" />
|
||||||
|
</component>
|
||||||
|
<component name="ProjectRootManager" version="2" project-jdk-name="uv (flux2)" project-jdk-type="Python SDK" />
|
||||||
|
</project>
|
||||||
8
.idea/modules.xml
generated
Normal file
8
.idea/modules.xml
generated
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="ProjectModuleManager">
|
||||||
|
<modules>
|
||||||
|
<module fileurl="file://$PROJECT_DIR$/.idea/flux2.iml" filepath="$PROJECT_DIR$/.idea/flux2.iml" />
|
||||||
|
</modules>
|
||||||
|
</component>
|
||||||
|
</project>
|
||||||
6
.idea/vcs.xml
generated
Normal file
6
.idea/vcs.xml
generated
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="VcsDirectoryMappings">
|
||||||
|
<mapping directory="" vcs="Git" />
|
||||||
|
</component>
|
||||||
|
</project>
|
||||||
10
client.py
Normal file
10
client.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
# This file is auto-generated by LitServe.
|
||||||
|
# Disable auto-generation by setting `generate_client_file=False` in `LitServer.run()`.
|
||||||
|
import time
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
response = requests.post("http://127.0.0.1:8011/predict", json={"prompt": "变成一只猫"})
|
||||||
|
# print(f"Status: {response.status_code}\nResponse:\n {response.text}")
|
||||||
|
print(time.time() - start_time)
|
||||||
110
litserve/app.py
Normal file
110
litserve/app.py
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
import torch
|
||||||
|
import litserve as ls
|
||||||
|
from PIL import Image
|
||||||
|
import io
|
||||||
|
import base64
|
||||||
|
from diffusers import Flux2KleinPipeline
|
||||||
|
|
||||||
|
|
||||||
|
# 保持原有的辅助函数
|
||||||
|
def aspect_to_wh(aspect_ratio: str, base_long_edge: int) -> tuple[int, int]:
|
||||||
|
w_str, h_str = aspect_ratio.split(":")
|
||||||
|
w, h = float(w_str), float(h_str)
|
||||||
|
if w >= h:
|
||||||
|
width = base_long_edge
|
||||||
|
height = int(round(base_long_edge * (h / w)))
|
||||||
|
else:
|
||||||
|
height = base_long_edge
|
||||||
|
width = int(round(base_long_edge * (w / h)))
|
||||||
|
width = max(64, (width // 8) * 8)
|
||||||
|
height = max(64, (height // 8) * 8)
|
||||||
|
return width, height
|
||||||
|
|
||||||
|
|
||||||
|
class FluxKleinAPI(ls.LitAPI):
|
||||||
|
def setup(self, device):
|
||||||
|
# 1. 模型初始化
|
||||||
|
self.repo_id = "black-forest-labs/FLUX.2-klein-4B"
|
||||||
|
self.device = device
|
||||||
|
self.dtype = torch.bfloat16
|
||||||
|
|
||||||
|
self.pipe = Flux2KleinPipeline.from_pretrained(
|
||||||
|
self.repo_id,
|
||||||
|
torch_dtype=self.dtype
|
||||||
|
)
|
||||||
|
self.pipe.to(device)
|
||||||
|
|
||||||
|
self.default_system_prompt = (
|
||||||
|
"Pure white background (#FFFFFF) with a seamless white studio backdrop; no environment... "
|
||||||
|
"Realistic scale, high material fidelity, sharp focus, 4K quality."
|
||||||
|
)
|
||||||
|
|
||||||
|
def decode_request(self, request):
|
||||||
|
# 1. 获取 Base64 字符串
|
||||||
|
image_data = request.get("image", None)
|
||||||
|
if not image_data:
|
||||||
|
raise ValueError("Request must contain 'image' field with base64 data.")
|
||||||
|
|
||||||
|
# 2. 如果带有 data:image/png;base64, 前缀,去掉它
|
||||||
|
if "," in image_data:
|
||||||
|
image_data = image_data.split(",")[1]
|
||||||
|
|
||||||
|
# 3. 解码并转换为 PIL 对象 (修复点)
|
||||||
|
img_bytes = base64.b64decode(image_data)
|
||||||
|
img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"image": img,
|
||||||
|
"prompt": request.get("prompt", ""),
|
||||||
|
"steps": request.get("steps", 28),
|
||||||
|
"guidance": request.get("guidance", 4.0),
|
||||||
|
"seed": request.get("seed", 42),
|
||||||
|
"aspect_ratio": request.get("aspect_ratio", "4:3"),
|
||||||
|
"base_long_edge": request.get("base_long_edge", 1024)
|
||||||
|
}
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def predict(self, payload):
|
||||||
|
# 3. 执行推理逻辑
|
||||||
|
img = payload.get("image", None)
|
||||||
|
prompt = payload.get("prompt", "")
|
||||||
|
W, H = aspect_to_wh(payload["aspect_ratio"], payload["base_long_edge"])
|
||||||
|
gen = torch.Generator(device=self.device)
|
||||||
|
|
||||||
|
if img:
|
||||||
|
# 这里简化了 resize 逻辑,实际可调用你原有的 resize_pad_to_target
|
||||||
|
img = img.convert("RGB").resize((W, H), Image.Resampling.LANCZOS)
|
||||||
|
output = self.pipe(
|
||||||
|
image=img,
|
||||||
|
prompt=prompt,
|
||||||
|
height=H,
|
||||||
|
width=W,
|
||||||
|
guidance_scale=payload["guidance"],
|
||||||
|
num_inference_steps=payload["steps"],
|
||||||
|
generator=gen,
|
||||||
|
).images[0]
|
||||||
|
else:
|
||||||
|
output = self.pipe(
|
||||||
|
image=img,
|
||||||
|
prompt=prompt,
|
||||||
|
height=H,
|
||||||
|
width=W,
|
||||||
|
guidance_scale=payload["guidance"],
|
||||||
|
num_inference_steps=payload["steps"],
|
||||||
|
generator=gen,
|
||||||
|
).images[0]
|
||||||
|
return output
|
||||||
|
|
||||||
|
def encode_response(self, output):
|
||||||
|
# 4. 将 PIL 图像转回 Base64 返回
|
||||||
|
buffered = io.BytesIO()
|
||||||
|
output.save(buffered, format="PNG")
|
||||||
|
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
|
||||||
|
return {"image": img_str}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 启动服务器
|
||||||
|
api = FluxKleinAPI()
|
||||||
|
server = ls.LitServer(api, accelerator="cuda", devices=1)
|
||||||
|
server.run(port=8451)
|
||||||
389
litserve/cli.py
Normal file
389
litserve/cli.py
Normal file
@@ -0,0 +1,389 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
FLUX.2 [klein] image-to-image editing with CLI args (single or batch).
|
||||||
|
|
||||||
|
Prereqs:
|
||||||
|
pip install -U "diffusers>=0.30.0" transformers accelerate safetensors pillow
|
||||||
|
# and you have CUDA set up for torch.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
# Single image
|
||||||
|
python flux2_klein_cli.py \
|
||||||
|
--image test_sketch/sofa_2.png \
|
||||||
|
--user_input "caramel brown leather, matte black legs" \
|
||||||
|
--out outputs/sofa_2_to_real.png
|
||||||
|
|
||||||
|
# Batch: one prompt applied to all images
|
||||||
|
python flux2_klein_cli.py \
|
||||||
|
--image_dir test_sketch \
|
||||||
|
--user_input "pure white background, photorealistic materials" \
|
||||||
|
--out_dir outputs/batch
|
||||||
|
|
||||||
|
# Batch: a prompt per image using a prompt file (mapping or list)
|
||||||
|
python flux2_klein_cli.py \
|
||||||
|
--image_dir test_sketch \
|
||||||
|
--prompt_file prompt.txt \
|
||||||
|
--out_dir outputs/batch
|
||||||
|
|
||||||
|
# Optional knobs
|
||||||
|
python flux2_klein_cli.py \
|
||||||
|
--image test_sketch/a.png \
|
||||||
|
--user_input "..." \
|
||||||
|
--out outputs/a.png \
|
||||||
|
--aspect_ratio 4:3 \
|
||||||
|
--base_long_edge 1024 \
|
||||||
|
--steps 28 \
|
||||||
|
--guidance 4.0 \
|
||||||
|
--seed 0 \
|
||||||
|
--cpu_offload
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Tuple, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL import Image, ImageOps
|
||||||
|
from diffusers import Flux2KleinPipeline
|
||||||
|
from diffusers.utils import load_image
|
||||||
|
|
||||||
|
DEFAULT_REPO_ID = "black-forest-labs/FLUX.2-klein-4B"
|
||||||
|
DEFAULT_DEVICE = "cuda"
|
||||||
|
DEFAULT_TORCH_DTYPE = "bfloat16"
|
||||||
|
|
||||||
|
DEFAULT_SYSTEM_PROMPT = (
|
||||||
|
"Pure white background (#FFFFFF) with a seamless white studio backdrop; no environment, no interior, "
|
||||||
|
"no floor texture, and no visible floor/wall lines. Studio product photography lighting with soft "
|
||||||
|
"diffused light; ultra-clean industrial design presentation. Render a single furniture object only, "
|
||||||
|
"isolated and perfectly centered on a borderless canvas. Strictly preserve the geometry, silhouette, "
|
||||||
|
"proportions, and key structural details of the reference sketch (do not change shape, do not add/remove parts). "
|
||||||
|
"Realistic scale, high material fidelity, sharp focus, 4K quality. No people, no props, no text, no logos, no watermark."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------
|
||||||
|
# Prompt parsing (optional)
|
||||||
|
# -----------------------------
|
||||||
|
def load_prompts(prompt_txt: Path) -> Tuple[Dict[str, str], List[str]]:
|
||||||
|
"""
|
||||||
|
Supports:
|
||||||
|
A) "sofa_2.png: caramel brown leather ..." (filename: user_input)
|
||||||
|
B) "caramel brown leather ..." (user_input only; matched by sorted image filenames)
|
||||||
|
"""
|
||||||
|
if not prompt_txt.exists():
|
||||||
|
raise FileNotFoundError(f"prompt file not found: {prompt_txt}")
|
||||||
|
|
||||||
|
lines: List[str] = []
|
||||||
|
for raw in prompt_txt.read_text(encoding="utf-8").splitlines():
|
||||||
|
s = raw.strip()
|
||||||
|
if not s or s.startswith("#"):
|
||||||
|
continue
|
||||||
|
lines.append(s)
|
||||||
|
|
||||||
|
mapping: Dict[str, str] = {}
|
||||||
|
list_only: List[str] = []
|
||||||
|
|
||||||
|
looks_like_mapping = any(
|
||||||
|
(":" in ln and ln.split(":", 1)[0].strip().lower().endswith((".png", ".jpg", ".jpeg", ".webp")))
|
||||||
|
for ln in lines
|
||||||
|
)
|
||||||
|
|
||||||
|
if looks_like_mapping:
|
||||||
|
for ln in lines:
|
||||||
|
if ":" not in ln:
|
||||||
|
continue
|
||||||
|
k, v = ln.split(":", 1)
|
||||||
|
key = k.strip()
|
||||||
|
val = v.strip()
|
||||||
|
if key and val:
|
||||||
|
mapping[key] = val
|
||||||
|
else:
|
||||||
|
list_only = lines
|
||||||
|
|
||||||
|
return mapping, list_only
|
||||||
|
|
||||||
|
|
||||||
|
def list_images(image_dir: Path) -> List[Path]:
|
||||||
|
exts = {".png", ".jpg", ".jpeg", ".webp"}
|
||||||
|
imgs = [p for p in sorted(image_dir.iterdir()) if p.is_file() and p.suffix.lower() in exts]
|
||||||
|
if not imgs:
|
||||||
|
raise RuntimeError(f"No images found in: {image_dir}")
|
||||||
|
return imgs
|
||||||
|
|
||||||
|
|
||||||
|
def build_final_prompt(system_prompt: str, user_input: str) -> str:
|
||||||
|
user_input = (user_input or "").strip()
|
||||||
|
if not user_input:
|
||||||
|
return system_prompt
|
||||||
|
return f"{system_prompt} User requirements: {user_input}"
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------
|
||||||
|
# Aspect ratio utils (resize+pad)
|
||||||
|
# -----------------------------
|
||||||
|
def aspect_to_wh(aspect_ratio: str, base_long_edge: int) -> tuple[int, int]:
|
||||||
|
w_str, h_str = aspect_ratio.split(":")
|
||||||
|
w, h = float(w_str), float(h_str)
|
||||||
|
if w >= h:
|
||||||
|
width = base_long_edge
|
||||||
|
height = int(round(base_long_edge * (h / w)))
|
||||||
|
else:
|
||||||
|
height = base_long_edge
|
||||||
|
width = int(round(base_long_edge * (w / h)))
|
||||||
|
|
||||||
|
width = max(64, (width // 8) * 8)
|
||||||
|
height = max(64, (height // 8) * 8)
|
||||||
|
return width, height
|
||||||
|
|
||||||
|
|
||||||
|
def resize_pad_to_target(img: Image.Image, width: int, height: int, fill=(255, 255, 255)) -> Image.Image:
|
||||||
|
img = img.convert("RGB")
|
||||||
|
contained = ImageOps.contain(img, (width, height), method=Image.Resampling.LANCZOS)
|
||||||
|
pad_w = width - contained.size[0]
|
||||||
|
pad_h = height - contained.size[1]
|
||||||
|
padding = (pad_w // 2, pad_h // 2, pad_w - pad_w // 2, pad_h - pad_h // 2)
|
||||||
|
return ImageOps.expand(contained, border=padding, fill=fill)
|
||||||
|
|
||||||
|
|
||||||
|
def resize_to_long_edge(img: Image.Image, base_long_edge: int) -> Image.Image:
|
||||||
|
img = img.convert("RGB")
|
||||||
|
w, h = img.size
|
||||||
|
if max(w, h) == base_long_edge:
|
||||||
|
return img
|
||||||
|
if w >= h:
|
||||||
|
nw, nh = base_long_edge, int(round(base_long_edge * h / w))
|
||||||
|
else:
|
||||||
|
nw, nh = int(round(base_long_edge * w / h)), base_long_edge
|
||||||
|
nw = max(64, (nw // 8) * 8)
|
||||||
|
nh = max(64, (nh // 8) * 8)
|
||||||
|
return img.resize((nw, nh), Image.Resampling.LANCZOS)
|
||||||
|
|
||||||
|
|
||||||
|
# -----------------------------
|
||||||
|
# Run one
|
||||||
|
# -----------------------------
|
||||||
|
@torch.inference_mode()
|
||||||
|
def run_one(
|
||||||
|
pipe: Flux2KleinPipeline,
|
||||||
|
device: str,
|
||||||
|
image_path: Path,
|
||||||
|
final_prompt: str,
|
||||||
|
out_path: Path,
|
||||||
|
aspect_ratio: Optional[str],
|
||||||
|
base_long_edge: int,
|
||||||
|
steps: int,
|
||||||
|
guidance: float,
|
||||||
|
seed: int,
|
||||||
|
) -> None:
|
||||||
|
ref = load_image(str(image_path)) # PIL.Image
|
||||||
|
|
||||||
|
if aspect_ratio:
|
||||||
|
W, H = aspect_to_wh(aspect_ratio, base_long_edge)
|
||||||
|
ref = resize_pad_to_target(ref, W, H)
|
||||||
|
else:
|
||||||
|
ref = resize_to_long_edge(ref, base_long_edge)
|
||||||
|
W, H = ref.size
|
||||||
|
|
||||||
|
gen = torch.Generator(device=device).manual_seed(seed)
|
||||||
|
|
||||||
|
result = pipe(
|
||||||
|
image=ref,
|
||||||
|
prompt=final_prompt,
|
||||||
|
height=H,
|
||||||
|
width=W,
|
||||||
|
guidance_scale=guidance,
|
||||||
|
num_inference_steps=steps,
|
||||||
|
generator=gen,
|
||||||
|
).images[0]
|
||||||
|
|
||||||
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
result.save(out_path)
|
||||||
|
print(f"[OK] saved: {out_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args() -> argparse.Namespace:
|
||||||
|
ap = argparse.ArgumentParser(description="FLUX.2 [klein] img2img editing with CLI args.")
|
||||||
|
|
||||||
|
io = ap.add_argument_group("I/O")
|
||||||
|
io.add_argument("--image", type=str, default="", help="Single input image path.")
|
||||||
|
io.add_argument("--image_dir", type=str, default="", help="Directory of images for batch mode.")
|
||||||
|
io.add_argument("--out", type=str, default="", help="Single output image path (for --image).")
|
||||||
|
io.add_argument("--out_dir", type=str, default="", help="Output directory (for --image_dir).")
|
||||||
|
|
||||||
|
prompt = ap.add_argument_group("Prompt")
|
||||||
|
prompt.add_argument("--user_input", type=str, default="", help="User prompt (applied to single image or all images).")
|
||||||
|
prompt.add_argument(
|
||||||
|
"--prompt_file",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help='Optional prompt file for batch mode. Supports "filename: prompt" mapping or one-prompt-per-line list.',
|
||||||
|
)
|
||||||
|
prompt.add_argument("--system_prompt", type=str, default=DEFAULT_SYSTEM_PROMPT, help="System prompt string.")
|
||||||
|
|
||||||
|
model = ap.add_argument_group("Model/Runtime")
|
||||||
|
model.add_argument("--repo_id", type=str, default=DEFAULT_REPO_ID, help=f"HF repo id (default: {DEFAULT_REPO_ID}).")
|
||||||
|
model.add_argument("--device", type=str, default=DEFAULT_DEVICE, help='Device, e.g. "cuda" or "cpu".')
|
||||||
|
model.add_argument(
|
||||||
|
"--torch_dtype",
|
||||||
|
type=str,
|
||||||
|
default=DEFAULT_TORCH_DTYPE,
|
||||||
|
choices=["float16", "bfloat16", "float32"],
|
||||||
|
help="torch dtype for loading weights.",
|
||||||
|
)
|
||||||
|
model.add_argument("--cpu_offload", action="store_true", help="Enable model CPU offload (saves VRAM).")
|
||||||
|
|
||||||
|
gen = ap.add_argument_group("Generation")
|
||||||
|
gen.add_argument("--aspect_ratio", type=str, default="4:3", help='e.g. "16:9", "1:1". Set "" to disable.')
|
||||||
|
gen.add_argument("--base_long_edge", type=int, default=1024, help="Target long edge for output/resized reference.")
|
||||||
|
gen.add_argument("--steps", type=int, default=28, help="num_inference_steps.")
|
||||||
|
gen.add_argument("--guidance", type=float, default=4.0, help="guidance_scale.")
|
||||||
|
gen.add_argument("--seed", type=int, default=0, help="Random seed.")
|
||||||
|
|
||||||
|
args = ap.parse_args()
|
||||||
|
|
||||||
|
# Validate mode
|
||||||
|
if bool(args.image) == bool(args.image_dir):
|
||||||
|
raise SystemExit("Provide exactly one of --image or --image_dir.")
|
||||||
|
|
||||||
|
if args.image:
|
||||||
|
if not args.out:
|
||||||
|
raise SystemExit("In single-image mode, --out is required.")
|
||||||
|
if not args.user_input and not args.system_prompt:
|
||||||
|
raise SystemExit("Provide --user_input and/or --system_prompt.")
|
||||||
|
else:
|
||||||
|
if not args.out_dir:
|
||||||
|
raise SystemExit("In batch mode, --out_dir is required.")
|
||||||
|
if not args.prompt_file and not args.user_input:
|
||||||
|
raise SystemExit("Batch mode requires either --prompt_file or --user_input.")
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_torch_dtype(dtype_str: str) -> torch.dtype:
|
||||||
|
if dtype_str == "float16":
|
||||||
|
return torch.float16
|
||||||
|
if dtype_str == "bfloat16":
|
||||||
|
return torch.bfloat16
|
||||||
|
return torch.float32
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
# Normalize aspect_ratio
|
||||||
|
aspect_ratio = (args.aspect_ratio or "").strip()
|
||||||
|
if aspect_ratio == "":
|
||||||
|
aspect_ratio = None
|
||||||
|
|
||||||
|
torch_dtype = resolve_torch_dtype(args.torch_dtype)
|
||||||
|
|
||||||
|
pipe = Flux2KleinPipeline.from_pretrained(args.repo_id, torch_dtype=torch_dtype)
|
||||||
|
if args.cpu_offload:
|
||||||
|
pipe.enable_model_cpu_offload()
|
||||||
|
else:
|
||||||
|
pipe.to(args.device)
|
||||||
|
|
||||||
|
system_prompt = args.system_prompt
|
||||||
|
|
||||||
|
if args.image:
|
||||||
|
img_path = Path(args.image)
|
||||||
|
if not img_path.exists():
|
||||||
|
raise FileNotFoundError(f"Image not found: {img_path}")
|
||||||
|
out_path = Path(args.out)
|
||||||
|
|
||||||
|
final_prompt = build_final_prompt(system_prompt, args.user_input)
|
||||||
|
run_one(
|
||||||
|
pipe=pipe,
|
||||||
|
device=args.device,
|
||||||
|
image_path=img_path,
|
||||||
|
final_prompt=final_prompt,
|
||||||
|
out_path=out_path,
|
||||||
|
aspect_ratio=aspect_ratio,
|
||||||
|
base_long_edge=args.base_long_edge,
|
||||||
|
steps=args.steps,
|
||||||
|
guidance=args.guidance,
|
||||||
|
seed=args.seed,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Batch mode
|
||||||
|
image_dir = Path(args.image_dir)
|
||||||
|
if not image_dir.exists():
|
||||||
|
raise FileNotFoundError(f"Image dir not found: {image_dir}")
|
||||||
|
|
||||||
|
images = list_images(image_dir)
|
||||||
|
out_dir = Path(args.out_dir)
|
||||||
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
mapping: Dict[str, str] = {}
|
||||||
|
list_only: List[str] = []
|
||||||
|
|
||||||
|
if args.prompt_file:
|
||||||
|
mapping, list_only = load_prompts(Path(args.prompt_file))
|
||||||
|
|
||||||
|
if mapping:
|
||||||
|
# Mapping mode: filename -> prompt
|
||||||
|
for img_path in images:
|
||||||
|
user_input = mapping.get(img_path.name, "")
|
||||||
|
final_prompt = build_final_prompt(system_prompt, user_input)
|
||||||
|
out_path = out_dir / f"{img_path.stem}_to_real_flux2klein.png"
|
||||||
|
run_one(
|
||||||
|
pipe=pipe,
|
||||||
|
device=args.device,
|
||||||
|
image_path=img_path,
|
||||||
|
final_prompt=final_prompt,
|
||||||
|
out_path=out_path,
|
||||||
|
aspect_ratio=aspect_ratio,
|
||||||
|
base_long_edge=args.base_long_edge,
|
||||||
|
steps=args.steps,
|
||||||
|
guidance=args.guidance,
|
||||||
|
seed=args.seed,
|
||||||
|
)
|
||||||
|
elif list_only:
|
||||||
|
# List mode: one prompt per line, matched by sorted filenames
|
||||||
|
n = min(len(list_only), len(images))
|
||||||
|
if len(list_only) != len(images):
|
||||||
|
print(
|
||||||
|
f"[WARN] prompt lines ({len(list_only)}) != images ({len(images)}). "
|
||||||
|
"Will match by min length and leave the rest empty."
|
||||||
|
)
|
||||||
|
for i, img_path in enumerate(images):
|
||||||
|
user_input = list_only[i] if i < n else ""
|
||||||
|
final_prompt = build_final_prompt(system_prompt, user_input)
|
||||||
|
out_path = out_dir / f"{img_path.stem}_to_real_flux2klein.png"
|
||||||
|
run_one(
|
||||||
|
pipe=pipe,
|
||||||
|
device=args.device,
|
||||||
|
image_path=img_path,
|
||||||
|
final_prompt=final_prompt,
|
||||||
|
out_path=out_path,
|
||||||
|
aspect_ratio=aspect_ratio,
|
||||||
|
base_long_edge=args.base_long_edge,
|
||||||
|
steps=args.steps,
|
||||||
|
guidance=args.guidance,
|
||||||
|
seed=args.seed,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# No prompt file parsed (or no prompt file). Apply --user_input to all
|
||||||
|
for img_path in images:
|
||||||
|
final_prompt = build_final_prompt(system_prompt, args.user_input)
|
||||||
|
out_path = out_dir / f"{img_path.stem}_to_real_flux2klein.png"
|
||||||
|
run_one(
|
||||||
|
pipe=pipe,
|
||||||
|
device=args.device,
|
||||||
|
image_path=img_path,
|
||||||
|
final_prompt=final_prompt,
|
||||||
|
out_path=out_path,
|
||||||
|
aspect_ratio=aspect_ratio,
|
||||||
|
base_long_edge=args.base_long_edge,
|
||||||
|
steps=args.steps,
|
||||||
|
guidance=args.guidance,
|
||||||
|
seed=args.seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
17
litserve/client.py
Normal file
17
litserve/client.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
import requests
|
||||||
|
import base64
|
||||||
|
|
||||||
|
# 将你的图片转为 base64
|
||||||
|
with open("/mnt/data/workspace/Code/flux2/20260123_152354_2steps.png", "rb") as f:
|
||||||
|
img_base64 = base64.b64encode(f.read()).decode("utf-8")
|
||||||
|
|
||||||
|
response = requests.post("http://localhost:8451/predict", json={
|
||||||
|
# "image": img_base64,
|
||||||
|
"prompt": "紫色实木窗帘",
|
||||||
|
"aspect_ratio": "1:1",
|
||||||
|
"steps": 4
|
||||||
|
})
|
||||||
|
|
||||||
|
# 保存结果
|
||||||
|
with open("result.png", "wb") as f:
|
||||||
|
f.write(base64.b64decode(response.json()["image"]))
|
||||||
84
litserve/model.py
Normal file
84
litserve/model.py
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
# model.py
|
||||||
|
import torch
|
||||||
|
from diffusers import FluxPipeline # 假设你用的是 flux from diffusers
|
||||||
|
from PIL import Image
|
||||||
|
import os
|
||||||
|
from typing import List, Optional, Dict, Any
|
||||||
|
|
||||||
|
class Flux2KleinModel:
|
||||||
|
def __init__(self, model_id="black-forest-labs/FLUX.1-dev", device="cuda"):
|
||||||
|
print(f"Loading Flux model on {device} ...")
|
||||||
|
self.pipe = FluxPipeline.from_pretrained(
|
||||||
|
model_id,
|
||||||
|
torch_dtype=torch.bfloat16
|
||||||
|
)
|
||||||
|
self.pipe.enable_model_cpu_offload() # 或 .to(device) 看显存情况
|
||||||
|
# self.pipe.vae.enable_slicing() # 可选,省显存
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
def preprocess_image(self, image_path: str) -> Image.Image:
|
||||||
|
img = Image.open(image_path).convert("RGB")
|
||||||
|
return img
|
||||||
|
|
||||||
|
def run_inference(
|
||||||
|
self,
|
||||||
|
image: Image.Image,
|
||||||
|
prompt: str,
|
||||||
|
aspect_ratio: str = "1:1",
|
||||||
|
base_long_edge: int = 1024,
|
||||||
|
steps: int = 28,
|
||||||
|
guidance: float = 4.0,
|
||||||
|
seed: int = 0,
|
||||||
|
**kwargs
|
||||||
|
) -> Image.Image:
|
||||||
|
# 根据 aspect_ratio 计算 width/height(示例实现,可按需修改)
|
||||||
|
if ":" in aspect_ratio:
|
||||||
|
w, h = map(int, aspect_ratio.split(":"))
|
||||||
|
ratio = w / h
|
||||||
|
else:
|
||||||
|
ratio = 1.0
|
||||||
|
|
||||||
|
if ratio >= 1:
|
||||||
|
width = base_long_edge
|
||||||
|
height = int(base_long_edge / ratio)
|
||||||
|
else:
|
||||||
|
height = base_long_edge
|
||||||
|
width = int(base_long_edge * ratio)
|
||||||
|
|
||||||
|
# 取整到 8 的倍数(flux 常见要求)
|
||||||
|
width = (width // 8) * 8
|
||||||
|
height = (height // 8) * 8
|
||||||
|
|
||||||
|
generator = torch.Generator(device=self.device).manual_seed(seed)
|
||||||
|
|
||||||
|
result = self.pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
image=image,
|
||||||
|
height=height,
|
||||||
|
width=width,
|
||||||
|
num_inference_steps=steps,
|
||||||
|
guidance_scale=guidance,
|
||||||
|
generator=generator,
|
||||||
|
**kwargs
|
||||||
|
).images[0]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def batch_run(
|
||||||
|
self,
|
||||||
|
image_paths: List[str],
|
||||||
|
prompts: List[str],
|
||||||
|
out_paths: List[str],
|
||||||
|
**common_kwargs
|
||||||
|
):
|
||||||
|
assert len(image_paths) == len(prompts) == len(out_paths)
|
||||||
|
|
||||||
|
for img_path, prompt, out_path in zip(image_paths, prompts, out_paths):
|
||||||
|
try:
|
||||||
|
img = self.preprocess_image(img_path)
|
||||||
|
result = self.run_inference(img, prompt, **common_kwargs)
|
||||||
|
os.makedirs(os.path.dirname(out_path), exist_ok=True)
|
||||||
|
result.save(out_path)
|
||||||
|
print(f"Saved: {out_path}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Failed {img_path}: {e}")
|
||||||
156
litserve_serve.py
Normal file
156
litserve_serve.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
import uuid
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from minio import Minio
|
||||||
|
|
||||||
|
import litserve as ls
|
||||||
|
from PIL import Image
|
||||||
|
import io
|
||||||
|
import base64
|
||||||
|
from diffusers import Flux2KleinPipeline
|
||||||
|
|
||||||
|
from utils.new_oss_client import oss_get_image, oss_upload_image, MINIO_URL, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE
|
||||||
|
|
||||||
|
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||||
|
|
||||||
|
|
||||||
|
# 保持原有的辅助函数
|
||||||
|
def aspect_to_wh(aspect_ratio: str, base_long_edge: int) -> tuple[int, int]:
|
||||||
|
w_str, h_str = aspect_ratio.split(":")
|
||||||
|
w, h = float(w_str), float(h_str)
|
||||||
|
if w >= h:
|
||||||
|
width = base_long_edge
|
||||||
|
height = int(round(base_long_edge * (h / w)))
|
||||||
|
else:
|
||||||
|
height = base_long_edge
|
||||||
|
width = int(round(base_long_edge * (w / h)))
|
||||||
|
width = max(64, (width // 8) * 8)
|
||||||
|
height = max(64, (height // 8) * 8)
|
||||||
|
return width, height
|
||||||
|
|
||||||
|
|
||||||
|
class FluxKleinAPI(ls.LitAPI):
|
||||||
|
def setup(self, device):
|
||||||
|
# 1. 模型初始化
|
||||||
|
self.repo_id = "black-forest-labs/FLUX.2-klein-4B"
|
||||||
|
self.device = device
|
||||||
|
self.dtype = torch.bfloat16
|
||||||
|
|
||||||
|
self.pipe = Flux2KleinPipeline.from_pretrained(
|
||||||
|
self.repo_id,
|
||||||
|
torch_dtype=self.dtype
|
||||||
|
)
|
||||||
|
self.pipe.to(device)
|
||||||
|
|
||||||
|
def decode_request(self, request):
|
||||||
|
"""
|
||||||
|
解析请求参数并加载OSS图片的接口函数
|
||||||
|
|
||||||
|
接口入参说明(request字典结构):
|
||||||
|
----------
|
||||||
|
request : dict
|
||||||
|
核心请求参数字典,各字段说明如下:
|
||||||
|
- input_image_paths : list[str] | None (可选)
|
||||||
|
OSS图片路径列表,格式为 "bucket/object_name"(如 "test/typical_b/uildi/ng_space_station.png")
|
||||||
|
若不传则为None,会导致后续图片加载失败,建议必传
|
||||||
|
- width : int (可选,默认值512)
|
||||||
|
图片宽度,默认512像素
|
||||||
|
- height : int (可选,默认值512)
|
||||||
|
图片高度,默认512像素
|
||||||
|
- bucket_name : str | None (可选)
|
||||||
|
OSS桶名,不传则为None
|
||||||
|
- object_name : str | None (可选)
|
||||||
|
OSS对象名(文件路径),不传则为None
|
||||||
|
- prompt : str (可选,默认值空字符串)
|
||||||
|
文本提示词,用于模型推理等场景
|
||||||
|
- steps : int (可选,默认值28)
|
||||||
|
推理步数,控制模型生成过程的迭代次数
|
||||||
|
- guidance : float (可选,默认值4.0)
|
||||||
|
引导系数,调节提示词对生成结果的影响程度
|
||||||
|
- seed : int (可选,默认值42)
|
||||||
|
随机种子,保证生成结果的可复现性
|
||||||
|
|
||||||
|
返回值说明
|
||||||
|
-------
|
||||||
|
dict
|
||||||
|
解析后的参数字典,包含:
|
||||||
|
- bucket_name: 请求中的桶名(None/字符串)
|
||||||
|
- object_name: 请求中的对象名(None/字符串)
|
||||||
|
- images: 从OSS加载的图片列表(按input_image_paths顺序)
|
||||||
|
- prompt: 文本提示词(默认空字符串)
|
||||||
|
- steps: 推理步数(默认28)
|
||||||
|
- guidance: 引导系数(默认4.0)
|
||||||
|
- seed: 随机种子(默认42)
|
||||||
|
- height: 图片高度(默认512)
|
||||||
|
- width: 图片宽度(默认512)
|
||||||
|
|
||||||
|
异常说明
|
||||||
|
-------
|
||||||
|
- 若input_image_paths非None但格式错误(无"/"分割且非空),可能导致rest[0]索引错误
|
||||||
|
- 若OSS图片加载失败(如路径不存在),oss_get_image会抛出对应异常
|
||||||
|
"""
|
||||||
|
input_image_paths = request.get("input_image_paths", None)
|
||||||
|
W = request.get("width", 512)
|
||||||
|
H = request.get("height", 512)
|
||||||
|
images = []
|
||||||
|
for path in input_image_paths:
|
||||||
|
bucket, *rest = path.split("/", 1) # 拆分为 ["test", "typical_b/uildi/ng_space_station.png"]
|
||||||
|
object_name = rest[0] if rest else ""
|
||||||
|
image = oss_get_image(oss_client=minio_client, bucket=bucket, object_name=object_name)
|
||||||
|
images.append(image)
|
||||||
|
return {
|
||||||
|
"bucket_name": request.get("bucket_name", None),
|
||||||
|
"object_name": request.get("object_name", None),
|
||||||
|
"images": images,
|
||||||
|
"prompt": request.get("prompt", ""),
|
||||||
|
"steps": request.get("steps", 4),
|
||||||
|
"guidance": request.get("guidance", 4.0),
|
||||||
|
"seed": request.get("seed", 42),
|
||||||
|
"height": H,
|
||||||
|
"width": W
|
||||||
|
}
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def predict(self, payload):
|
||||||
|
# 3. 执行推理逻辑
|
||||||
|
images = payload.get("images", [])
|
||||||
|
prompt = payload.get("prompt", "")
|
||||||
|
W, H = aspect_to_wh(payload["aspect_ratio"], payload["base_long_edge"])
|
||||||
|
gen = torch.Generator(device=self.device)
|
||||||
|
output = {}
|
||||||
|
if images:
|
||||||
|
output['im'] = self.pipe(
|
||||||
|
image=images,
|
||||||
|
prompt=prompt,
|
||||||
|
height=H,
|
||||||
|
width=W,
|
||||||
|
guidance_scale=payload["guidance"],
|
||||||
|
num_inference_steps=payload["steps"],
|
||||||
|
generator=gen,
|
||||||
|
).images[0]
|
||||||
|
else:
|
||||||
|
output = self.pipe(
|
||||||
|
prompt=prompt,
|
||||||
|
height=H,
|
||||||
|
width=W,
|
||||||
|
guidance_scale=payload["guidance"],
|
||||||
|
num_inference_steps=payload["steps"],
|
||||||
|
generator=gen,
|
||||||
|
).images[0]
|
||||||
|
image_data = io.BytesIO()
|
||||||
|
output.save(image_data, format='PNG')
|
||||||
|
image_data.seek(0)
|
||||||
|
image_bytes = image_data.read()
|
||||||
|
req = oss_upload_image(oss_client=minio_client, bucket=payload.get("bucket_name", "test"), object_name=payload.get("object_name", f"fida_generate_image/{uuid.uuid4().hex}.png"), image_bytes=image_bytes)
|
||||||
|
output_path = req.bucket_name + "/" + req.object_name
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
def encode_response(self, output_path):
|
||||||
|
return {"output_path": output_path}
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 启动服务器
|
||||||
|
api = FluxKleinAPI()
|
||||||
|
server = ls.LitServer(api, accelerator="cuda", devices=1)
|
||||||
|
server.run(port=8451)
|
||||||
@@ -18,6 +18,9 @@ dependencies = [
|
|||||||
"fire==0.7.1",
|
"fire==0.7.1",
|
||||||
"openai==2.8.1",
|
"openai==2.8.1",
|
||||||
"accelerate==1.12.0",
|
"accelerate==1.12.0",
|
||||||
|
"datetime>=6.0",
|
||||||
|
"litserve>=0.2.17",
|
||||||
|
"minio>=7.2.20",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
@@ -42,4 +45,4 @@ ignore = [
|
|||||||
quote-style = "double"
|
quote-style = "double"
|
||||||
indent-style = "space"
|
indent-style = "space"
|
||||||
line-ending = "auto"
|
line-ending = "auto"
|
||||||
docstring-code-format = true
|
docstring-code-format = true
|
||||||
|
|||||||
BIN
result1.png
Normal file
BIN
result1.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 963 KiB |
31
test.py
Normal file
31
test.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
import datetime
|
||||||
|
import time
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
from diffusers import Flux2KleinPipeline
|
||||||
|
|
||||||
|
device = "cuda"
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
|
||||||
|
pipe = Flux2KleinPipeline.from_pretrained("black-forest-labs/FLUX.2-klein-4B", torch_dtype=dtype, is_distilled=False)
|
||||||
|
pipe.to(device) # save some VRAM by offloading the model to CPUsave some VRAM by offloading the model to CPU
|
||||||
|
|
||||||
|
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
prompt = ""
|
||||||
|
num_inference_steps = 4
|
||||||
|
input_image = Image.open("result1.png")
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
image = pipe(
|
||||||
|
image=input_image,
|
||||||
|
prompt=prompt,
|
||||||
|
height=768,
|
||||||
|
width=512,
|
||||||
|
guidance_scale=1.0,
|
||||||
|
num_inference_steps=num_inference_steps,
|
||||||
|
# generator=torch.Generator(device=device).manual_seed(3)
|
||||||
|
).images[0]
|
||||||
|
|
||||||
|
image.save(f"{timestamp}_{num_inference_steps}steps.png")
|
||||||
|
print(f"infer time : {time.time() - start_time}")
|
||||||
68
utils/new_oss_client.py
Normal file
68
utils/new_oss_client.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
import io
|
||||||
|
import logging
|
||||||
|
from io import BytesIO
|
||||||
|
import urllib3
|
||||||
|
from PIL import Image
|
||||||
|
from minio import Minio
|
||||||
|
|
||||||
|
MINIO_URL = "www.minio-api.aida.com.hk"
|
||||||
|
MINIO_ACCESS = "vXKFLSJkYeEq2DrSZvkB"
|
||||||
|
MINIO_SECRET = "uKTZT3x7C43WvPN9QTc99DiRkwddWZrG9Uh3JVlR"
|
||||||
|
MINIO_SECURE = True
|
||||||
|
|
||||||
|
minio_client = Minio(MINIO_URL, access_key=MINIO_ACCESS, secret_key=MINIO_SECRET, secure=MINIO_SECURE)
|
||||||
|
|
||||||
|
|
||||||
|
# 自定义 Retry 类
|
||||||
|
class CustomRetry(urllib3.Retry):
|
||||||
|
def increment(self, method=None, url=None, response=None, error=None, **kwargs):
|
||||||
|
# 调用父类的 increment 方法
|
||||||
|
new_retry = super(CustomRetry, self).increment(method, url, response, error, **kwargs)
|
||||||
|
# 打印重试信息
|
||||||
|
logger.info(f"重试连接: {method} {url},错误: {error},重试次数: {self.total - new_retry.total}")
|
||||||
|
return new_retry
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger()
|
||||||
|
timeout = urllib3.Timeout(connect=1, read=10.0) # 连接超时 5 秒,读取超时 10 秒
|
||||||
|
http_client = urllib3.PoolManager(
|
||||||
|
num_pools=10, # 设置连接池大小
|
||||||
|
maxsize=10,
|
||||||
|
timeout=timeout,
|
||||||
|
cert_reqs='CERT_REQUIRED', # 需要证书验证
|
||||||
|
retries=CustomRetry(
|
||||||
|
total=5,
|
||||||
|
backoff_factor=0.2,
|
||||||
|
status_forcelist=[500, 502, 503, 504],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# 获取图片
|
||||||
|
def oss_get_image(oss_client, bucket, object_name):
|
||||||
|
# cv2 默认全通道读取
|
||||||
|
image_object = None
|
||||||
|
try:
|
||||||
|
image_data = oss_client.get_object(bucket_name=bucket, object_name=object_name)
|
||||||
|
data_bytes = BytesIO(image_data.read())
|
||||||
|
image_object = Image.open(data_bytes)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f" | 获取图片出现异常 ######: {e}")
|
||||||
|
return image_object
|
||||||
|
|
||||||
|
|
||||||
|
def oss_upload_image(oss_client, bucket, object_name, image_bytes):
|
||||||
|
req = None
|
||||||
|
try:
|
||||||
|
req = oss_client.put_object(bucket_name=bucket, object_name=object_name, data=io.BytesIO(image_bytes), length=len(image_bytes), content_type='image/png')
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f" | 上传图片出现异常 ######: {e}")
|
||||||
|
return req
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
url = "fida-test/furniture/sketches/4449a66d-6267-43f7-86a2-1e42bd19ec61.png"
|
||||||
|
read_type = "2"
|
||||||
|
img = oss_get_image(oss_client=minio_client, bucket=url.split('/')[0], object_name=url[url.find('/') + 1:], data_type=read_type)
|
||||||
|
img.show()
|
||||||
|
img.save("result.png")
|
||||||
Reference in New Issue
Block a user