This commit is contained in:
zhouchengrong
2024-03-20 11:44:15 +08:00
commit 69132570aa
29 changed files with 815 additions and 0 deletions

0
app/__init__.py Normal file
View File

0
app/api/__init__.py Normal file
View File

9
app/api/api_route.py Normal file
View File

@@ -0,0 +1,9 @@
from fastapi import APIRouter
from app.api import api_test
from app.api import api_super_resolution
router = APIRouter()
router.include_router(api_test.router, tags=["test"], prefix="/test")
router.include_router(api_super_resolution.router, tags=["api_super_resolution"], prefix="/api")

View File

@@ -0,0 +1,14 @@
from fastapi import APIRouter
from app.schemas.super_resolution import SuperResolutionModel
from app.service.super_resolution.service import SuperResolution
router = APIRouter()
@router.post("super_resolution")
def super_resolution(request_item: SuperResolutionModel):
service = SuperResolution()
sr_result_url = service.sr_result(request_item.sr_image_url, request_item.sr_xn)
response = {"sr_result_url": sr_result_url}
return {"code": 200, "message": "ok", "data": response}

14
app/api/api_test.py Normal file
View File

@@ -0,0 +1,14 @@
import logging
from fastapi import APIRouter
logger = logging.getLogger()
router = APIRouter()
@router.get("")
def test():
logger.info("test")
return {"message": "ok"}

0
app/core/__init__.py Normal file
View File

53
app/core/config.py Normal file
View File

@@ -0,0 +1,53 @@
import os
from dotenv import load_dotenv
from pydantic import BaseSettings
BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../'))
load_dotenv(os.path.join(BASE_DIR, '.env'))
class Settings(BaseSettings):
PROJECT_NAME = os.getenv('PROJECT_NAME', 'FASTAPI BASE')
SECRET_KEY = os.getenv('SECRET_KEY', '')
API_PREFIX = ''
BACKEND_CORS_ORIGINS = ['*']
DATABASE_URL = os.getenv('SQL_DATABASE_URL', '')
ACCESS_TOKEN_EXPIRE_SECONDS: int = 60 * 60 * 24 * 7 # Token expired after 7 days
SECURITY_ALGORITHM = 'HS256'
LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py')
settings = Settings()
ckpt = 'service/super_resolution_ccsr/weights/real-world_ccsr.ckpt'
config = 'service/super_resolution_ccsr/configs/model/ccsr_stage2.yaml'
steps = 45
sr_scale = 4
repeat_times = 1
tiled = False
tile_size = 512
tile_stride = 256
color_fix_type = "adain"
t_max = 0.6667
t_min = 0.3333
show_lq = False
skip_if_exist = False
seed = 233
device = "cuda"
tile_diffusion = False #
tile_diffusion_size = 512
tile_diffusion_stride = 256
tile_vae = True
vae_decoder_tile_size = 224
vae_encoder_tile_size = 1024
strength = 1
# minio 配置
sr_bucket = "test"
MINIO_IP = "www.minio.aida.com.hk"
MINIO_PORT = 9000
MINIO_ACCESS = 'vXKFLSJkYeEq2DrSZvkB'
MINIO_SECRET = 'uKTZT3x7C43WvPN9QTc99DiRkwddWZrG9Uh3JVlR'
MINIO_SECURE = True
# input = 'preprocess_img/input_x2' # 这个值需要被函数参数覆盖
# output = '/path/to/output' # 这个值将被函数参数覆盖
LOGS_PATH = "logs/errors.log"

1
app/logs/debug.log Normal file
View File

@@ -0,0 +1 @@
2024-03-12 13:03:10,034 main.py [line:43] INFO test ok

28
app/logs/errors.log Normal file
View File

