From 9b940c9cf8466eb1e4e69e37de0f1aeb67eecd4c Mon Sep 17 00:00:00 2001 From: zhouchengrong Date: Mon, 23 Sep 2024 10:47:34 +0800 Subject: [PATCH] =?UTF-8?q?feat=20=20api=20=E8=B0=83=E7=94=A8=E6=AC=A1?= =?UTF-8?q?=E6=95=B0=E8=AE=B0=E5=BD=95=20fix?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/core/record_api_count.py | 44 ++++++++++++++++++++++++++++++++++++ app/main.py | 2 ++ 2 files changed, 46 insertions(+) create mode 100644 app/core/record_api_count.py diff --git a/app/core/record_api_count.py b/app/core/record_api_count.py new file mode 100644 index 0000000..c93a642 --- /dev/null +++ b/app/core/record_api_count.py @@ -0,0 +1,44 @@ +from fastapi import Request +from sqlalchemy import Column, Integer, String +from sqlalchemy import create_engine +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker + +# 创建数据库引擎 +DATABASE_URL = "sqlite:///./api_count.db" +engine = create_engine(DATABASE_URL) + +# 创建数据库会话 +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +# 创建数据库模型基类 +Base = declarative_base() + + +# 定义存储调用次数的数据库模型 +class CallCount(Base): + __tablename__ = "call_count" + id = Column(Integer, primary_key=True, index=True) + service_name = Column(String, nullable=False) + call_count = Column(Integer, default=0) + + +# 创建数据库表(如果不存在) +Base.metadata.create_all(bind=engine) + + +# 定义中间件函数,用于记录接口调用次数 + +def count_api_calls(request: Request, call_next): + db = SessionLocal() + service_name = request.url.path + call_record = db.query(CallCount).filter_by(service_name=service_name).first() + if call_record is None: + call_record = CallCount(service_name=service_name, call_count=1) + db.add(call_record) + else: + call_record.call_count += 1 + db.commit() + db.refresh(call_record) + response = call_next(request) + return response diff --git a/app/main.py b/app/main.py index b085d7d..95c666a 100644 --- a/app/main.py +++ b/app/main.py @@ -8,6 +8,7 @@ from fastapi import FastAPI from app.api.api_route import router from app.core.config import settings +from app.core.record_api_count import count_api_calls from app.schemas.response_template import ResponseModel from logging_env import LOGGER_CONFIG_DICT @@ -34,6 +35,7 @@ def get_application() -> FastAPI: allow_methods=["*"], allow_headers=["*"], ) + application.middleware("http")(count_api_calls) application.include_router(router=router, prefix=settings.API_PREFIX) return application