Gradio自定义组件:DAMO-YOLO手机检测WebUI集成手机型号识别扩展功能

Gradio自定义组件:DAMO-YOLO手机检测WebUI集成手机型号识别扩展功能

1. 项目简介:从“找到手机”到“认出手机”

如果你用过一些AI检测工具,可能会发现一个挺普遍的情况:系统能告诉你“这里有个手机”,但也就到此为止了。它不会告诉你这是iPhone 15 Pro Max还是小米14 Ultra,是华为Mate 60还是三星Galaxy S24。

今天要介绍的这个项目,就是在解决这个问题。它基于一个叫DAMO-YOLO的轻量级检测模型,原本只能检测手机的位置。我们给它加了个“大脑”——一个手机型号识别扩展功能。现在它不仅能框出手机在哪里,还能告诉你这是什么型号的手机。

想象一下这些场景:

  • 商场里分析顾客都用什么手机,好调整产品陈列
  • 二手手机交易平台,自动识别上传的手机型号
  • 内容审核时,快速识别视频中出现的手机品牌
  • 市场调研,统计不同场合的手机品牌分布

这个项目的核心价值很简单:让检测更有用。不是仅仅知道“有手机”,而是知道“有什么手机”。

2. 技术核心:DAMO-YOLO + TinyNAS的“小快省”组合

2.1 为什么选DAMO-YOLO?

你可能听说过YOLO系列模型,从YOLOv1到现在的YOLOv11,版本多得让人眼花缭乱。DAMO-YOLO是阿里巴巴达摩院推出的一个变种,它的特点可以用三个字概括:小、快、省

指的是模型体积小。完整的DAMO-YOLO-S模型只有125MB左右,相比动辄几个G的大模型,它轻巧得像个精灵。

说的是推理速度快。在T4 GPU上,处理一张640x640的图片只需要3.83毫秒。什么概念?一秒钟能处理260多张图片,真正的实时检测。

是省资源。它专门为移动端和边缘设备优化,对算力要求低,功耗也小。这意味着你可以在普通的服务器上部署,甚至在一些性能不错的开发板上也能跑起来。

2.2 TinyNAS:让模型更聪明的“瘦身术”

如果说DAMO-YOLO是基础,那TinyNAS就是让它变得更聪明的关键。

TinyNAS是一种神经架构搜索技术。你可以把它理解成一个“自动模型设计师”。传统的做法是工程师手动设计网络结构,然后不断尝试调整。TinyNAS不一样,它让算法自己去搜索:在给定的计算资源限制下,什么样的网络结构能达到最好的效果?

这个过程有点像训练一个AI去设计AI。最终得到的模型,在同样的精度下,体积更小、速度更快。这正是我们需要的——既要能准确检测手机,又要快,还要省资源。

2.3 扩展功能:手机型号识别怎么实现的?

原来的系统只能检测“这是手机”,现在我们要让它能识别“这是iPhone 15”。这听起来是个大工程,但其实有比较巧妙的实现方式。

思路是这样的

  1. 先用DAMO-YOLO检测出图片中所有的手机
  2. 把每个检测到的手机区域裁剪出来
  3. 用另一个专门训练的分类模型识别手机型号
  4. 把识别结果和检测框一起显示出来

这个“另一个模型”可以是现成的手机分类模型,也可以自己训练。市面上有不少开源的手机数据集,包含几十种常见型号。训练一个分类模型的技术现在已经很成熟了。

3. 系统搭建:从零开始部署完整流程

3.1 环境准备:你需要什么?

在开始之前,先确认你的环境是否符合要求:

硬件要求(最低配置):

  • CPU:4核以上(推荐8核)
  • 内存:8GB(推荐16GB)
  • 存储:至少10GB可用空间
  • GPU:可选,有GPU会快很多(T4或以上)

软件要求

  • 操作系统:Ubuntu 20.04或更高版本(其他Linux发行版也可以)
  • Python:3.8或3.9(3.11可能有一些包兼容性问题)
  • CUDA:如果你用GPU,需要11.7或以上

3.2 一步一步安装部署

第一步:克隆项目代码