@@ -0,0 +1,28 @@
2024-03-20 11:41:28,641 decorator.py [line:11] INFO function【read_image】,runtime【2.3682610988616943】s
2024-03-20 11:41:28,641 decorator.py [line:11] INFO function【read_image】,runtime【2.3682610988616943】s
2024-03-20 11:41:28,978 decorator.py [line:11] INFO function【sr_result】,runtime【2.7045106887817383】s
2024-03-20 11:41:28,978 decorator.py [line:11] INFO function【sr_result】,runtime【2.7045106887817383】s
2024-03-20 11:41:40,123 decorator.py [line:11] INFO function【read_image】,runtime【6.277707099914551】s
2024-03-20 11:41:40,123 decorator.py [line:11] INFO function【read_image】,runtime【6.277707099914551】s
2024-03-20 11:41:40,439 decorator.py [line:11] INFO function【sr_result】,runtime【6.594382047653198】s
2024-03-20 11:41:40,439 decorator.py [line:11] INFO function【sr_result】,runtime【6.594382047653198】s
2024-03-20 11:41:41,338 decorator.py [line:11] INFO function【read_image】,runtime【0.16055655479431152】s
2024-03-20 11:41:41,338 decorator.py [line:11] INFO function【read_image】,runtime【0.16055655479431152】s
2024-03-20 11:41:41,643 decorator.py [line:11] INFO function【sr_result】,runtime【0.46419310569763184】s
2024-03-20 11:41:41,643 decorator.py [line:11] INFO function【sr_result】,runtime【0.46419310569763184】s
2024-03-20 11:41:42,625 decorator.py [line:11] INFO function【read_image】,runtime【0.15813016891479492】s
2024-03-20 11:41:42,625 decorator.py [line:11] INFO function【read_image】,runtime【0.15813016891479492】s
2024-03-20 11:41:42,929 decorator.py [line:11] INFO function【sr_result】,runtime【0.4632871150970459】s
2024-03-20 11:41:42,929 decorator.py [line:11] INFO function【sr_result】,runtime【0.4632871150970459】s
2024-03-20 11:41:48,216 decorator.py [line:11] INFO function【read_image】,runtime【0.1381824016571045】s
2024-03-20 11:41:48,216 decorator.py [line:11] INFO function【read_image】,runtime【0.1381824016571045】s
2024-03-20 11:41:48,537 decorator.py [line:11] INFO function【sr_result】,runtime【0.4588344097137451】s
2024-03-20 11:41:48,537 decorator.py [line:11] INFO function【sr_result】,runtime【0.4588344097137451】s
2024-03-20 11:42:48,128 decorator.py [line:11] INFO function【read_image】,runtime【0.15878772735595703】s
2024-03-20 11:42:48,128 decorator.py [line:11] INFO function【read_image】,runtime【0.15878772735595703】s
2024-03-20 11:42:48,463 decorator.py [line:11] INFO function【sr_result】,runtime【0.49385905265808105】s
2024-03-20 11:42:48,463 decorator.py [line:11] INFO function【sr_result】,runtime【0.49385905265808105】s
2024-03-20 11:43:24,220 decorator.py [line:11] INFO function【read_image】,runtime【0.16216182708740234】s
2024-03-20 11:43:24,220 decorator.py [line:11] INFO function【read_image】,runtime【0.16216182708740234】s
2024-03-20 11:43:24,563 decorator.py [line:11] INFO function【sr_result】,runtime【0.5048878192901611】s
2024-03-20 11:43:24,563 decorator.py [line:11] INFO function【sr_result】,runtime【0.5048878192901611】s

1
app/logs/info.log Normal file
View File

@@ -0,0 +1 @@
2024-03-12 13:03:10,034 main.py [line:43] INFO test ok

38
app/main.py Normal file
View File

