Flask部署机器学习模型的生产级实践指南

1. 这不是“写个API”那么简单:为什么用 Flask 部署模型是多数团队的真实起点

你手头刚跑通一个在本地 Jupyter 里准确率 92.3% 的图像分类模型,老板下周一就要在客户演示系统里调用它——不是让你发个 notebook 链接,而是要“能被前端按钮点一下就返回结果”。这时候,你搜到的第一批教程标题几乎全是“How to deploy ML model with Flask”。但很快你会发现,照着敲完 app.py flask run ,本地能返回 JSON,一上服务器就报错 ModuleNotFoundError: No module named 'torch' ;或者压测时并发 5 个请求,CPU 直冲 100%,响应时间从 200ms 涨到 8 秒;更别提模型加载慢、内存泄漏、日志没埋点、错误信息全是 Internal Server Error ……这些根本不是 Flask 的锅,而是部署这件事本身,远比“写个路由”复杂得多。

Flask 在这里扮演的角色,本质上是一个 轻量级胶水层 :它不负责训练、不优化推理速度、不管理 GPU 资源、不处理自动扩缩容。它的价值恰恰在于“不做多余的事”——把模型预测这个确定性动作,封装成标准 HTTP 接口,让前端、移动端、其他后端服务能像调用天气 API 一样调用你的模型。我过去三年带过 17 个落地项目,从金融风控的 XGBoost 模型到医疗影像的 ResNet50 分割模型,90% 的 MVP 阶段都首选 Flask,原因很实在:它没有 Django 那套 ORM 和 admin 后台的冗余开销,也没有 FastAPI 那样对异步和 Pydantic 的强依赖(尤其当团队里有只熟悉 Python 基础的算法同学时)。它就像一把瑞士军刀里的小剪刀——功能单一,但够快、够稳、够透明,所有问题都能一眼看到根上。这篇文章不讲“Flask 基础语法”,而是聚焦你真正卡住的地方:如何让模型在 Flask 里 不崩、不慢、不糊弄、不裸奔 。适合已经跑通模型、正对着 requirements.txt 发愁的工程师,也适合想理解“模型上线到底难在哪”的算法同事。接下来每一部分,都是我在生产环境里亲手踩坑、反复验证过的路径。

2. 整体架构设计:为什么拒绝“单文件暴走”,而选择分层解耦

2.1 核心矛盾:模型的“重”与 Web 框架的“轻”天然冲突

初学者最容易犯的错误,就是把模型加载、预处理、预测逻辑全塞进 app.py 一个文件里。比如这样:

from flask import Flask, request, jsonify
import joblib
import numpy as np

# ❌ 危险:模型在每次请求时都重新加载!
def predict():
    model = joblib.load('model.pkl')  # 每次请求都磁盘 IO + 反序列化
    data = request.json['features']
    result = model.predict([data])
    return jsonify({'prediction': int(result[0])})

这在本地测试时毫无压力,但上线后立刻暴露三个致命问题:
第一, 启动延迟高 :Flask 启动时并不加载模型,第一个请求进来才触发 joblib.load() ,用户首屏等待超时;
第二, 内存爆炸 :每次请求都新建一个模型实例,10 个并发就是 10 份模型副本,4GB 模型直接吃光 40GB 内存;
第三, 状态不可控 :模型参数、预处理对象(如 StandardScaler )散落在函数里,无法统一管理版本或热更新。

真正的解法,是把整个流程拆成四个明确职责的层,用 Python 包结构固化下来:

ml_service/
├── app/                    # Flask 应用核心(路由、HTTP 层)
│   ├── __init__.py         # 创建 Flask 实例、注册蓝图
│   ├── main.py             # 主路由定义(/predict, /health)
│   └── errors.py           # 全局错误处理器
├── models/                 # 模型生命周期管理(核心!)
│   ├── __init__.py         # 定义 ModelManager 单例
│   ├── base.py             # 抽象基类(load, predict, validate)
│   ├── sklearn_model.py    # 具体实现:XGBoost/LightGBM 加载逻辑
│   └── torch_model.py      # 具体实现:PyTorch 模型 + GPU 设备管理
├── preprocessing/          # 输入输出标准化(非业务逻辑)
│   ├── __init__.py         # 提供统一的 transform() 和 postprocess()
│   └── image_transform.py  # 图像缩放、归一化、Tensor 转换
└── config.py               # 所有可配置项集中管理(路径、超时、设备)