# 创建项目目录
mkdir phone-detection-system
cd phone-detection-system

# 克隆代码(这里用示例仓库,实际需要替换为你的仓库)
git clone https://github.com/your-repo/phone-detection.git
cd phone-detection

第二步:创建Python虚拟环境

# 创建虚拟环境
python -m venv venv

# 激活虚拟环境
source venv/bin/activate

# 在Windows上是:
# venv\Scripts\activate

第三步:安装依赖包

创建一个requirements.txt文件,内容如下:

torch>=2.0.0
torchvision>=0.15.0
gradio>=4.0.0
opencv-python>=4.8.0
pillow>=10.0.0
numpy>=1.24.0
modelscope>=1.9.0
# 如果需要手机型号识别,添加:
# timm>=0.9.0  # 图像分类库
# scikit-learn>=1.3.0  # 机器学习工具

然后安装:

pip install -r requirements.txt

如果你用GPU,可能需要单独安装对应版本的PyTorch:

# 例如CUDA 11.8
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118

第四步:下载模型文件

DAMO-YOLO模型可以通过ModelScope下载:

from modelscope import snapshot_download

model_dir = snapshot_download('damo/cv_tinynas_object-detection_damoyolo')
print(f"模型下载到: {model_dir}")

手机型号分类模型需要另外准备。你可以:

  1. 使用开源的预训练模型
  2. 自己收集数据训练
  3. 使用商业API(如果有的话)

第五步:编写主程序

创建一个app.py文件,这是Web界面的核心:

import gradio as gr
import cv2
import torch
import numpy as np
from PIL import Image
import os

# 这里导入你的检测和识别函数
from detection import detect_phones
from recognition import recognize_phone_model

def process_image(input_image):
    """
    处理上传的图片:检测手机 + 识别型号
    """
    # 将Gradio的图片转换为OpenCV格式
    if isinstance(input_image, str):
        # 如果是文件路径
        image = cv2.imread(input_image)
    else:
        # 如果是numpy数组(Gradio默认)
        image = cv2.cvtColor(input_image, cv2.COLOR_RGB2BGR)
    
    # 第一步:检测手机位置
    detections = detect_phones(image)
    
    # 第二步:对每个检测到的手机识别型号
    results = []
    for det in detections:
        x1, y1, x2, y2, confidence = det
        
        # 裁剪出手机区域
        phone_region = image[y1:y2, x1:x2]
        if phone_region.size == 0:
            continue
            
        # 识别手机型号
        model_name, model_confidence = recognize_phone_model(phone_region)
        
        results.append({
            'bbox': [x1, y1, x2, y2],
            'detection_confidence': confidence,
            'model_name': model_name,
            'model_confidence': model_confidence
        })
    
    # 在图片上绘制结果
    output_image = image.copy()
    for result in results:
        x1, y1, x2, y2 = result['bbox']
        model_name = result['model_name']
        det_conf = result['detection_confidence']
        model_conf = result['model_confidence']
        
        # 画检测框
        cv2.rectangle(output_image, (x1, y1), (x2, y2), (0, 0, 255), 2)
        
        # 准备显示文本
        label = f"{model_name} ({det_conf:.1%}, {model_conf:.1%})"
        
        # 计算文本背景框
        (text_width, text_height), baseline = cv2.getTextSize(
            label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2
        )
        
        # 画文本背景
        cv2.rectangle(output_image, 
                     (x1, y1 - text_height - 10),
                     (x1 + text_width, y1),
                     (0, 0, 255), -1)
        
        # 写文本
        cv2.putText(output_image, label,
                   (x1, y1 - 5),
                   cv2.FONT_HERSHEY_SIMPLEX, 0.5,
                   (255, 255, 255), 2)
    
    # 转换回RGB格式供Gradio显示
    output_image = cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB)
    
    # 准备返回信息
    detection_count = len(results)
    if detection_count == 0:
        result_text = "未检测到手机"
    else:
        result_text = f"检测到 {detection_count} 个手机\n"
        for i, result in enumerate(results, 1):
            result_text += f"\n手机 {i}:\n"
            result_text += f"  型号: {result['model_name']}\n"
            result_text += f"  检测置信度: {result['detection_confidence']:.1%}\n"
            result_text += f"  型号置信度: {result['model_confidence']:.1%}"
    
    return output_image, result_text

