1
This commit is contained in:
0
app/__init__.py
Normal file
0
app/__init__.py
Normal file
0
app/api/__init__.py
Normal file
0
app/api/__init__.py
Normal file
9
app/api/api_route.py
Normal file
9
app/api/api_route.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.api import api_test
|
||||
from app.api import api_super_resolution
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
router.include_router(api_test.router, tags=["test"], prefix="/test")
|
||||
router.include_router(api_super_resolution.router, tags=["api_super_resolution"], prefix="/api")
|
||||
14
app/api/api_super_resolution.py
Normal file
14
app/api/api_super_resolution.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.schemas.super_resolution import SuperResolutionModel
|
||||
from app.service.super_resolution.service import SuperResolution
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("super_resolution")
|
||||
def super_resolution(request_item: SuperResolutionModel):
|
||||
service = SuperResolution()
|
||||
sr_result_url = service.sr_result(request_item.sr_image_url, request_item.sr_xn)
|
||||
response = {"sr_result_url": sr_result_url}
|
||||
return {"code": 200, "message": "ok", "data": response}
|
||||
14
app/api/api_test.py
Normal file
14
app/api/api_test.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import logging
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
logger = logging.getLogger()
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
|
||||
|
||||
@router.get("")
|
||||
def test():
|
||||
logger.info("test")
|
||||
return {"message": "ok"}
|
||||
0
app/core/__init__.py
Normal file
0
app/core/__init__.py
Normal file
53
app/core/config.py
Normal file
53
app/core/config.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import BaseSettings
|
||||
|
||||
BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../'))
|
||||
load_dotenv(os.path.join(BASE_DIR, '.env'))
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
PROJECT_NAME = os.getenv('PROJECT_NAME', 'FASTAPI BASE')
|
||||
SECRET_KEY = os.getenv('SECRET_KEY', '')
|
||||
API_PREFIX = ''
|
||||
BACKEND_CORS_ORIGINS = ['*']
|
||||
DATABASE_URL = os.getenv('SQL_DATABASE_URL', '')
|
||||
ACCESS_TOKEN_EXPIRE_SECONDS: int = 60 * 60 * 24 * 7 # Token expired after 7 days
|
||||
SECURITY_ALGORITHM = 'HS256'
|
||||
LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py')
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
||||
ckpt = 'service/super_resolution_ccsr/weights/real-world_ccsr.ckpt'
|
||||
config = 'service/super_resolution_ccsr/configs/model/ccsr_stage2.yaml'
|
||||
steps = 45
|
||||
sr_scale = 4
|
||||
repeat_times = 1
|
||||
tiled = False
|
||||
tile_size = 512
|
||||
tile_stride = 256
|
||||
color_fix_type = "adain"
|
||||
t_max = 0.6667
|
||||
t_min = 0.3333
|
||||
show_lq = False
|
||||
skip_if_exist = False
|
||||
seed = 233
|
||||
device = "cuda"
|
||||
tile_diffusion = False #
|
||||
tile_diffusion_size = 512
|
||||
tile_diffusion_stride = 256
|
||||
tile_vae = True
|
||||
vae_decoder_tile_size = 224
|
||||
vae_encoder_tile_size = 1024
|
||||
strength = 1
|
||||
# minio 配置
|
||||
sr_bucket = "test"
|
||||
MINIO_IP = "www.minio.aida.com.hk"
|
||||
MINIO_PORT = 9000
|
||||
MINIO_ACCESS = 'vXKFLSJkYeEq2DrSZvkB'
|
||||
MINIO_SECRET = 'uKTZT3x7C43WvPN9QTc99DiRkwddWZrG9Uh3JVlR'
|
||||
MINIO_SECURE = True
|
||||
# input = 'preprocess_img/input_x2' # 这个值需要被函数参数覆盖
|
||||
# output = '/path/to/output' # 这个值将被函数参数覆盖
|
||||
LOGS_PATH = "logs/errors.log"
|
||||
1
app/logs/debug.log
Normal file
1
app/logs/debug.log
Normal file
@@ -0,0 +1 @@
|
||||
2024-03-12 13:03:10,034 main.py [line:43] INFO test ok
|
||||
28
app/logs/errors.log
Normal file
28
app/logs/errors.log
Normal file
@@ -0,0 +1,28 @@
|
||||
2024-03-20 11:41:28,641 decorator.py [line:11] INFO function:【read_image】,runtime:【2.3682610988616943】s
|
||||
2024-03-20 11:41:28,641 decorator.py [line:11] INFO function:【read_image】,runtime:【2.3682610988616943】s
|
||||
2024-03-20 11:41:28,978 decorator.py [line:11] INFO function:【sr_result】,runtime:【2.7045106887817383】s
|
||||
2024-03-20 11:41:28,978 decorator.py [line:11] INFO function:【sr_result】,runtime:【2.7045106887817383】s
|
||||
2024-03-20 11:41:40,123 decorator.py [line:11] INFO function:【read_image】,runtime:【6.277707099914551】s
|
||||
2024-03-20 11:41:40,123 decorator.py [line:11] INFO function:【read_image】,runtime:【6.277707099914551】s
|
||||
2024-03-20 11:41:40,439 decorator.py [line:11] INFO function:【sr_result】,runtime:【6.594382047653198】s
|
||||
2024-03-20 11:41:40,439 decorator.py [line:11] INFO function:【sr_result】,runtime:【6.594382047653198】s
|
||||
2024-03-20 11:41:41,338 decorator.py [line:11] INFO function:【read_image】,runtime:【0.16055655479431152】s
|
||||
2024-03-20 11:41:41,338 decorator.py [line:11] INFO function:【read_image】,runtime:【0.16055655479431152】s
|
||||
2024-03-20 11:41:41,643 decorator.py [line:11] INFO function:【sr_result】,runtime:【0.46419310569763184】s
|
||||
2024-03-20 11:41:41,643 decorator.py [line:11] INFO function:【sr_result】,runtime:【0.46419310569763184】s
|
||||
2024-03-20 11:41:42,625 decorator.py [line:11] INFO function:【read_image】,runtime:【0.15813016891479492】s
|
||||
2024-03-20 11:41:42,625 decorator.py [line:11] INFO function:【read_image】,runtime:【0.15813016891479492】s
|
||||
2024-03-20 11:41:42,929 decorator.py [line:11] INFO function:【sr_result】,runtime:【0.4632871150970459】s
|
||||
2024-03-20 11:41:42,929 decorator.py [line:11] INFO function:【sr_result】,runtime:【0.4632871150970459】s
|
||||
2024-03-20 11:41:48,216 decorator.py [line:11] INFO function:【read_image】,runtime:【0.1381824016571045】s
|
||||
2024-03-20 11:41:48,216 decorator.py [line:11] INFO function:【read_image】,runtime:【0.1381824016571045】s
|
||||
2024-03-20 11:41:48,537 decorator.py [line:11] INFO function:【sr_result】,runtime:【0.4588344097137451】s
|
||||
2024-03-20 11:41:48,537 decorator.py [line:11] INFO function:【sr_result】,runtime:【0.4588344097137451】s
|
||||
2024-03-20 11:42:48,128 decorator.py [line:11] INFO function:【read_image】,runtime:【0.15878772735595703】s
|
||||
2024-03-20 11:42:48,128 decorator.py [line:11] INFO function:【read_image】,runtime:【0.15878772735595703】s
|
||||
2024-03-20 11:42:48,463 decorator.py [line:11] INFO function:【sr_result】,runtime:【0.49385905265808105】s
|
||||
2024-03-20 11:42:48,463 decorator.py [line:11] INFO function:【sr_result】,runtime:【0.49385905265808105】s
|
||||
2024-03-20 11:43:24,220 decorator.py [line:11] INFO function:【read_image】,runtime:【0.16216182708740234】s
|
||||
2024-03-20 11:43:24,220 decorator.py [line:11] INFO function:【read_image】,runtime:【0.16216182708740234】s
|
||||
2024-03-20 11:43:24,563 decorator.py [line:11] INFO function:【sr_result】,runtime:【0.5048878192901611】s
|
||||
2024-03-20 11:43:24,563 decorator.py [line:11] INFO function:【sr_result】,runtime:【0.5048878192901611】s
|
||||
1
app/logs/info.log
Normal file
1
app/logs/info.log
Normal file
@@ -0,0 +1 @@
|
||||
2024-03-12 13:03:10,034 main.py [line:43] INFO test ok
|
||||
38
app/main.py
Normal file
38
app/main.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import logging.config
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
|
||||
from app.api.api_route import router
|
||||
from app.core.config import settings
|
||||
from logging_env import LOGGER_CONFIG_DICT
|
||||
|
||||
logging.config.dictConfig(LOGGER_CONFIG_DICT)
|
||||
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
|
||||
def get_application() -> FastAPI:
|
||||
application = FastAPI(
|
||||
title=settings.PROJECT_NAME, docs_url="/docs", redoc_url='/re-docs',
|
||||
openapi_url=f"{settings.API_PREFIX}/openapi.json",
|
||||
description='''
|
||||
Base frame with FastAPI
|
||||
- Super Resolution API
|
||||
|
||||
'''
|
||||
)
|
||||
application.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=[str(origin) for origin in settings.BACKEND_CORS_ORIGINS],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
application.include_router(router=router, prefix=settings.API_PREFIX)
|
||||
return application
|
||||
|
||||
|
||||
app = get_application()
|
||||
if __name__ == '__main__':
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
6
app/schemas/super_resolution.py
Normal file
6
app/schemas/super_resolution.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SuperResolutionModel(BaseModel):
|
||||
sr_image_url: str
|
||||
sr_xn: int
|
||||
67
app/service/super_resolution/service.py
Normal file
67
app/service/super_resolution/service.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import io
|
||||
import logging
|
||||
from io import BytesIO
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import tritonclient.http as httpclient
|
||||
from minio import Minio
|
||||
|
||||
from app.core.config import MINIO_IP, MINIO_ACCESS, MINIO_SECRET, MINIO_SECURE, MINIO_PORT
|
||||
|
||||
from app.service.utils.decorator import RunTime
|
||||
from app.service.utils.generate_uuid import generate_uuid
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
class SuperResolution:
|
||||
def __init__(self):
|
||||
self.triton_client = httpclient.InferenceServerClient(url=f"10.1.1.150:7000")
|
||||
self.minio_client = Minio(
|
||||
f"{MINIO_IP}:{MINIO_PORT}",
|
||||
access_key=MINIO_ACCESS,
|
||||
secret_key=MINIO_SECRET,
|
||||
secure=MINIO_SECURE)
|
||||
|
||||
@RunTime
|
||||
def read_image(self, image_url):
|
||||
image_data = self.minio_client.get_object(image_url.split("/", 1)[0], image_url.split("/", 1)[1])
|
||||
img = np.frombuffer(image_data.data, np.uint8) # 转成8位无符号整型
|
||||
img = cv2.imdecode(img, cv2.IMREAD_COLOR).astype(np.float32) / 255. # 解码
|
||||
return img
|
||||
|
||||
@RunTime
|
||||
def sr_result(self, image_url, sr_xn):
|
||||
sample = self.read_image(image_url)
|
||||
sample = np.transpose(sample if sample.shape[2] == 1 else sample[:, :, [2, 1, 0]], (2, 0, 1))
|
||||
sample = torch.from_numpy(sample).float().unsqueeze(0).numpy()
|
||||
inputs = [
|
||||
httpclient.InferInput("input", sample.shape, datatype="FP32")
|
||||
]
|
||||
inputs[0].set_data_from_numpy(sample, binary_data=True)
|
||||
results = self.triton_client.infer(model_name="super_resolution", inputs=inputs)
|
||||
|
||||
sr_output = torch.from_numpy(results.as_numpy(f"output"))
|
||||
output = sr_output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
if output.ndim == 3:
|
||||
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR
|
||||
output = (output * 255.0).round().astype(np.uint8)
|
||||
output_url = self.upload_img_sr(output, generate_uuid())
|
||||
return output_url
|
||||
|
||||
def upload_img_sr(self, image, object_name):
|
||||
try:
|
||||
image_bytes = cv2.imencode('.jpg', image)[1].tobytes()
|
||||
image_url = f"test/{self.minio_client.put_object(f'test', f'{object_name}.jpg', io.BytesIO(image_bytes), len(image_bytes), content_type='image/png').object_name}"
|
||||
|
||||
return image_url
|
||||
except Exception as e:
|
||||
logger.warning(f"upload_png_mask runtime exception : {e}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
service = SuperResolution()
|
||||
result_url = service.sr_result("test/128_image/11.png", 4)
|
||||
print(result_url)
|
||||
28
app/service/super_resolution/test.py
Normal file
28
app/service/super_resolution/test.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import time
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import tritonclient.http as httpclient
|
||||
from PIL import Image
|
||||
|
||||
triton_client = httpclient.InferenceServerClient(url=f"10.1.1.150:7000")
|
||||
|
||||
sample = cv2.imread("comic2.png", cv2.IMREAD_COLOR).astype(np.float32) / 255.
|
||||
sample = np.transpose(sample if sample.shape[2] == 1 else sample[:, :, [2, 1, 0]], (2, 0, 1))
|
||||
sample = torch.from_numpy(sample).float().unsqueeze(0).numpy()
|
||||
inputs = [
|
||||
httpclient.InferInput("input", sample.shape, datatype="FP32")
|
||||
]
|
||||
inputs[0].set_data_from_numpy(sample, binary_data=True)
|
||||
start_time = time.time()
|
||||
results = triton_client.infer(model_name="super_resolution", inputs=inputs)
|
||||
print(time.time() - start_time)
|
||||
sr_output = torch.from_numpy(results.as_numpy(f"output"))
|
||||
output = sr_output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
|
||||
if output.ndim == 3:
|
||||
output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0)) # CHW-RGB to HCW-BGR
|
||||
output = (output * 255.0).round().astype(np.uint8)
|
||||
# cv2.imshow("", output)
|
||||
# cv2.waitKey(0)
|
||||
cv2.imwrite("comic3.png", output)
|
||||
14
app/service/utils/decorator.py
Normal file
14
app/service/utils/decorator.py
Normal file
@@ -0,0 +1,14 @@
|
||||
import time
|
||||
import logging
|
||||
|
||||
|
||||
def RunTime(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
t1 = time.time()
|
||||
res = func(*args, **kwargs)
|
||||
t2 = time.time()
|
||||
if t2 - t1 > 0.05:
|
||||
logging.info(f"function:【{func.__name__}】,runtime:【{str(t2 - t1)}】s")
|
||||
return res
|
||||
|
||||
return wrapper
|
||||
10
app/service/utils/generate_uuid.py
Normal file
10
app/service/utils/generate_uuid.py
Normal file
@@ -0,0 +1,10 @@
|
||||
import threading
|
||||
import uuid
|
||||
|
||||
id_lock = threading.Lock()
|
||||
|
||||
|
||||
def generate_uuid():
|
||||
with id_lock:
|
||||
unique_id = str(uuid.uuid1())
|
||||
return unique_id
|
||||
Reference in New Issue
Block a user