提示:这种结构不是为了“显得专业”,而是为了解决真实运维问题。比如 models/__init__.py 里用 @lru_cache 或模块级变量实现模型单例,确保整个 Flask 进程只加载一次; config.py 里把模型路径设为环境变量 MODEL_PATH ,上线时不用改代码,只改 Docker 环境变量即可。

2.2 为什么坚持“同步阻塞”而非盲目上异步

看到 FastAPI 的 async/await 示例,很多人会想:“我的模型预测是不是也能异步?”答案是否定的——除非你的模型本身支持异步 I/O(比如调用外部 API),否则强行加 async def predict() 只会让事情更糟。Python 的 GIL(全局解释器锁)决定了:纯 CPU 计算(如 NumPy 矩阵运算、PyTorch forward)无法通过 asyncio 并行加速。我实测过一个 ResNet50 推理接口:

  • 同步 Flask(4 workers):平均延迟 320ms,并发 20 QPS
  • FastAPI + async(4 workers):平均延迟 315ms,并发 19 QPS(因 asyncio event loop 额外开销,QPS 反而略降)

真正提升吞吐量的,是 worker 进程数 模型推理优化 ,而不是协程。Flask 的 gunicorn --workers 4 --worker-class sync 是更务实的选择。async 的价值,在于当你需要同时做三件事:1)读取用户上传的图片(I/O)、2)调用模型预测(CPU)、3)把结果写入数据库(I/O)——这时 FastAPI 的 async 才有意义。但对绝大多数“输入→模型→输出”单链路场景,Flask 的简洁性胜过所有花哨特性。

2.3 容器化不是可选项,而是安全底线

有人问:“能不能直接在服务器上 pip install 然后 python app.py?”可以,但等于把生产环境当开发机用。我见过最惨的案例:某团队在 Ubuntu 20.04 服务器上装了 torch==1.12.1+cu113 ,三个月后因为安全补丁升级了 CUDA 驱动,模型直接报 libcudnn.so not found ,服务中断 6 小时。Docker 的价值,是把“模型能跑”这个状态固化下来。一个最小可行的 Dockerfile 必须包含三要素:

FROM nvidia/cuda:11.3.1-cudnn8-runtime-ubuntu20.04  # 显式指定 CUDA 版本,避免驱动冲突

# 复制依赖前先复制 requirements.txt,利用 Docker layer cache 加速构建
COPY requirements.txt .
RUN pip install --no-cache-dir -r requirements.txt

# 复制模型文件(注意:大模型文件放最后,避免频繁重建镜像层)
COPY models/production/model.pth /app/models/
COPY models/production/scaler.pkl /app/models/

# 启动命令必须指定 GPU 设备(即使只用 CPU,也要显式声明 --device=cpu)
CMD ["gunicorn", "--bind", "0.0.0.0:5000", "--workers", "4", "app:create_app()"]

关键细节:基础镜像必须匹配你训练时的 CUDA 版本(查 torch.version.cuda ); --device=cpu 参数不是摆设——它强制 PyTorch 使用 CPU 后端,避免容器内无 GPU 时自动 fallback 到 CPU 导致性能骤降却无提示。

3. 核心细节解析:模型加载、预处理、错误防御的实战要点

3.1 模型加载:单例模式 + 延迟初始化,一个都不能少

模型加载必须解决两个问题: 何时加载 (时机)和 加载几次 (次数)。正确做法是:在 Flask 应用工厂函数中,用 @app.before_first_request (Flask <2.3)或 app.app_context() (Flask >=2.3)触发加载,并用模块级变量缓存:

# models/__init__.py
import logging
from typing import Optional
from .base import BaseModel
from .sklearn_model import SklearnModel
from .torch_model import TorchModel

_logger = logging.getLogger(__name__)
_model_instance: Optional[BaseModel] = None