# 创建Gradio界面
with gr.Blocks(title="手机检测与型号识别系统") as demo:
    gr.Markdown("# 📱 手机检测与型号识别系统")
    gr.Markdown("上传图片,系统会自动检测手机并识别型号")
    
    with gr.Row():
        with gr.Column(scale=1):
            input_image = gr.Image(label="上传图片", type="numpy")
            upload_button = gr.UploadButton("点击上传图片", file_types=["image"])
            
            gr.Markdown("### 或者使用示例图片")
            example_images = gr.Examples(
                examples=[
                    ["examples/example1.jpg"],
                    ["examples/example2.jpg"],
                    ["examples/example3.jpg"]
                ],
                inputs=[input_image],
                label="快速尝试"
            )
            
            detect_button = gr.Button("开始检测", variant="primary")
        
        with gr.Column(scale=1):
            output_image = gr.Image(label="检测结果", interactive=False)
            result_text = gr.Textbox(label="检测详情", lines=10, interactive=False)
    
    # 绑定事件
    upload_button.upload(
        fn=lambda file: file.name,
        inputs=[upload_button],
        outputs=[input_image]
    )
    
    detect_button.click(
        fn=process_image,
        inputs=[input_image],
        outputs=[output_image, result_text]
    )
    
    # 输入图片变化时自动检测
    input_image.change(
        fn=process_image,
        inputs=[input_image],
        outputs=[output_image, result_text]
    )

# 启动服务
if __name__ == "__main__":
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=False
    )

第六步:创建检测和识别模块

创建detection.py

import torch
import cv2
import numpy as np
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks

class PhoneDetector:
    def __init__(self, model_path=None):
        """
        初始化手机检测器
        """
        # 使用ModelScope的DAMO-YOLO模型
        self.detector = pipeline(
            Tasks.domain_specific_object_detection,
            model='damo/cv_tinynas_object-detection_damoyolo'
        )
        
    def detect(self, image):
        """
        检测图片中的手机
        """
        # 运行检测
        result = self.detector(image)
        
        # 解析结果
        detections = []
        if 'boxes' in result:
            boxes = result['boxes']
            scores = result['scores']
            labels = result['labels']
            
            for box, score, label in zip(boxes, scores, labels):
                # 只保留手机类别(假设label 0是手机)
                if label == 0 and score > 0.5:  # 置信度阈值
                    x1, y1, x2, y2 = box
                    detections.append([x1, y1, x2, y2, float(score)])
        
        return detections

# 全局检测器实例
detector = PhoneDetector()

def detect_phones(image):
    """
    对外提供的检测接口
    """
    return detector.detect(image)

创建recognition.py

import torch
import torch.nn as nn
import cv2
import numpy as np
from PIL import Image
import torchvision.transforms as transforms

class PhoneModelRecognizer:
    def __init__(self, model_path, class_names):
        """
        初始化手机型号识别器
        
        参数:
        - model_path: 训练好的模型路径
        - class_names: 类别名称列表,如 ['iPhone 15', '小米14', '华为Mate 60', ...]
        """
        # 这里假设使用ResNet18作为分类模型
        self.model = self.load_model(model_path)
        self.model.eval()
        
        self.class_names = class_names
        self.num_classes = len(class_names)
        
        # 图像预处理
        self.transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])
    
    def load_model(self, model_path):
        """
        加载预训练模型
        这里需要根据你的实际模型结构来修改
        """
        # 示例:使用ResNet18
        model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=False)
        
        # 修改最后一层,适应你的类别数
        num_features = model.fc.in_features
        model.fc = nn.Linear(num_features, self.num_classes)
        
        # 加载训练好的权重
        if model_path:
            model.load_state_dict(torch.load(model_path))
        
        return model
    
    def recognize(self, image):
        """
        识别手机型号
        
        参数:
        - image: OpenCV格式的图片(BGR)
        
        返回:
        - model_name: 识别出的型号名称
        - confidence: 置信度
        """
        # 转换颜色空间
        image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        pil_image = Image.fromarray(image_rgb)
        
        # 预处理
        input_tensor = self.transform(pil_image)
        input_batch = input_tensor.unsqueeze(0)
        
        # 使用GPU如果可用
        if torch.cuda.is_available():
            input_batch = input_batch.to('cuda')
            self.model.to('cuda')
        
        # 推理
        with torch.no_grad():
            output = self.model(input_batch)
            probabilities = torch.nn.functional.softmax(output[0], dim=0)
        
        # 获取最高置信度的类别
        confidence, predicted_idx = torch.max(probabilities, 0)
        confidence = confidence.item()
        predicted_idx = predicted_idx.item()
        
        model_name = self.class_names[predicted_idx]
        
        return model_name, confidence

