commit bf06c7c120d5cf10f5af20227e5cdea89b0ab352 Author: zhouchengrong Date: Mon Mar 11 10:29:58 2024 +0800 Initial commit diff --git a/app/__init__.py b/app/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/api/__init__.py b/app/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/api/api_outfit_matcher.py b/app/api/api_outfit_matcher.py new file mode 100644 index 0000000..c9b4ca0 --- /dev/null +++ b/app/api/api_outfit_matcher.py @@ -0,0 +1,16 @@ +import logging + +from fastapi import APIRouter + +from app.service.outfit_matcher_hon.service import OutfitMatcherHon + +logger = logging.getLogger() +router = APIRouter() + +class Item(BaseModel) + +@router.post("") +def outfit_matcher_hon(): + service = OutfitMatcherHon() + logger.info("test") + return {"message": "ok"} diff --git a/app/api/api_route.py b/app/api/api_route.py new file mode 100644 index 0000000..02cf9ca --- /dev/null +++ b/app/api/api_route.py @@ -0,0 +1,7 @@ +from fastapi import APIRouter + +from app.api import api_test + +router = APIRouter() + +router.include_router(api_test.router, tags=["test"], prefix="/test") diff --git a/app/api/api_test.py b/app/api/api_test.py new file mode 100644 index 0000000..b47dca5 --- /dev/null +++ b/app/api/api_test.py @@ -0,0 +1,12 @@ +import logging + +from fastapi import APIRouter + +logger = logging.getLogger() +router = APIRouter() + + +@router.get("") +def test(): + logger.info("test") + return {"message": "ok"} diff --git a/app/core/__init__.py b/app/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/core/config.py b/app/core/config.py new file mode 100644 index 0000000..ea20a1c --- /dev/null +++ b/app/core/config.py @@ -0,0 +1,20 @@ +import os +from dotenv import load_dotenv +from pydantic import BaseSettings + +BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')) +load_dotenv(os.path.join(BASE_DIR, '.env')) + + +class Settings(BaseSettings): + PROJECT_NAME = os.getenv('PROJECT_NAME', 'FASTAPI BASE') + SECRET_KEY = os.getenv('SECRET_KEY', '') + API_PREFIX = '' + BACKEND_CORS_ORIGINS = ['*'] + DATABASE_URL = os.getenv('SQL_DATABASE_URL', '') + ACCESS_TOKEN_EXPIRE_SECONDS: int = 60 * 60 * 24 * 7 # Token expired after 7 days + SECURITY_ALGORITHM = 'HS256' + LOGGING_CONFIG_FILE = os.path.join(BASE_DIR, 'logging_env.py') + + +settings = Settings() diff --git a/app/logs/debug.log b/app/logs/debug.log new file mode 100644 index 0000000..ae3fa54 --- /dev/null +++ b/app/logs/debug.log @@ -0,0 +1,2 @@ +2024-03-11 10:10:54,038 api_test.py [line:11] INFO test +2024-03-11 10:10:55,431 api_test.py [line:11] INFO test diff --git a/app/logs/errors.log b/app/logs/errors.log new file mode 100644 index 0000000..e69de29 diff --git a/app/logs/info.log b/app/logs/info.log new file mode 100644 index 0000000..ae3fa54 --- /dev/null +++ b/app/logs/info.log @@ -0,0 +1,2 @@ +2024-03-11 10:10:54,038 api_test.py [line:11] INFO test +2024-03-11 10:10:55,431 api_test.py [line:11] INFO test diff --git a/app/main.py b/app/main.py new file mode 100644 index 0000000..f9a4d63 --- /dev/null +++ b/app/main.py @@ -0,0 +1,42 @@ +import uvicorn +from fastapi import FastAPI + +import logging.config + +from app.api.api_route import router +from app.core.config import settings + +from logging_env import LOGGER_CONFIG_DICT + +logging.config.dictConfig(LOGGER_CONFIG_DICT) + +from starlette.middleware.cors import CORSMiddleware + + +def get_application() -> FastAPI: + application = FastAPI( + title=settings.PROJECT_NAME, docs_url="/docs", redoc_url='/re-docs', + openapi_url=f"{settings.API_PREFIX}/openapi.json", + description=''' + Base frame with FastAPI micro framework + Postgresql + - Login/Register with JWT + - Permission + - CRUD User + - Unit testing with Pytest + - Dockerize + ''' + ) + application.add_middleware( + CORSMiddleware, + allow_origins=[str(origin) for origin in settings.BACKEND_CORS_ORIGINS], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + application.include_router(router=router, prefix=settings.API_PREFIX) + return application + + +app = get_application() +if __name__ == '__main__': + uvicorn.run(app, host="0.0.0.0", port=8000) diff --git a/app/schemas/sche_outfit_matcher_hon.py b/app/schemas/sche_outfit_matcher_hon.py new file mode 100644 index 0000000..1e88aff --- /dev/null +++ b/app/schemas/sche_outfit_matcher_hon.py @@ -0,0 +1 @@ +class \ No newline at end of file diff --git a/app/service/outfit_matcher_hon/__init__.py b/app/service/outfit_matcher_hon/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/service/outfit_matcher_hon/service.py b/app/service/outfit_matcher_hon/service.py new file mode 100644 index 0000000..85ffffc --- /dev/null +++ b/app/service/outfit_matcher_hon/service.py @@ -0,0 +1,132 @@ +import torch +import torch.nn.functional as F +import tritonclient.http as httpclient +import requests +import cv2 +import numpy as np +from PIL import Image +from foco import extract_main_colors + + +class OutfitMatcherHon: + def __init__(self, outfits): + self.outfits = outfits + self.tritonclient = httpclient.InferenceServerClient(url="localhost:8000") + + @staticmethod + def imnormalize(img, mean, std, to_rgb=True): + """Normalize an image with mean and std. + + Args: + img (ndarray): Image to be normalized. + mean (ndarray): The mean to be used for normalize. + std (ndarray): The std to be used for normalize. + to_rgb (bool): Whether to convert to rgb. + + Returns: + ndarray: The normalized image. + """ + img = img.copy().astype(np.float32) + assert img.dtype != np.uint8 + mean = np.float64(mean.reshape(1, -1)) + stdinv = 1 / np.float64(std.reshape(1, -1)) + if to_rgb: + cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) # inplace + cv2.subtract(img, mean, img) # inplace + cv2.multiply(img, stdinv, img) # inplace + return img + + @staticmethod + def load_image(img_path): + if 'http' in img_path: + file = requests.get(img_path) + image = cv2.imdecode(np.fromstring(file.content, np.uint8), 1) + image = Image.fromarray(image.astype('uint8'), 'RGB') + else: + image = Image.open(img_path).convert('RGB') + return np.array(image) + + @staticmethod + def resize_image(img): + """ + Args: + img: ndarray (height, width, channel) + """ + resized_img = cv2.resize(img, (224, 224), dst=None, interpolation=1) + return resized_img + + @staticmethod + def pad_array(input_value): + """pad List of Array into same batch size + + Args: + input_value: List of numpy arrary need to be padded + + Returns: + Tensor: [batch_dim, max_dim, original_tensor_size] + """ + max_dim = max([len(x) for x in input_value]) + mask = np.zeros((len(input_value), max_dim), dtype=np.float32) + + # Pad each array + padded_arrays = [] + for i, array in enumerate(input_value): + # Compute padding amount along the pad dimension + pad_dim = max_dim - array.shape[0] + consistent_shape = array.shape[1:] + pad_widths = [(0, pad_dim)] + [(0, 0)] * len(consistent_shape) + padded_array = np.pad(array, pad_widths, mode='constant', constant_values=0) + padded_arrays.append(padded_array) + + mask[i, array.shape[0]:] = float("-inf") + + # Stack the padded arrays and change the dimension + batched_arrays = np.stack(padded_arrays, axis=0) + return batched_arrays, mask + + def preprocess(self): + outfit_images = [] + outfit_colors = [] + for outfit in self.outfits: + images = [] + colors = [] + for item in outfit["items"]: + image = self.load_image(item["image_path"]) + image = self.resize_image(image) + normalized_image = self.imnormalize(image, + mean=np.array([208.32996145, 201.28227452, 198.47047691], dtype=np.float32), + std=np.array([75.48939648, 80.47423057, 82.21144189], dtype=np.float32)) + images.append(normalized_image.transpose(2, 0, 1)) + color = extract_main_colors(image) + colors.append(color) + images = np.stack(images, axis=0) + outfit_images.append(images) # List[(items, 3, 224, 224)] + colors = np.stack(colors, axis=0) + outfit_colors.append(colors) + outfit_images, mask = self.pad_array(outfit_images) + outfit_colors, _ = self.pad_array(outfit_colors) + return outfit_images, outfit_colors, mask + + def get_result(self, outfits): + # start = time.time() + image, color, mask = self.preprocess() + # print(start - time.time()) + # transformed_img = image.astype(np.float32) + # 输入集 + inputs = [ + httpclient.InferInput("input__0", image.shape, datatype="FP32"), + httpclient.InferInput("input__1", color.shape, datatype="FP32"), + httpclient.InferInput("input__2", mask.shape, datatype="FP32"), + ] + inputs[0].set_data_from_numpy(image.astype(np.float32), binary_data=True) + inputs[1].set_data_from_numpy(color.astype(np.float32), binary_data=True) + inputs[2].set_data_from_numpy(mask.astype(np.float32), binary_data=True) + # 输出集 + outputs = [ + httpclient.InferRequestedOutput("output__0", binary_data=True), + ] + results = self.tritonclient.infer(model_name="outfit_matcher_hon", inputs=inputs, outputs=outputs) + # 推理 + # 取结果 + inference_output1 = torch.from_numpy(results.as_numpy("output__0")) + return inference_output1 # Shape (N, 1) diff --git a/logging_env.py b/logging_env.py new file mode 100644 index 0000000..c6327a7 --- /dev/null +++ b/logging_env.py @@ -0,0 +1,49 @@ +LOGGER_CONFIG_DICT = { + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "simple": {"format": "%(asctime)s %(filename)s [line:%(lineno)d] %(levelname)s %(message)s"} + }, + "handlers": { + "console": { + "class": "logging.StreamHandler", + "level": "DEBUG", + "formatter": "simple", + "stream": "ext://sys.stdout", + }, + "info_file_handler": { + "class": "logging.handlers.RotatingFileHandler", + "level": "INFO", + "formatter": "simple", + "filename": "logs/info.log", + "maxBytes": 10485760, + "backupCount": 50, + "encoding": "utf8", + }, + "error_file_handler": { + "class": "logging.handlers.RotatingFileHandler", + "level": "ERROR", + "formatter": "simple", + "filename": "logs/errors.log", + "maxBytes": 10485760, + "backupCount": 20, + "encoding": "utf8", + }, + "debug_file_handler": { + "class": "logging.handlers.RotatingFileHandler", + "level": "DEBUG", + "formatter": "simple", + "filename": "logs/debug.log", + "maxBytes": 10485760, + "backupCount": 50, + "encoding": "utf8", + }, + }, + "loggers": { + "my_module": {"level": "INFO", "handlers": ["console"], "propagate": "no"} + }, + "root": { + "level": "INFO", + "handlers": ["error_file_handler", "info_file_handler", "debug_file_handler", "console"], + }, +}