def get_model() -> BaseModel:
    global _model_instance
    if _model_instance is None:
        try:
            # 从 config.py 读取模型路径和类型
            from config import MODEL_TYPE, MODEL_PATH
            if MODEL_TYPE == "sklearn":
                _model_instance = SklearnModel.load(MODEL_PATH)
            elif MODEL_TYPE == "torch":
                _model_instance = TorchModel.load(MODEL_PATH)
            _logger.info(f"Model loaded successfully: {MODEL_TYPE} from {MODEL_PATH}")
        except Exception as e:
            _logger.critical(f"Failed to load model: {e}")
            raise RuntimeError(f"Model loading failed: {e}")
    return _model_instance

为什么不用 @lru_cache ?因为 lru_cache 缓存的是函数返回值,而模型对象是可变的(如 PyTorch 的 .eval() 状态),且无法序列化。模块级变量是唯一能跨请求共享状态的安全方式。更重要的是, 加载失败必须抛出异常并终止进程 ——如果静默失败,后续所有请求都会返回 500,但你根本不知道是模型问题还是代码问题。

3.2 预处理:输入校验比模型预测更重要

90% 的线上报错,根源不在模型,而在输入数据。一个典型的崩溃场景:前端传来的 JSON 里 {"image": "data:image/png;base64,..."} ,后端直接 base64.b64decode() ,结果字符串里混入了空格或换行符, binascii.Error: Incorrect padding 直接 500。正确的预处理链必须包含三层防御:

  1. Schema 校验 :用 pydantic 定义严格输入结构,拒绝非法字段
  2. 内容校验 :对 base64 字符串做长度检查、字符过滤、padding 补全
  3. 业务校验 :图像尺寸是否在模型接受范围内(如 ResNet 要求 224x224)
# preprocessing/image_transform.py
from pydantic import BaseModel, validator
from typing import Optional
import base64
import numpy as np
from PIL import Image

class PredictionRequest(BaseModel):
    image: str  # base64 encoded string
    
    @validator('image')
    def validate_base64(cls, v):
        # 移除 data URL 前缀(如 data:image/png;base64,)
        if v.startswith('data:'):
            v = v.split(',', 1)[-1]
        # 移除空格和换行符
        v = v.replace(' ', '').replace('\n', '').replace('\r', '')
        # 补全 padding(base64 长度必须是 4 的倍数)
        missing_padding = len(v) % 4
        if missing_padding:
            v += '=' * (4 - missing_padding)
        try:
            base64.b64decode(v, validate=True)
        except Exception as e:
            raise ValueError(f"Invalid base64 string: {e}")
        return v

def decode_image(base64_str: str) -> np.ndarray:
    """Convert base64 string to RGB numpy array, resize to 224x224"""
    try:
        img_bytes = base64.b64decode(base64_str)
        img = Image.open(io.BytesIO(img_bytes)).convert('RGB')
        img = img.resize((224, 224), Image.BILINEAR)
        return np.array(img)  # shape: (224, 224, 3)
    except Exception as e:
        raise ValueError(f"Image decode failed: {e}")

注意: pydantic validator 在请求解析阶段就执行,非法输入根本不会进入预测函数。这比在 predict() 里写 if not isinstance(data, str): raise ValueError 更早拦截错误,减少无效计算。

3.3 错误防御:给每个环节配“保险丝”,而不是等熔断

生产环境最怕的不是报错,而是 静默失败 。比如模型预测返回 NaN ,但代码没检查就直接 jsonify() ,前端收到 {"prediction": null} 却以为是正常结果。必须在每个关键节点插入“保险丝”:

  • 模型输出校验 :检查预测结果是否为 np.nan inf ,或超出业务合理范围(如概率值不在 [0,1])
  • HTTP 响应包装 :所有成功响应必须包含 status_code=200 request_id ,便于日志追踪
  • 全局异常处理器 :捕获未预见异常,返回结构化错误,绝不暴露堆栈
# app/errors.py
from flask import jsonify
import uuid

def register_error_handlers(app):
    @app.errorhandler(400)
    def bad_request(error):
        return jsonify({
            'error': 'Bad Request',
            'message': str(error.description),
            'request_id': str(uuid.uuid4())
        }), 400

    @app.errorhandler(500)
    def internal_error(error):
        # 记录完整堆栈到日志,但响应体只返回通用信息
        app.logger.exception("Unhandled exception occurred")
        return jsonify({
            'error': 'Internal Server Error',
            'request_id': str(uuid.uuid4())
        }), 500

    # 捕获模型预测中的业务异常(如 ValueError)
    @app.errorhandler(ValueError)
    def handle_value_error(error):
        return jsonify({
            'error': 'Invalid Input',
            'message': str(error),
            'request_id': str(uuid.uuid4())
        }), 400