# 示例:创建识别器(需要你先训练好模型)
# recognizer = PhoneModelRecognizer(
#     model_path='models/phone_classifier.pth',
#     class_names=['iPhone 15', 'Samsung S24', 'Xiaomi 14', 'Huawei Mate 60']
# )

def recognize_phone_model(image):
    """
    对外提供的识别接口
    
    注意:这是一个示例函数,实际使用时需要:
    1. 训练自己的手机分类模型
    2. 准备类别列表
    3. 加载训练好的模型
    """
    # 这里返回示例数据,实际应该调用recognizer.recognize(image)
    # 为了演示,我们返回一个随机结果
    import random
    
    phone_models = [
        "iPhone 15 Pro Max",
        "Samsung Galaxy S24 Ultra", 
        "Xiaomi 14 Pro",
        "Huawei Mate 60 Pro",
        "OPPO Find X7",
        "vivo X100 Pro"
    ]
    
    model_name = random.choice(phone_models)
    confidence = random.uniform(0.7, 0.95)  # 模拟置信度
    
    return model_name, confidence

第七步:创建启动脚本

创建start.sh

#!/bin/bash

# 启动手机检测服务
cd /root/phone-detection

# 激活虚拟环境
source venv/bin/activate

# 启动Gradio应用
python app.py \
    --server-name 0.0.0.0 \
    --server-port 7860 \
    --share false \
    > logs/app.log 2>&1 &

echo "服务已启动,访问 http://localhost:7860"
echo "查看日志:tail -f logs/app.log"

创建stop.sh

#!/bin/bash

# 停止手机检测服务
pkill -f "python app.py"
echo "服务已停止"

第八步:使用Supervisor管理服务(可选但推荐)

创建配置文件/etc/supervisor/conf.d/phone-detection.conf

[program:phone-detection]
command=/root/phone-detection/venv/bin/python /root/phone-detection/app.py
directory=/root/phone-detection
user=root
autostart=true
autorestart=true
startsecs=10
stopwaitsecs=10
stdout_logfile=/root/phone-detection/logs/access.log
stdout_logfile_maxbytes=10MB
stdout_logfile_backups=5
stderr_logfile=/root/phone-detection/logs/error.log
stderr_logfile_maxbytes=10MB
stderr_logfile_backups=5
environment=PYTHONPATH="/root/phone-detection"

然后重新加载Supervisor配置:

sudo supervisorctl reread
sudo supervisorctl update
sudo supervisorctl start phone-detection

3.3 测试你的系统

现在打开浏览器,访问 http://你的服务器IP:7860,你应该能看到这样的界面:

┌─────────────────────────────────────────────────────────────────────┐
│                    📱 手机检测与型号识别系统                           │
│              基于 DAMO-YOLO + 手机型号分类模型                        │
├──────────────────────┬──────────────────────────────────────────────┤
│   📤 上传图片区域     │         🖼️ 检测结果展示区                     │
│                      │                                              │
│  [选择图片按钮]      │        [这里显示带标注的图片]                  │
│  [拖拽上传提示]      │                                              │
│                      │         📊 检测详情                           │
│  示例图片:           │      检测到 2 个手机                          │
│  ○ 会议室场景        │                                              │
│  ○ 商场柜台          │      手机 1:                                 │
│  ○ 户外广告          │        型号: iPhone 15 Pro Max               │
│                      │        检测置信度: 96.5%                      │
│  [🔍 开始检测按钮]   │        型号置信度: 88.2%                      │
│                      │                                              │
│                      │      手机 2:                                 │
│                      │        型号: 华为 Mate 60 Pro                 │
│                      │        检测置信度: 94.1%                      │
│                      │        型号置信度: 91.5%                      │
└──────────────────────┴──────────────────────────────────────────────┘

