【Python Web 开发精通】第14讲 | FastAPI异步编程:释放高性能的秘诀
·
环境声明
- Python 版本:
Python 3.12+(建议使用 3.10 以上版本) - FastAPI 版本:
FastAPI 0.115+ - Pydantic 版本:
Pydantic 2.0+ - ASGI 服务器:
Uvicorn 0.30+ - 开发工具:
PyCharm或VS Code - 操作系统:
Windows/macOS/Linux(通用)
学习目标
学完本讲,你将能够:
- 理解 Python 异步编程的核心概念
- 掌握 FastAPI 中的异步路径操作
- 使用后台任务处理耗时操作
- 实现异步数据库操作
- 应用性能优化技巧提升应用性能
1. 异步编程基础
1.1 什么是异步编程
比喻理解:
想象你在餐厅点餐:
- 同步模式:你点完餐,站在柜台前等着,直到餐做好才离开(阻塞)
- 异步模式:你点完餐,拿到叫号器,先去座位等,餐好了叫号器通知你(非阻塞)
异步编程允许程序在等待 I/O 操作(如网络请求、数据库查询)时,去执行其他任务,而不是空等。
1.2 核心概念
| 概念 | 说明 | 类比 |
|---|---|---|
async |
定义异步函数 | 标记这是一个"可以等待"的任务 |
await |
等待异步操作完成 | 暂时交出控制权,等结果回来再继续 |
asyncio |
Python 异步标准库 | 餐厅经理,协调所有任务 |
event loop |
事件循环 | 调度中心,决定哪个任务先执行 |
1.3 同步 vs 异步对比
import time
import asyncio
# ========== 同步版本 ==========
def sync_task(name: str, delay: int):
"""同步任务"""
print(f"[{name}] 开始执行")
time.sleep(delay) # 阻塞操作
print(f"[{name}] 执行完成")
return f"{name} 的结果"
def run_sync():
"""运行同步任务"""
start = time.time()
sync_task("任务A", 2)
sync_task("任务B", 2)
sync_task("任务C", 2)
print(f"同步总耗时: {time.time() - start:.2f}秒")
# 输出: 同步总耗时: 6.00秒
# ========== 异步版本 ==========
async def async_task(name: str, delay: int):
"""异步任务"""
print(f"[{name}] 开始执行")
await asyncio.sleep(delay) # 非阻塞操作
print(f"[{name}] 执行完成")
return f"{name} 的结果"
async def run_async():
"""运行异步任务"""
start = time.time()
# 并发执行多个任务
await asyncio.gather(
async_task("任务A", 2),
async_task("任务B", 2),
async_task("任务C", 2)
)
print(f"异步总耗时: {time.time() - start:.2f}秒")
# 输出: 异步总耗时: 2.00秒
# 运行对比
if __name__ == "__main__":
print("=== 同步执行 ===")
run_sync()
print("\n=== 异步执行 ===")
asyncio.run(run_async())
2. FastAPI 中的异步路径操作
2.1 定义异步路由
from fastapi import FastAPI
import asyncio
app = FastAPI()
# 同步路由(FastAPI 会自动在线程池中运行)
@app.get("/sync")
def sync_endpoint():
"""同步端点"""
import time
time.sleep(1) # 模拟耗时操作
return {"message": "同步响应"}
# 异步路由(推荐用于 I/O 操作)
@app.get("/async")
async def async_endpoint():
"""异步端点"""
await asyncio.sleep(1) # 模拟异步 I/O
return {"message": "异步响应"}
# 混合使用
@app.get("/mixed/{item_id}")
async def mixed_endpoint(item_id: int):
"""混合端点示例"""
# 异步操作
await asyncio.sleep(0.5)
# CPU 密集型操作(应使用 run_in_threadpool)
from fastapi.concurrency import run_in_threadpool
result = await run_in_threadpool(cpu_intensive_task, item_id)
return {"item_id": item_id, "result": result}
def cpu_intensive_task(n: int) -> int:
"""CPU 密集型任务"""
return sum(i * i for i in range(n * 10000))
2.2 何时使用 async/await
| 场景 | 推荐方式 | 原因 |
|---|---|---|
| 数据库查询 | async |
I/O 操作,使用异步数据库驱动 |
| HTTP 请求 | async |
网络 I/O,使用 aiohttp/httpx |
| 文件读写 | async |
I/O 操作,使用 aiofiles |
| 纯 CPU 计算 | sync + run_in_threadpool |
避免阻塞事件循环 |
| 简单响应 | sync 或 async 均可 |
无 I/O 操作,差别不大 |
from fastapi import FastAPI
import httpx
import aiofiles
app = FastAPI()
# 异步 HTTP 请求
@app.get("/fetch-data")
async def fetch_data():
"""异步获取外部数据"""
async with httpx.AsyncClient() as client:
response = await client.get("https://api.github.com/users/python")
return response.json()
# 异步文件操作
@app.get("/read-file")
async def read_file():
"""异步读取文件"""
async with aiofiles.open("data.txt", mode="r") as f:
content = await f.read()
return {"content": content}
# CPU 密集型操作(正确使用方式)
from fastapi.concurrency import run_in_threadpool
@app.get("/process/{data}")
async def process_data(data: str):
"""处理数据"""
# 错误:直接在 async 函数中执行 CPU 密集型操作
# result = heavy_computation(data) # 会阻塞事件循环!
# 正确:在线程池中执行
result = await run_in_threadpool(heavy_computation, data)
return {"result": result}
def heavy_computation(data: str) -> str:
"""CPU 密集型计算"""
# 模拟复杂计算
result = data
for _ in range(1000000):
result = hash(result)
return str(result)
2.3 并发请求处理
from fastapi import FastAPI
import httpx
import asyncio
app = FastAPI()
@app.get("/concurrent-fetch")
async def concurrent_fetch():
"""并发获取多个 API 数据"""
urls = [
"https://api.github.com/users/python",
"https://api.github.com/users/google",
"https://api.github.com/users/microsoft"
]
# 方法1:使用 asyncio.gather(推荐)
async with httpx.AsyncClient() as client:
tasks = [client.get(url) for url in urls]
responses = await asyncio.gather(*tasks)
results = [r.json() for r in responses]
return {"results": results}
@app.get("/concurrent-with-limit")
async def concurrent_with_limit():
"""带并发限制的请求"""
urls = [f"https://api.example.com/data/{i}" for i in range(100)]
# 使用信号量限制并发数
semaphore = asyncio.Semaphore(10) # 最多10个并发
async def fetch_with_limit(url: str):
async with semaphore:
async with httpx.AsyncClient() as client:
response = await client.get(url)
return response.json()
tasks = [fetch_with_limit(url) for url in urls]
results = await asyncio.gather(*tasks)
return {"count": len(results)}
# 使用任务组(Python 3.11+)
@app.get("/task-group")
async def task_group_example():
"""使用 TaskGroup 管理任务"""
results = []
async with asyncio.TaskGroup() as tg:
task1 = tg.create_task(fetch_data("url1"))
task2 = tg.create_task(fetch_data("url2"))
task3 = tg.create_task(fetch_data("url3"))
results = [task1.result(), task2.result(), task3.result()]
return {"results": results}
async def fetch_data(url: str):
"""获取数据"""
async with httpx.AsyncClient() as client:
response = await client.get(url)
return response.json()
3. 后台任务
3.1 基础后台任务
from fastapi import FastAPI, BackgroundTasks
import asyncio
app = FastAPI()
# 定义后台任务函数
async def send_email(email: str, message: str):
"""异步发送邮件"""
await asyncio.sleep(2) # 模拟发送邮件耗时
print(f"邮件已发送至 {email}: {message}")
async def write_log(message: str):
"""异步写入日志"""
await asyncio.sleep(0.5)
print(f"日志记录: {message}")
# 使用后台任务
@app.post("/send-notification/{email}")
async def send_notification(
email: str,
background_tasks: BackgroundTasks
):
"""发送通知(后台执行)"""
# 立即返回响应
background_tasks.add_task(send_email, email, "感谢您的注册!")
background_tasks.add_task(write_log, f"发送邮件给 {email}")
return {"message": "通知将在后台发送", "email": email}
# 同步后台任务
@app.post("/process-data")
async def process_data(
data: str,
background_tasks: BackgroundTasks
):
"""处理数据(后台执行)"""
def sync_process(data: str):
"""同步处理函数"""
import time
time.sleep(3)
print(f"数据处理完成: {data}")
background_tasks.add_task(sync_process, data)
return {"message": "数据处理已在后台启动"}
3.2 任务队列与状态追踪
from fastapi import FastAPI, BackgroundTasks, HTTPException
from pydantic import BaseModel
from enum import Enum
from datetime import datetime
import asyncio
import uuid
app = FastAPI()
# 任务状态枚举
class TaskStatus(str, Enum):
pending = "pending"
running = "running"
completed = "completed"
failed = "failed"
# 任务模型
class TaskInfo(BaseModel):
task_id: str
status: TaskStatus
created_at: datetime
completed_at: datetime | None = None
result: dict | None = None
error: str | None = None
# 内存中的任务存储(生产环境应使用 Redis 等)
tasks_db: dict[str, TaskInfo] = {}
async def long_running_task(task_id: str, data: dict):
"""长时间运行的任务"""
task = tasks_db[task_id]
task.status = TaskStatus.running
try:
# 模拟多阶段处理
for i in range(5):
await asyncio.sleep(2) # 模拟工作
print(f"任务 {task_id} 进度: {(i+1)*20}%")
task.status = TaskStatus.completed
task.completed_at = datetime.now()
task.result = {"processed": True, "data": data}
except Exception as e:
task.status = TaskStatus.failed
task.error = str(e)
@app.post("/tasks", response_model=TaskInfo)
async def create_task(
data: dict,
background_tasks: BackgroundTasks
):
"""创建后台任务"""
task_id = str(uuid.uuid4())
task_info = TaskInfo(
task_id=task_id,
status=TaskStatus.pending,
created_at=datetime.now()
)
tasks_db[task_id] = task_info
# 添加到后台任务
background_tasks.add_task(long_running_task, task_id, data)
return task_info
@app.get("/tasks/{task_id}", response_model=TaskInfo)
async def get_task_status(task_id: str):
"""获取任务状态"""
if task_id not in tasks_db:
raise HTTPException(status_code=404, detail="任务不存在")
return tasks_db[task_id]
@app.get("/tasks")
async def list_tasks():
"""列出所有任务"""
return list(tasks_db.values())
3.3 使用 Celery 处理复杂后台任务
# celery_app.py
from celery import Celery
# 配置 Celery
celery_app = Celery(
"worker",
broker="redis://localhost:6379/0",
backend="redis://localhost:6379/0"
)
@celery_app.task
def process_image_task(image_path: str, operations: list):
"""处理图片的 Celery 任务"""
from PIL import Image
img = Image.open(image_path)
for op in operations:
if op["type"] == "resize":
img = img.resize((op["width"], op["height"]))
elif op["type"] == "rotate":
img = img.rotate(op["angle"])
output_path = f"processed_{image_path}"
img.save(output_path)
return {"output_path": output_path, "operations": len(operations)}
@celery_app.task(bind=True)
def send_bulk_emails(self, email_list: list, template: str):
"""批量发送邮件"""
total = len(email_list)
for i, email in enumerate(email_list):
# 发送邮件逻辑
print(f"发送邮件给 {email}")
# 更新进度
self.update_state(
state='PROGRESS',
meta={'current': i + 1, 'total': total, 'percent': int((i + 1) / total * 100)}
)
return {"sent": total}
# main.py
from fastapi import FastAPI
from celery_app import process_image_task, send_bulk_emails
from celery.result import AsyncResult
app = FastAPI()
@app.post("/process-image")
async def create_image_task(image_path: str, operations: list):
"""创建图片处理任务"""
task = process_image_task.delay(image_path, operations)
return {"task_id": task.id, "status": "submitted"}
@app.get("/task-status/{task_id}")
async def get_celery_task_status(task_id: str):
"""获取 Celery 任务状态"""
task_result = AsyncResult(task_id)
return {
"task_id": task_id,
"status": task_result.status,
"result": task_result.result if task_result.ready() else None
}
@app.post("/send-bulk-emails")
async def create_bulk_email_task(emails: list[str], template: str):
"""创建批量邮件任务"""
task = send_bulk_emails.delay(emails, template)
return {"task_id": task.id, "total_emails": len(emails)}
4. 异步数据库操作
4.1 SQLAlchemy 2.0 异步操作
from fastapi import FastAPI, Depends, HTTPException
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from sqlalchemy.orm import declarative_base
from sqlalchemy import Column, Integer, String, select, update, delete
from pydantic import BaseModel
from typing import List
app = FastAPI()
# 异步数据库配置
DATABASE_URL = "sqlite+aiosqlite:///./async_app.db"
engine = create_async_engine(DATABASE_URL, echo=True)
AsyncSessionLocal = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
Base = declarative_base()
# 模型定义
class User(Base):
__tablename__ = "users"
id = Column(Integer, primary_key=True, index=True)
username = Column(String(50), unique=True, index=True)
email = Column(String(100), unique=True)
full_name = Column(String(100))
# Pydantic 模型
class UserCreate(BaseModel):
username: str
email: str
full_name: str
class UserResponse(BaseModel):
id: int
username: str
email: str
full_name: str
model_config = {"from_attributes": True}
# 数据库依赖
async def get_db():
"""获取异步数据库会话"""
async with AsyncSessionLocal() as session:
try:
yield session
finally:
await session.close()
# 创建表(应用启动时)
@app.on_event("startup")
async def startup():
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
# CRUD 操作
@app.post("/users", response_model=UserResponse)
async def create_user(user: UserCreate, db: AsyncSession = Depends(get_db)):
"""创建用户"""
db_user = User(**user.model_dump())
db.add(db_user)
await db.commit()
await db.refresh(db_user)
return db_user
@app.get("/users", response_model=List[UserResponse])
async def list_users(
skip: int = 0,
limit: int = 100,
db: AsyncSession = Depends(get_db)
):
"""获取用户列表"""
result = await db.execute(select(User).offset(skip).limit(limit))
users = result.scalars().all()
return users
@app.get("/users/{user_id}", response_model=UserResponse)
async def get_user(user_id: int, db: AsyncSession = Depends(get_db)):
"""获取单个用户"""
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None:
raise HTTPException(status_code=404, detail="用户不存在")
return user
@app.put("/users/{user_id}", response_model=UserResponse)
async def update_user(
user_id: int,
user_update: UserCreate,
db: AsyncSession = Depends(get_db)
):
"""更新用户"""
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None:
raise HTTPException(status_code=404, detail="用户不存在")
for key, value in user_update.model_dump().items():
setattr(user, key, value)
await db.commit()
await db.refresh(user)
return user
@app.delete("/users/{user_id}")
async def delete_user(user_id: int, db: AsyncSession = Depends(get_db)):
"""删除用户"""
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None:
raise HTTPException(status_code=404, detail="用户不存在")
await db.delete(user)
await db.commit()
return {"message": "用户删除成功"}
4.2 事务管理
from fastapi import FastAPI, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from contextlib import asynccontextmanager
app = FastAPI()
@asynccontextmanager
async def get_transaction(db: AsyncSession):
"""事务上下文管理器"""
async with db.begin():
yield db
@app.post("/transfer")
async def transfer_money(
from_account: int,
to_account: int,
amount: float,
db: AsyncSession = Depends(get_db)
):
"""转账操作(带事务)"""
try:
async with db.begin():
# 查询转出账户
result = await db.execute(
select(Account).where(Account.id == from_account)
)
from_acc = result.scalar_one_or_none()
if not from_acc or from_acc.balance < amount:
raise HTTPException(status_code=400, detail="余额不足")
# 查询转入账户
result = await db.execute(
select(Account).where(Account.id == to_account)
)
to_acc = result.scalar_one_or_none()
if not to_acc:
raise HTTPException(status_code=404, detail="目标账户不存在")
# 执行转账
from_acc.balance -= amount
to_acc.balance += amount
# 记录交易日志
transaction = Transaction(
from_account=from_account,
to_account=to_account,
amount=amount
)
db.add(transaction)
# 事务自动提交
return {"message": "转账成功", "amount": amount}
except Exception as e:
# 事务自动回滚
raise HTTPException(status_code=500, detail=f"转账失败: {str(e)}")
# 嵌套事务示例
@app.post("/batch-operation")
async def batch_operation(
operations: list[dict],
db: AsyncSession = Depends(get_db)
):
"""批量操作(部分失败处理)"""
results = []
async with db.begin():
for i, op in enumerate(operations):
try:
# 每个操作在子事务中执行
async with db.begin_nested():
await process_operation(db, op)
results.append({"index": i, "status": "success"})
except Exception as e:
results.append({"index": i, "status": "failed", "error": str(e)})
# 子事务回滚,继续下一个
continue
return {"results": results}
async def process_operation(db: AsyncSession, operation: dict):
"""处理单个操作"""
# 具体操作逻辑
pass
4.3 数据库连接池优化
from fastapi import FastAPI
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from sqlalchemy.pool import NullPool
app = FastAPI()
# 优化的异步引擎配置
engine = create_async_engine(
"postgresql+asyncpg://user:pass@localhost/db",
# 连接池配置
pool_size=20, # 连接池大小
max_overflow=10, # 最大溢出连接
pool_pre_ping=True, # 连接前 ping 测试
pool_recycle=3600, # 连接回收时间(秒)
pool_timeout=30, # 获取连接超时时间
# 性能优化
echo=False, # 关闭 SQL 日志(生产环境)
future=True # 使用 SQLAlchemy 2.0 风格
)
session_maker = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False, # 提交后不过期对象
autoflush=False # 手动控制 flush
)
# 连接池监控
@app.get("/db-pool-status")
async def get_pool_status():
"""获取连接池状态"""
pool = engine.pool
return {
"size": pool.size(), # 当前连接数
"checked_in": pool.checkedin(), # 空闲连接
"checked_out": pool.checkedout(), # 使用中连接
"overflow": pool.overflow() # 溢出连接
}
5. 性能优化技巧
5.1 响应缓存
from fastapi import FastAPI, Request, Response
from fastapi.middleware.caching import CacheMiddleware
from functools import wraps
import hashlib
import json
import time
from typing import Callable
app = FastAPI()
# 简单的内存缓存
class SimpleCache:
def __init__(self):
self._cache = {}
def get(self, key: str):
item = self._cache.get(key)
if item and item["expires"] > time.time():
return item["value"]
return None
def set(self, key: str, value: any, ttl: int = 300):
self._cache[key] = {
"value": value,
"expires": time.time() + ttl
}
def delete(self, key: str):
self._cache.pop(key, None)
cache = SimpleCache()
def cached(ttl: int = 300, key_prefix: str = ""):
"""缓存装饰器"""
def decorator(func: Callable):
@wraps(func)
async def wrapper(*args, **kwargs):
# 生成缓存键
cache_key = f"{key_prefix}:{func.__name__}:{hash(str(args))}:{hash(str(kwargs))}"
# 尝试从缓存获取
cached_value = cache.get(cache_key)
if cached_value is not None:
return cached_value
# 执行函数
result = await func(*args, **kwargs)
# 存入缓存
cache.set(cache_key, result, ttl)
return result
return wrapper
return decorator
# 使用缓存
@app.get("/expensive-query")
@cached(ttl=60, key_prefix="query")
async def expensive_query():
"""耗时查询(带缓存)"""
await asyncio.sleep(2) # 模拟耗时操作
return {"data": "expensive result", "timestamp": time.time()}
# Redis 缓存示例
import redis.asyncio as redis
redis_client = redis.Redis(host='localhost', port=6379, db=0)
async def get_cached_or_fetch(key: str, fetch_func: Callable, ttl: int = 300):
"""从缓存获取或执行获取函数"""
# 尝试从 Redis 获取
cached = await redis_client.get(key)
if cached:
return json.loads(cached)
# 执行获取
result = await fetch_func()
# 存入 Redis
await redis_client.setex(key, ttl, json.dumps(result))
return result
@app.get("/users/{user_id}/profile")
async def get_user_profile(user_id: int):
"""获取用户资料(带 Redis 缓存)"""
cache_key = f"user_profile:{user_id}"
async def fetch_profile():
# 从数据库获取
return {"user_id": user_id, "name": "John", "data": "..."}
return await get_cached_or_fetch(cache_key, fetch_profile, ttl=600)
5.2 请求限流
from fastapi import FastAPI, Request, HTTPException
from fastapi.middleware.trustedhost import TrustedHostMiddleware
import time
from collections import defaultdict
app = FastAPI()
# 简单的速率限制器
class RateLimiter:
def __init__(self, requests_per_minute: int = 60):
self.requests_per_minute = requests_per_minute
self.requests = defaultdict(list)
def is_allowed(self, key: str) -> bool:
now = time.time()
minute_ago = now - 60
# 清理旧请求记录
self.requests[key] = [req_time for req_time in self.requests[key] if req_time > minute_ago]
# 检查是否超过限制
if len(self.requests[key]) >= self.requests_per_minute:
return False
# 记录新请求
self.requests[key].append(now)
return True
rate_limiter = RateLimiter(requests_per_minute=100)
@app.middleware("http")
async def rate_limit_middleware(request: Request, call_next):
"""速率限制中间件"""
# 使用 IP 作为限制键
client_ip = request.client.host
if not rate_limiter.is_allowed(client_ip):
raise HTTPException(status_code=429, detail="请求过于频繁,请稍后再试")
response = await call_next(request)
return response
# 基于 Redis 的分布式限流
import redis.asyncio as redis
redis_client = redis.Redis(host='localhost', port=6379, db=0)
async def check_rate_limit(key: str, limit: int, window: int = 60) -> bool:
"""
滑动窗口限流检查
Args:
key: 限流键
limit: 窗口期内最大请求数
window: 时间窗口(秒)
"""
now = time.time()
window_start = now - window
# 使用 Redis 有序集合
pipe = redis_client.pipeline()
# 移除窗口期外的记录
pipe.zremrangebyscore(key, 0, window_start)
# 获取当前窗口内的请求数
pipe.zcard(key)
# 添加当前请求
pipe.zadd(key, {str(now): now})
# 设置过期时间
pipe.expire(key, window)
results = await pipe.execute()
current_count = results[1]
return current_count < limit
@app.get("/api/limited")
async def limited_endpoint(request: Request):
"""受速率限制的端点"""
client_id = request.headers.get("X-Client-ID", request.client.host)
if not await check_rate_limit(f"rate_limit:{client_id}", limit=10, window=60):
raise HTTPException(status_code=429, detail="API 调用次数超限")
return {"message": "请求成功"}
5.3 连接优化
from fastapi import FastAPI
import httpx
from contextlib import asynccontextmanager
app = FastAPI()
# 应用级别的 HTTP 客户端(复用连接)
http_client: httpx.AsyncClient | None = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""应用生命周期管理"""
# 启动时创建客户端
global http_client
http_client = httpx.AsyncClient(
limits=httpx.Limits(
max_connections=100,
max_keepalive_connections=20
),
timeout=httpx.Timeout(30.0)
)
yield
# 关闭时清理
await http_client.aclose()
app = FastAPI(lifespan=lifespan)
@app.get("/external-api")
async def call_external_api():
"""调用外部 API(复用连接)"""
response = await http_client.get("https://api.example.com/data")
return response.json()
# 数据库连接优化
from sqlalchemy.ext.asyncio import create_async_engine
engine = create_async_engine(
"postgresql+asyncpg://user:pass@localhost/db",
pool_size=20,
max_overflow=0, # 不允许溢出,强制限制连接数
pool_pre_ping=True,
pool_recycle=300,
echo=False
)
# 批量操作优化
@app.post("/batch-insert")
async def batch_insert(items: list[dict]):
"""批量插入数据"""
from sqlalchemy.dialects.postgresql import insert
# 使用批量插入
stmt = insert(Item).values(items)
async with engine.begin() as conn:
await conn.execute(stmt)
return {"inserted": len(items)}
5.4 Gzip 压缩与响应优化
from fastapi import FastAPI
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import ORJSONResponse
import orjson
app = FastAPI(default_response_class=ORJSONResponse)
# Gzip 压缩(大于 1000 字节的响应)
app.add_middleware(GZipMiddleware, minimum_size=1000)
# CORS 配置
app.add_middleware(
CORSMiddleware,
allow_origins=["https://example.com"],
allow_credentials=True,
allow_methods=["GET", "POST", "PUT", "DELETE"],
allow_headers=["*"],
max_age=3600
)
# 使用 orjson 加速 JSON 序列化
@app.get("/large-data")
async def get_large_data():
"""返回大量数据"""
data = [{"id": i, "data": "x" * 100} for i in range(10000)]
return data
# 流式响应(大文件下载)
from fastapi.responses import StreamingResponse
import io
@app.get("/download-large-file")
async def download_large_file():
"""流式下载大文件"""
def file_generator():
for i in range(1000):
yield f"Line {i}: {'x' * 100}\n"
return StreamingResponse(
file_generator(),
media_type="text/plain",
headers={"Content-Disposition": "attachment; filename=large.txt"}
)
# 分页优化
@app.get("/items")
async def get_items(
cursor: str | None = None,
limit: int = 20
):
"""游标分页(适合大数据集)"""
query = select(Item).order_by(Item.id).limit(limit + 1)
if cursor:
# 解码游标
last_id = decode_cursor(cursor)
query = query.where(Item.id > last_id)
result = await db.execute(query)
items = result.scalars().all()
# 检查是否有更多数据
has_more = len(items) > limit
items = items[:limit]
# 生成下一页游标
next_cursor = encode_cursor(items[-1].id) if has_more else None
return {
"items": items,
"next_cursor": next_cursor,
"has_more": has_more
}
def encode_cursor(id: int) -> str:
"""编码游标"""
import base64
return base64.b64encode(str(id).encode()).decode()
def decode_cursor(cursor: str) -> int:
"""解码游标"""
import base64
return int(base64.b64decode(cursor.encode()).decode())
6. 性能监控与调试
6.1 请求计时中间件
from fastapi import FastAPI, Request
import time
import logging
app = FastAPI()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
@app.middleware("http")
async def timing_middleware(request: Request, call_next):
"""请求计时中间件"""
start_time = time.time()
response = await call_next(request)
process_time = time.time() - start_time
# 记录慢请求
if process_time > 1.0:
logger.warning(
f"慢请求: {request.method} {request.url.path} "
f"耗时: {process_time:.3f}s"
)
# 添加响应头
response.headers["X-Process-Time"] = str(process_time)
logger.info(
f"{request.method} {request.url.path} "
f"- {response.status_code} - {process_time:.3f}s"
)
return response
# 性能指标端点
@app.get("/metrics")
async def get_metrics():
"""获取应用性能指标"""
import psutil
return {
"cpu_percent": psutil.cpu_percent(),
"memory_percent": psutil.virtual_memory().percent,
"disk_usage": psutil.disk_usage('/').percent,
"connections": len(psutil.net_connections())
}
6.2 OpenTelemetry 集成
from fastapi import FastAPI
from opentelemetry import trace
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
app = FastAPI()
# 配置 OpenTelemetry
resource = Resource.create({"service.name": "my-fastapi-app"})
provider = TracerProvider(resource=resource)
processor = BatchSpanProcessor(OTLPSpanExporter())
provider.add_span_processor(processor)
trace.set_tracer_provider(provider)
# 自动 instrument FastAPI
FastAPIInstrumentor.instrument_app(app)
@app.get("/ traced-endpoint")
async def traced_endpoint():
"""带有追踪的端点"""
tracer = trace.get_tracer(__name__)
with tracer.start_as_current_span("process_request") as span:
span.set_attribute("custom.attribute", "value")
# 模拟工作
await asyncio.sleep(0.1)
with tracer.start_as_current_span("database_query"):
# 模拟数据库查询
await asyncio.sleep(0.05)
return {"message": "完成"}
避坑小贴士
1. 不要在异步函数中使用同步 I/O
# 错误:会阻塞事件循环
@app.get("/bad")
async def bad_endpoint():
import requests # 同步库
response = requests.get("https://api.example.com") # 阻塞!
return response.json()
# 正确:使用异步 HTTP 客户端
@app.get("/good")
async def good_endpoint():
import httpx # 异步库
async with httpx.AsyncClient() as client:
response = await client.get("https://api.example.com")
return response.json()
2. 注意 asyncio.gather 的错误处理
# 问题:一个任务失败会导致所有结果丢失
tasks = [fetch_data(i) for i in range(10)]
results = await asyncio.gather(*tasks) # 任一失败都会抛出异常
# 解决方案1:return_exceptions=True
results = await asyncio.gather(*tasks, return_exceptions=True)
for i, result in enumerate(results):
if isinstance(result, Exception):
print(f"任务 {i} 失败: {result}")
else:
print(f"任务 {i} 成功: {result}")
# 解决方案2:使用 TaskGroup(Python 3.11+)
async with asyncio.TaskGroup() as tg:
tasks = [tg.create_task(fetch_data(i)) for i in range(10)]
results = [task.result() for task in tasks]
3. 小心内存泄漏
# 问题:全局列表无限增长
results_cache = []
@app.get("/cache")
async def cache_result():
result = await expensive_operation()
results_cache.append(result) # 内存泄漏!
return result
# 解决方案:使用有界缓存
from collections import deque
results_cache = deque(maxlen=1000) # 最多保存 1000 条
# 或使用 TTL 缓存
from cachetools import TTLCache
results_cache = TTLCache(maxsize=1000, ttl=3600)
4. 数据库连接未正确关闭
# 错误:连接可能泄露
@app.get("/bad-db")
async def bad_db_access():
db = await create_db_connection()
result = await db.query("SELECT * FROM users")
# 如果这里抛出异常,连接不会关闭!
await db.close()
return result
# 正确:使用上下文管理器
@app.get("/good-db")
async def good_db_access():
async with create_db_connection() as db:
result = await db.query("SELECT * FROM users")
return result # 自动关闭
# 或使用依赖注入(FastAPI 自动处理)
async def get_db():
db = SessionLocal()
try:
yield db
finally:
await db.close()
课后练习
练习 1:异步 HTTP 客户端
实现一个异步端点,并发请求 3 个不同的 API,并合并结果返回。
练习 2:后台任务系统
实现一个文件处理系统:
- 上传文件后,在后台异步处理
- 提供接口查询处理进度
- 处理完成后发送通知
练习 3:数据库优化
实现用户搜索功能:
- 使用异步数据库查询
- 添加 Redis 缓存
- 实现游标分页
练习 4:性能测试
使用 wrk 或 locust 对你的 FastAPI 应用进行压力测试:
- 测试同步 vs 异步端点的性能差异
- 测试不同并发连接数下的表现
- 找出性能瓶颈并优化
下一篇预告
第15讲:FastAPI安全与认证
在下一讲中,我们将学习:
- JWT 认证与授权
- OAuth2 集成
- API 安全防护
- HTTPS 配置与部署
安全是 Web 应用的重中之重,敬请期待!
参考资源
更多推荐
所有评论(0)