实操心得:我在 TorchModel.predict() 方法末尾强制添加:

if np.isnan(output).any() or np.isinf(output).any():
    raise ValueError("Model output contains NaN or Inf")

这招救了我们两次——一次是模型训练时标签泄露导致 logits 异常,另一次是预处理时图像全黑(像素值全 0)触发了 ReLU 死区。没有这行检查,问题会潜伏数周才被业务方发现。

4. 实操过程:从本地调试到生产部署的完整流水线

4.1 本地开发:用 pytest + fixtures 模拟真实请求链

不要用 curl 或 Postman 手动测每个接口。为预测路由写单元测试,能提前暴露 70% 的集成问题:

# tests/test_predict.py
import pytest
from app import create_app
from models import get_model
from preprocessing.image_transform import encode_image

@pytest.fixture
def client():
    app = create_app()
    app.config['TESTING'] = True
    with app.test_client() as client:
        yield client

def test_predict_valid_image(client):
    # 准备一张 224x224 的测试图(用 PIL 生成)
    from PIL import Image
    import io
    test_img = Image.new('RGB', (224, 224), color='red')
    img_buffer = io.BytesIO()
    test_img.save(img_buffer, format='PNG')
    img_buffer.seek(0)
    base64_img = encode_image(img_buffer.getvalue())  # 自定义编码函数
    
    response = client.post('/predict', 
                          json={'image': base64_img},
                          content_type='application/json')
    
    assert response.status_code == 200
    data = response.get_json()
    assert 'prediction' in data
    assert isinstance(data['prediction'], (int, float))

def test_predict_invalid_base64(client):
    response = client.post('/predict', 
                          json={'image': 'invalid!!!'},
                          content_type='application/json')
    assert response.status_code == 400
    assert 'Invalid base64' in response.get_json()['message']

关键技巧: pytest fixture 机制让每个测试用独立的 Flask 应用实例,避免状态污染;用 PIL.Image.new() 生成测试图,不依赖外部文件,保证测试可重现。运行 pytest tests/ -v ,所有测试通过才是本地开发完成的标志。

4.2 生产配置:Gunicorn + Nginx 的黄金组合

Flask 自带的 flask run 仅用于开发。生产必须用 Gunicorn(WSGI 服务器)+ Nginx(反向代理):

组件 职责 关键配置
Gunicorn 管理 Flask worker 进程,处理并发请求 --workers 4 (CPU 核数 x2)、 --timeout 120 (防长耗时请求)、 --max-requests 1000 (定期重启 worker 防内存泄漏)
Nginx 接收外部 HTTP 请求,转发给 Gunicorn,处理静态文件、SSL 终止、限流 proxy_read_timeout 120 (必须 ≥ Gunicorn timeout)、 client_max_body_size 10M (支持大图上传)

一个最小可用的 gunicorn.conf.py

# gunicorn.conf.py
import multiprocessing

bind = "0.0.0.0:8000"
bind_address = "0.0.0.0:8000"
workers = multiprocessing.cpu_count() * 2 + 1
worker_class = "sync"
timeout = 120
max_requests = 1000
max_requests_jitter = 100
keepalive = 5
preload = True  # 关键!预加载应用,确保模型在 fork 前加载

注意 preload = True :它让 Gunicorn 在 fork worker 进程前先导入应用,此时模型已加载到主进程内存,fork 后的 worker 通过 copy-on-write 共享模型内存,极大节省 RAM。没有这一行,4 个 worker 会各自加载 4 份模型。

Nginx 配置片段( /etc/nginx/sites-available/ml-service ):

upstream ml_backend {
    server 127.0.0.1:8000;
}