上传一张包含手机的图片,系统会自动:

  1. 用红色框标出所有手机位置
  2. 在每个框上方显示手机型号
  3. 在右侧显示详细的检测结果

4. 扩展功能开发:让系统更强大

基础功能有了,但你可能想要更多。下面介绍几个实用的扩展功能。

4.1 批量处理功能

一次上传多张图片,批量处理:

def batch_process(image_files):
    """
    批量处理多张图片
    """
    all_results = []
    
    for image_file in image_files:
        # 处理单张图片
        result_image, result_text = process_image(image_file)
        
        all_results.append({
            'filename': os.path.basename(image_file),
            'image': result_image,
            'result': result_text,
            'detections': parse_detection_result(result_text)
        })
    
    # 生成汇总报告
    summary = generate_summary(all_results)
    
    return all_results, summary

def generate_summary(results):
    """
    生成批量处理汇总报告
    """
    total_images = len(results)
    total_phones = sum(len(r['detections']) for r in results)
    
    # 统计各型号出现次数
    model_counts = {}
    for result in results:
        for det in result['detections']:
            model = det.get('model', '未知')
            model_counts[model] = model_counts.get(model, 0) + 1
    
    # 生成报告文本
    report = f"批量处理完成!\n\n"
    report += f"共处理 {total_images} 张图片\n"
    report += f"检测到 {total_phones} 个手机\n\n"
    report += "型号分布:\n"
    
    for model, count in sorted(model_counts.items(), key=lambda x: x[1], reverse=True):
        percentage = count / total_phones * 100
        report += f"  {model}: {count}个 ({percentage:.1f}%)\n"
    
    return report

在Gradio界面中添加批量上传组件:

with gr.Tab("单张检测"):
    # 原来的单张检测界面

with gr.Tab("批量处理"):
    batch_input = gr.File(
        label="上传多张图片",
        file_types=["image"],
        file_count="multiple"
    )
    batch_button = gr.Button("批量处理", variant="primary")
    batch_output = gr.Gallery(label="处理结果")
    batch_report = gr.Textbox(label="汇总报告", lines=10)
    
    batch_button.click(
        fn=batch_process,
        inputs=[batch_input],
        outputs=[batch_output, batch_report]
    )

4.2 导出检测结果

用户可能想把结果保存下来:

import json
import csv
from datetime import datetime

def export_results(results, format='json'):
    """
    导出检测结果
    
    支持格式:json, csv, txt
    """
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    if format == 'json':
        filename = f"detection_results_{timestamp}.json"
        with open(filename, 'w', encoding='utf-8') as f:
            json.dump(results, f, ensure_ascii=False, indent=2)
    
    elif format == 'csv':
        filename = f"detection_results_{timestamp}.csv"
        with open(filename, 'w', newline='', encoding='utf-8') as f:
            writer = csv.writer(f)
            writer.writerow(['图片名称', '手机序号', '型号', '检测置信度', '型号置信度', 
                           '左上角X', '左上角Y', '右下角X', '右下角Y'])
            
            for result in results:
                for i, det in enumerate(result['detections'], 1):
                    writer.writerow([
                        result['filename'],
                        i,
                        det.get('model', '未知'),
                        f"{det.get('detection_confidence', 0):.1%}",
                        f"{det.get('model_confidence', 0):.1%}",
                        det['bbox'][0], det['bbox'][1],
                        det['bbox'][2], det['bbox'][3]
                    ])
    
    elif format == 'txt':
        filename = f"detection_results_{timestamp}.txt"
        with open(filename, 'w', encoding='utf-8') as f:
            for result in results:
                f.write(f"图片: {result['filename']}\n")
                f.write(f"检测到 {len(result['detections'])} 个手机\n")
                for i, det in enumerate(result['detections'], 1):
                    f.write(f"\n手机 {i}:\n")
                    f.write(f"  型号: {det.get('model', '未知')}\n")
                    f.write(f"  检测置信度: {det.get('detection_confidence', 0):.1%}\n")
                    f.write(f"  型号置信度: {det.get('model_confidence', 0):.1%}\n")
                    f.write(f"  位置: ({det['bbox'][0]}, {det['bbox'][1]}) "
                           f"到 ({det['bbox'][2]}, {det['bbox'][3]})\n")
                f.write("\n" + "="*50 + "\n")
    
    return filename