@@ -0,0 +1,38 @@
import logging.config
import uvicorn
from fastapi import FastAPI
from app.api.api_route import router
from app.core.config import settings
from logging_env import LOGGER_CONFIG_DICT
logging.config.dictConfig(LOGGER_CONFIG_DICT)
from starlette.middleware.cors import CORSMiddleware
def get_application() -> FastAPI:
application = FastAPI(
title=settings.PROJECT_NAME, docs_url="/docs", redoc_url='/re-docs',
openapi_url=f"{settings.API_PREFIX}/openapi.json",
description='''
Base frame with FastAPI
- Super Resolution API
'''
)
application.add_middleware(
CORSMiddleware,
allow_origins=[str(origin) for origin in settings.BACKEND_CORS_ORIGINS],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
application.include_router(router=router, prefix=settings.API_PREFIX)
return application
app = get_application()
if __name__ == '__main__':
uvicorn.run(app, host="0.0.0.0", port=8000)

View File

@@ -0,0 +1,6 @@
from pydantic import BaseModel
class SuperResolutionModel(BaseModel):
sr_image_url: str
sr_xn: int

View File

@@ -0,0 +1,67 @@
import io
import logging
from io import BytesIO
import cv2
import numpy as np
import torch
import tritonclient.http as httpclient
from minio import Minio
from app.core.config import MINIO_IP, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, MINIO_PORT
from app.service.utils.decorator import RunTime
from app.service.utils.generate_uuid import generate_uuid
logger = logging.getLogger()
class SuperResolution:
def __init__(self):
self.triton_client = httpclient.InferenceServerClient(url=f"10.1.1.150:7000")
self.minio_client = Minio(
f"{MINIO_IP}:{MINIO_PORT}",
access_key=MINIO_ACCESS,
secret_key=MINIO_SECRET,
secure=MINIO_SECURE)
@RunTime
def read_image(self, image_url):
image_data = self.minio_client.get_object(image_url.split("/", 1)[0], image_url.split("/", 1)[1])
img = np.frombuffer(image_data.data, np.uint8) # 转成8位无符号整型
img = cv2.imdecode(img, cv2.IMREAD_COLOR).astype(np.float32) / 255. # 解码
return img
@RunTime
def sr_result(self, image_url, sr_xn):
sample = self.read_image(image_url)
sample = np.transpose(sample if sample.shape[2] == 1 else sample[:, :, [2, 1, 0]], (2, 0, 1))
sample = torch.from_numpy(sample).float().unsqueeze(0).numpy()
inputs = [
httpclient.InferInput("input", sample.shape, datatype="FP32")
]
inputs[0].set_data_from_numpy(sample, binary_data=True)
results = self.triton_client.infer(model_name="super_resolution", inputs=inputs)
sr_output = torch.from_numpy(results.as_numpy(f"output"))
output = sr_output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
if output.ndim == 3:
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR
output = (output * 255.0).round().astype(np.uint8)
output_url = self.upload_img_sr(output, generate_uuid())
return output_url
def upload_img_sr(self, image, object_name):
try:
image_bytes = cv2.imencode('.jpg', image)[1].tobytes()
image_url = f"test/{self.minio_client.put_object(f'test', f'{object_name}.jpg', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
return image_url
except Exception as e:
logger.warning(f"upload_png_mask runtime exception : {e}")
if __name__ == '__main__':
service = SuperResolution()
result_url = service.sr_result("test/128_image/11.png", 4)
print(result_url)

View File

@@ -0,0 +1,28 @@
import time
import cv2
import numpy as np
import torch
import tritonclient.http as httpclient
from PIL import Image
triton_client = httpclient.InferenceServerClient(url=f"10.1.1.150:7000")
sample = cv2.imread("comic2.png", cv2.IMREAD_COLOR).astype(np.float32) / 255.
sample = np.transpose(sample if sample.shape[2] == 1 else sample[:, :, [2, 1, 0]], (2, 0, 1))
sample = torch.from_numpy(sample).float().unsqueeze(0).numpy()
inputs = [
httpclient.InferInput("input", sample.shape, datatype="FP32")
]
inputs[0].set_data_from_numpy(sample, binary_data=True)
start_time = time.time()
results = triton_client.infer(model_name="super_resolution", inputs=inputs)
print(time.time() - start_time)
sr_output = torch.from_numpy(results.as_numpy(f"output"))
output = sr_output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
if output.ndim == 3:
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR
output = (output * 255.0).round().astype(np.uint8)
# cv2.imshow("", output)
# cv2.waitKey(0)
cv2.imwrite("comic3.png", output)

View File

@@ -0,0 +1,14 @@
import time
import logging
def RunTime(func):
def wrapper(*args, **kwargs):
t1 = time.time()
res = func(*args, **kwargs)
t2 = time.time()
if t2 - t1 > 0.05:
logging.info(f"function{func.__name__}】,runtime{str(t2 - t1)}】s")
return res
return wrapper

View File

@@ -0,0 +1,10 @@
import threading
import uuid
id_lock = threading.Lock()
def generate_uuid():
with id_lock:
unique_id = str(uuid.uuid1())
return unique_id