server {
    listen 443 ssl;
    server_name api.yourdomain.com;

    ssl_certificate /etc/letsencrypt/live/yourdomain.com/fullchain.pem;
    ssl_certificate_key /etc/letsencrypt/live/yourdomain.com/privkey.pem;

    location /predict {
        proxy_pass http://ml_backend;
        proxy_set_header Host $host;
        proxy_set_header X-Real-IP $remote_addr;
        proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
        proxy_read_timeout 120;  # 必须匹配 Gunicorn timeout
        client_max_body_size 10M;
    }

    location /health {
        proxy_pass http://ml_backend;
        proxy_set_header Host $host;
    }
}

4.3 健康检查与监控:让服务“会说话”

生产服务必须提供 /health 接口,且它要检查 所有依赖项 ,不只是 Flask 是否存活:

# app/main.py
from flask import Blueprint, jsonify
from models import get_model
from config import MODEL_PATH

health_bp = Blueprint('health', __name__)

@health_bp.route('/health')
def health_check():
    try:
        # 1. 检查模型是否加载成功
        model = get_model()
        # 2. 检查模型文件是否存在(防止文件被误删)
        import os
        if not os.path.exists(MODEL_PATH):
            return jsonify({'status': 'error', 'reason': 'Model file missing'}), 503
        # 3. 可选:执行一次轻量预测(如用全 0 输入)
        # dummy_input = np.zeros((1, 3, 224, 224))
        # _ = model.predict(dummy_input)
        return jsonify({'status': 'ok', 'model': model.__class__.__name__})
    except Exception as e:
        return jsonify({'status': 'error', 'reason': str(e)}), 503

Kubernetes 或云平台的健康检查探针,必须调用此接口。我见过太多服务因为 /health 只返回 {"status":"ok"} 而被误判为健康,实际模型已损坏。真正的健康,是 端到端可预测

5. 常见问题与排查技巧实录:那些文档里不会写的坑

5.1 内存泄漏:模型越跑越慢,直到 OOM

现象 :服务运行 24 小时后, ps aux --sort=-%mem | head 显示 Python 进程内存从 1GB 涨到 8GB, gunicorn worker 频繁重启。
根因 :PyTorch 默认启用 torch.autograd.grad 的计算图缓存,即使在 model.eval() 模式下,若未显式禁用,中间 Tensor 会持续累积。
解决方案 :在预测函数开头强制关闭梯度,并清空 CUDA 缓存(如果用 GPU):

def predict(self, input_data: np.ndarray) -> np.ndarray:
    import torch
    with torch.no_grad():  # 关键!禁用梯度计算
        if torch.cuda.is_available():
            input_tensor = torch.from_numpy(input_data).float().cuda()
            output = self.model(input_tensor)
            torch.cuda.empty_cache()  # 清理 GPU 显存
        else:
            input_tensor = torch.from_numpy(input_data).float()
            output = self.model(input_tensor)
    return output.cpu().numpy()

实测数据:关闭 torch.no_grad() 后,1000 次预测内存增长 1.2GB;开启后,内存稳定在 1.05GB(仅模型权重占用)。

5.2 GPU 利用率低:明明有 A100,却只跑出 10% 利用率

现象 nvidia-smi 显示 GPU-Util 长期低于 20%,但 CPU 使用率 100%,响应延迟高。
根因 :数据加载(Data Loading)成为瓶颈。PyTorch 的 DataLoader 默认 num_workers=0 ,即在主线程中同步读取和预处理图像,CPU 算力全耗在 IO 上,GPU 一直空闲。
解决方案 :在 TorchModel.load() 中配置多进程数据加载,但需注意 num_workers > 0 时必须使用 spawn 启动方式(Linux/macOS 默认),且模型不能是 lambda 或闭包:

# models/torch_model.py
def load(cls, path: str):
    model = torch.load(path, map_location='cpu')
    model.eval()
    # 关键:预热模型,触发 CUDA 初始化
    if torch.cuda.is_available():
        model = model.cuda()
        # 预热:用 dummy input 触发 CUDA kernel 编译
        dummy = torch.randn(1, 3, 224, 224).cuda()
        _ = model(dummy)
    return cls(model)