在界面中添加导出按钮:

with gr.Row():
    export_json = gr.Button("导出JSON")
    export_csv = gr.Button("导出CSV")
    export_txt = gr.Button("导出TXT")

export_json.click(
    fn=lambda: export_results(current_results, 'json'),
    outputs=gr.File(label="下载文件")
)

4.3 实时视频流处理

虽然当前版本主要处理图片,但扩展到视频流也不难:

import cv2
import threading
import queue
import time

class VideoProcessor:
    def __init__(self, camera_url=0):
        """
        初始化视频处理器
        
        camera_url: 摄像头URL或索引,0表示默认摄像头
        """
        self.camera_url = camera_url
        self.cap = None
        self.processing = False
        self.frame_queue = queue.Queue(maxsize=10)
        self.result_queue = queue.Queue(maxsize=10)
        
    def start_capture(self):
        """开始捕获视频"""
        self.cap = cv2.VideoCapture(self.camera_url)
        if not self.cap.isOpened():
            raise ValueError("无法打开摄像头")
        
        self.processing = True
        
        # 启动捕获线程
        capture_thread = threading.Thread(target=self._capture_frames)
        capture_thread.daemon = True
        capture_thread.start()
        
        # 启动处理线程
        process_thread = threading.Thread(target=self._process_frames)
        process_thread.daemon = True
        process_thread.start()
        
    def _capture_frames(self):
        """捕获帧的线程函数"""
        while self.processing:
            ret, frame = self.cap.read()
            if not ret:
                break
                
            # 限制帧率,避免队列积压
            if self.frame_queue.qsize() < 5:
                self.frame_queue.put(frame)
            else:
                time.sleep(0.01)
                
    def _process_frames(self):
        """处理帧的线程函数"""
        while self.processing:
            try:
                frame = self.frame_queue.get(timeout=1)
                
                # 处理帧(检测+识别)
                processed_frame, results = process_video_frame(frame)
                
                # 将结果放入队列
                if self.result_queue.qsize() < 5:
                    self.result_queue.put((processed_frame, results))
                    
            except queue.Empty:
                continue
                
    def get_frame(self):
        """获取处理后的帧"""
        try:
            return self.result_queue.get(timeout=0.1)
        except queue.Empty:
            return None, None
            
    def stop(self):
        """停止处理"""
        self.processing = False
        if self.cap:
            self.cap.release()