# 在 predict() 中使用 DataLoader(仅当批量预测时)
def predict_batch(self, batch_images: List[np.ndarray]):
    from torch.utils.data import DataLoader, TensorDataset
    # 转 tensor 并送 GPU
    tensors = [torch.from_numpy(img).float().permute(2,0,1) for img in batch_images]
    dataset = TensorDataset(torch.stack(tensors))
    loader = DataLoader(dataset, batch_size=8, num_workers=4, pin_memory=True)  # pin_memory 加速 CPU→GPU 传输
    results = []
    with torch.no_grad():
        for batch in loader:
            batch = batch[0].cuda()
            out = self.model(batch)
            results.append(out.cpu().numpy())
    return np.concatenate(results)

5.3 日志混乱:找不到谁在哪个时间点干了什么

现象 :线上报错, journalctl -u gunicorn 里一堆 Internal Server Error ,但没有 request_id,无法关联前端日志。
根因 :Flask 默认日志不包含请求上下文,且 Gunicorn 的 accesslog errorlog 分离,难以串联。
解决方案 :用 structlog 替代原生日志,注入 request_id,并统一输出格式:

# app/__init__.py
import structlog
import logging
from flask import request, g
import uuid

# 配置 structlog
structlog.configure(
    processors=[
        structlog.stdlib.filter_by_level,
        structlog.stdlib.add_logger_name,
        structlog.stdlib.add_log_level,
        structlog.stdlib.PositionalArgumentsFormatter(),
        structlog.processors.TimeStamper(fmt="iso"),
        structlog.processors.StackInfoRenderer(),
        structlog.processors.format_exc_info,
        structlog.processors.UnicodeDecoder(),
        structlog.processors.JSONRenderer()  # 输出 JSON,方便 ELK 收集
    ],
    context_class=dict,
    logger_factory=structlog.stdlib.LoggerFactory(),
    wrapper_class=structlog.stdlib.BoundLogger,
    cache_logger_on_first_use=True,
)

# 请求开始时生成 request_id
@app.before_request
def before_request():
    g.request_id = str(uuid.uuid4())

# 在日志中自动注入 request_id
@app.after_request
def after_request(response):
    structlog.get_logger().info(
        "request_finished",
        method=request.method,
        path=request.path,
        status_code=response.status_code,
        request_id=g.request_id,
        duration_ms=int((time.time() - g.start_time) * 1000) if hasattr(g, 'start_time') else 0
    )
    return response

然后在 gunicorn.conf.py 中关闭默认日志,全部走 structlog:

accesslog = "-"  # stdout
errorlog = "-"   # stdout
capture_output = True

这样每条日志都是结构化 JSON,含 request_id path duration_ms ,ELK 或 Loki 一搜 request_id 就能串起整个请求链。

5.4 模型热更新:不想重启服务,怎么换新模型?

需求 :模型迭代频繁,每次更新都要 systemctl restart gunicorn ,导致 30 秒服务不可用。
安全方案 :用文件系统信号触发重载,而非进程重启。原理是监听模型文件的修改时间,当检测到变更时,优雅地重新加载模型:

# models/__init__.py
import os
import time
from threading import Thread

_model_instance = None
_last_modified = 0
_MODEL_FILE = None

def _watch_model_file():
    global _last_modified, _MODEL_FILE
    while True:
        if _MODEL_FILE and os.path.exists(_MODEL_FILE):
            current_mod = os.path.getmtime(_MODEL_FILE)
            if current_mod != _last_modified:
                _last_modified = current_mod
                _logger.info(f"Model file changed, reloading...")
                # 重新加载模型(线程安全:用新实例替换旧实例)
                new_model = _load_model(_MODEL_FILE)
                global _model_instance
                _model_instance = new_model
        time.sleep(5)  # 每 5 秒检查一次

# 在应用启动后启动监控线程
def start_model_watcher(model_path: str):
    global _MODEL_FILE, _last_modified
    _MODEL_FILE = model_path
    _last_modified = os.path.getmtime(model_path) if os.path.exists(model_path) else 0
    watcher = Thread(target=_watch_model_file, daemon=True)
    watcher.start()

调用方式:在 create_app() 最后加 start_model_watcher(config.MODEL_PATH) 。运维只需 touch model.pth cp new_model.pth model.pth ,服务会在 5 秒内自动加载新模型,零停机。注意:此方案要求模型加载是幂等的(多次加载同一文件无副作用),且预测函数是线程安全的(我们用 get_model() 获取实例,天然线程安全)。

6. 性能压测与容量规划:别让“能跑”变成“不敢用”

6.1 用 Locust 模拟真实流量,而非 ab 工具

ab (Apache Bench)只能测单 URL,无法模拟用户行为链(如上传图→等待→获取结果)。Locust 支持编写 Python 脚本定义用户行为:

# locustfile.py
from locust import HttpUser, task, between
import base64
import numpy as np
from PIL import Image
import io

class MLUser(HttpUser):
    wait_time = between(1, 3)  # 用户思考时间 1-3 秒
    
    @task
    def predict_image(self):
        # 生成随机红图(模拟真实图像大小)
        img = Image.new('RGB', (224, 224), color='red')
        buffer = io.BytesIO()
        img.save(buffer, format='PNG')
        base64_img = base64.b64encode(buffer.getvalue()).decode('utf-8')
        
        self.client.post("/predict", 
                        json={"image": base64_img},
                        name="/predict (224x224)")

# 运行:locust -f locustfile.py --host http://localhost:5000

压测时关注三个核心指标:

  • P95 延迟 :95% 的请求响应时间 ≤ 500ms(业务可接受阈值)
  • 错误率 :必须 < 0.1%,任何 5xx 都要立即排查
  • 资源饱和度 htop 看 CPU 是否持续 >80%, nvidia-smi 看 GPU-Util 是否 >90%

我通常按阶梯加压:10 → 50 → 100 → 200 用户,每档跑 5 分钟,观察指标拐点。当 P95 从 300ms 涨到 1200ms 时,说明已达当前配置瓶颈。

6.2 容量公式:如何计算你需要多少台机器

别猜,用公式算。核心是 QPS = (CPU 核数 × 单核处理能力) ÷ 单请求 CPU 时间
实测你的模型:在 1 核 CPU 上,单次预测耗时 400ms,则单核理论 QPS = 1000ms / 400ms = 2.5 QPS。
若目标承载 100 QPS,则需 CPU 核数 = 100 / 2.5 = 40 核。
一台 16 核服务器,最多支撑 40 QPS,所以 100 QPS 需要 3 台(40×2=80 < 100,40×3=120 > 100)。

GPU 场景同理:测单卡 A100 的最大并发(如 32),目标 1000 QPS,则需 1000/32 ≈ 32 张卡。

注意:这是理论值,必须预留 30% 余量应对流量峰值,所以最终采购量 = 计算值 × 1.3。我们曾按 100 QPS 采购,结果双十一流量峰值达 180 QPS,幸亏预留了余量。

6.3 成本优化:什么时候该换框架?

Flask 不是万能的。当你的场景出现以下任一情况,就该评估迁移:

  • QPS > 500 且持续增长 :Flask + Gunicorn 的进程模型扩展成本高,考虑 FastAPI + Uvicorn(异步 + 更高吞吐)
  • 需要实时流式响应 (如语音转文字逐字返回):Flask 不支持 Server-Sent Events,必须换 Starlette 或 Quart
  • 微服务化程度高 ,需服务发现、熔断、链路追踪:Flask 生态弱,Spring Cloud 或 Go Microservice 更合适

但记住: 迁移成本 = 开发成本 + 测试成本 + 运维学习成本 。我们有个项目 QPS 从 200 涨到 600,团队花了 3 周迁到 FastAPI,QPS 提升到 800——但 ROI 很低,因为 600 QPS 时 Gunicorn 用 8 workers 也能扛住,只是 CPU 利用率高些。真正的瓶颈,往往在模型本身,而不是框架。


我在实际使用中发现,最浪费时间的从来不是写代码,而是 在错误的日志里找错误的原因 。比如有一次, gunicorn worker 频繁重启,日志只显示 Worker timeout ,查了两天才发现是模型加载时 torch.load() 读取了一个 2GB 的 .pth 文件,而 --timeout 120 设置得太短——把超时调到 300 秒,问题消失。所以,永远先看日志的 上下文 ,而不是第一行错误。另外,别迷信“最新版”,我们线上稳定运行的是 Flask 2.0.3 + Gunicorn 20.1.0,不是最新的 2.3.x,因为新版某个 commit 引入了 threading.local 内存泄漏,官方 issue 到现在还没合入。生产环境,稳字当头。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值