def process_video_frame(frame):
    """
    处理视频帧
    
    为了实时性,这里可以做些优化:
    1. 降低检测频率(比如每3帧检测一次)
    2. 使用更小的输入尺寸
    3. 缓存识别结果
    """
    # 这里可以添加帧率控制逻辑
    # 例如:每3帧进行一次完整检测,中间帧使用跟踪
    
    # 实际的处理逻辑
    detections = detect_phones(frame)
    
    # 绘制结果
    output_frame = frame.copy()
    for det in detections:
        x1, y1, x2, y2, confidence = det
        
        # 裁剪手机区域进行识别
        phone_region = frame[y1:y2, x1:x2]
        if phone_region.size > 0:
            model_name, model_conf = recognize_phone_model(phone_region)
            
            # 绘制框和标签
            cv2.rectangle(output_frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
            label = f"{model_name} ({confidence:.1%})"
            cv2.putText(output_frame, label, (x1, y1-10),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
    
    return output_frame, detections

在Gradio中添加视频处理界面:

with gr.Tab("实时视频"):
    video_output = gr.Image(label="实时检测", streaming=True)
    start_button = gr.Button("开始摄像头", variant="primary")
    stop_button = gr.Button("停止摄像头", variant="secondary")
    
    # 视频处理器实例
    video_processor = gr.State(None)
    
    def start_video():
        """启动视频处理"""
        processor = VideoProcessor(camera_url=0)
        processor.start_capture()
        return processor
    
    def update_video_frame(processor):
        """更新视频帧"""
        if processor:
            frame, _ = processor.get_frame()
            if frame is not None:
                # 转换颜色空间
                frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                return frame_rgb
        return None
    
    def stop_video(processor):
        """停止视频处理"""
        if processor:
            processor.stop()
        return None
    
    start_button.click(
        fn=start_video,
        outputs=video_processor
    ).then(
        fn=lambda: None,
        outputs=video_output
    )
    
    # 设置视频流
    video_output.stream(
        fn=update_video_frame,
        inputs=[video_processor],
        outputs=[video_output],
        every=0.03  # 约30fps
    )
    
    stop_button.click(
        fn=stop_video,
        inputs=[video_processor],
        outputs=[video_processor]
    )

5. 性能优化与实用技巧

5.1 加速推理的几种方法

使用GPU加速

import torch

# 检查GPU是否可用
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}")

# 将模型移到GPU
model.to(device)

# 在推理时,确保输入数据也在GPU上
input_tensor = input_tensor.to(device)

使用半精度浮点数

# 如果GPU支持,使用半精度可以大幅减少显存占用并加速
model.half()  # 转换为半精度
input_tensor = input_tensor.half()

批处理

def batch_detect(images):
    """
    批量检测,比单张检测更高效
    """
    # 将多张图片堆叠成一个批次
    batch_tensor = torch.stack(images)
    
    # 批量推理
    with torch.no_grad():
        outputs = model(batch_tensor)
    
    return outputs

使用TensorRT加速(高级优化):

# 将PyTorch模型转换为ONNX
torch.onnx.export(model, dummy_input, "model.onnx")

# 然后使用TensorRT优化ONNX模型
# 这需要安装TensorRT并编写转换脚本

5.2 提高识别准确率

数据增强

from torchvision import transforms

# 训练时使用数据增强
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# 推理时使用简单的预处理
infer_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

模型集成

class EnsembleModel:
    """多个模型的集成,提高准确率"""
    def __init__(self, model_paths):
        self.models = []
        for path in model_paths:
            model = load_model(path)
            model.eval()
            self.models.append(model)
    
    def predict(self, image):
        predictions = []
        confidences = []
        
        for model in self.models:
            pred, conf = model.predict(image)
            predictions.append(pred)
            confidences.append(conf)
        
        # 投票或加权平均
        # 这里使用简单投票
        from collections import Counter
        most_common = Counter(predictions).most_common(1)[0][0]
        
        # 计算平均置信度
        avg_confidence = sum(confidences) / len(confidences)
        
        return most_common, avg_confidence

后处理优化

def post_process(detections, model_predictions, min_confidence=0.5):
    """
    后处理:过滤低置信度结果,合并重叠框等
    """
    results = []
    
    for det, model_pred in zip(detections, model_predictions):
        bbox, det_conf = det
        model_name, model_conf = model_pred
        
        # 综合置信度
        combined_conf = det_conf * 0.6 + model_conf * 0.4
        
        if combined_conf >= min_confidence:
            results.append({
                'bbox': bbox,
                'model': model_name,
                'detection_confidence': det_conf,
                'model_confidence': model_conf,
                'combined_confidence': combined_conf
            })
    
    # 非极大值抑制,去除重叠框
    results = non_max_suppression(results, iou_threshold=0.5)
    
    return results

def non_max_suppression(results, iou_threshold=0.5):
    """
    非极大值抑制
    """
    if not results:
        return []
    
    # 按置信度排序
    results.sort(key=lambda x: x['combined_confidence'], reverse=True)
    
    keep = []
    while results:
        # 取置信度最高的
        best = results.pop(0)
        keep.append(best)
        
        # 移除与best重叠度高的
        results = [r for r in results if 
                  calculate_iou(best['bbox'], r['bbox']) < iou_threshold]
    
    return keep

5.3 实际部署建议

资源监控

import psutil
import time

class ResourceMonitor:
    """监控系统资源使用情况"""
    def __init__(self):
        self.start_time = time.time()
        self.frame_count = 0
    
    def get_status(self):
        """获取当前状态"""
        # CPU使用率
        cpu_percent = psutil.cpu_percent(interval=1)
        
        # 内存使用
        memory = psutil.virtual_memory()
        
        # GPU信息(如果可用)
        gpu_info = self.get_gpu_info()
        
        # 帧率
        elapsed = time.time() - self.start_time
        fps = self.frame_count / elapsed if elapsed > 0 else 0
        
        return {
            'cpu_percent': cpu_percent,
            'memory_percent': memory.percent,
            'memory_used_gb': memory.used / 1024**3,
            'gpu_info': gpu_info,
            'fps': fps,
            'uptime': elapsed,
            'total_frames': self.frame_count
        }
    
    def increment_frame(self):
        """增加帧计数"""
        self.frame_count += 1
    
    def get_gpu_info(self):
        """获取GPU信息"""
        try:
            import pynvml
            pynvml.nvmlInit()
            
            gpu_count = pynvml.nvmlDeviceGetCount()
            gpus = []
            
            for i in range(gpu_count):
                handle = pynvml.nvmlDeviceGetHandleByIndex(i)
                util = pynvml.nvmlDeviceGetUtilizationRates(handle)
                memory = pynvml.nvmlDeviceGetMemoryInfo(handle)
                
                gpus.append({
                    'index': i,
                    'name': pynvml.nvmlDeviceGetName(handle),
                    'gpu_util': util.gpu,
                    'memory_used_mb': memory.used / 1024**2,
                    'memory_total_mb': memory.total / 1024**2,
                    'memory_percent': memory.used / memory.total * 100
                })
            
            pynvml.nvmlShutdown()
            return gpus
            
        except ImportError:
            return None

添加监控界面

with gr.Accordion("系统监控", open=False):
    monitor_output = gr.JSON(label="系统状态", every=2)  # 每2秒更新一次
    
    def update_monitor():
        status = resource_monitor.get_status()
        return status
    
    # 自动更新监控信息
    monitor_output.change(
        fn=update_monitor,
        outputs=monitor_output,
        every=2
    )

6. 总结

6.1 项目回顾

我们从头构建了一个完整的手机检测与型号识别系统,核心功能包括:

  1. 基础检测功能:基于DAMO-YOLO的实时手机检测,准确率88.8%,速度3.83ms/张
  2. 型号识别扩展:在检测基础上增加手机型号识别,让系统更有实用价值
  3. 友好Web界面:使用Gradio构建,支持图片上传、批量处理、实时视频
  4. 完整部署方案:从环境搭建到服务管理,提供了一站式解决方案

6.2 关键收获

技术层面

  • DAMO-YOLO确实做到了“小、快、省”,特别适合边缘部署
  • Gradio让AI应用开发变得简单,快速构建Web界面
  • 模型组合(检测+分类)是解决复杂问题的有效思路

工程层面

  • 良好的代码结构让系统易于维护和扩展
  • 资源监控和性能优化是生产部署的关键
  • 错误处理和日志记录不能忽视

6.3 下一步建议

如果你想让这个系统更强大,可以考虑:

  1. 扩充型号库:收集更多手机型号数据,训练更全面的分类模型
  2. 优化性能:尝试TensorRT加速、模型量化、蒸馏等技术
  3. 增加功能:比如手机状态识别(是否在充电、屏幕是否亮着)
  4. 云端部署:将系统部署到云端,提供API服务
  5. 移动端适配:开发手机App,随时随地使用

这个项目的价值在于它展示了一个完整的AI应用开发流程:从模型选择、功能扩展、界面开发到部署运维。你可以基于这个框架,轻松扩展到其他物体检测和识别任务。


获取更多AI镜像

想探索更多AI镜像和应用场景?访问 CSDN星图镜像广场,提供丰富的预置镜像,覆盖大模型推理、图像生成、视频生成、模型微调等多个领域,支持一键部署。

创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值