Gradio自定义组件:DAMO-YOLO手机检测WebUI集成手机型号识别扩展功能
1. 项目简介:从“找到手机”到“认出手机”
如果你用过一些AI检测工具,可能会发现一个挺普遍的情况:系统能告诉你“这里有个手机”,但也就到此为止了。它不会告诉你这是iPhone 15 Pro Max还是小米14 Ultra,是华为Mate 60还是三星Galaxy S24。
今天要介绍的这个项目,就是在解决这个问题。它基于一个叫DAMO-YOLO的轻量级检测模型,原本只能检测手机的位置。我们给它加了个“大脑”——一个手机型号识别扩展功能。现在它不仅能框出手机在哪里,还能告诉你这是什么型号的手机。
想象一下这些场景:
- 商场里分析顾客都用什么手机,好调整产品陈列
- 二手手机交易平台,自动识别上传的手机型号
- 内容审核时,快速识别视频中出现的手机品牌
- 市场调研,统计不同场合的手机品牌分布
这个项目的核心价值很简单:让检测更有用。不是仅仅知道“有手机”,而是知道“有什么手机”。
2. 技术核心:DAMO-YOLO + TinyNAS的“小快省”组合
2.1 为什么选DAMO-YOLO?
你可能听说过YOLO系列模型,从YOLOv1到现在的YOLOv11,版本多得让人眼花缭乱。DAMO-YOLO是阿里巴巴达摩院推出的一个变种,它的特点可以用三个字概括:小、快、省。
小指的是模型体积小。完整的DAMO-YOLO-S模型只有125MB左右,相比动辄几个G的大模型,它轻巧得像个精灵。
快说的是推理速度快。在T4 GPU上,处理一张640x640的图片只需要3.83毫秒。什么概念?一秒钟能处理260多张图片,真正的实时检测。
省是省资源。它专门为移动端和边缘设备优化,对算力要求低,功耗也小。这意味着你可以在普通的服务器上部署,甚至在一些性能不错的开发板上也能跑起来。
2.2 TinyNAS:让模型更聪明的“瘦身术”
如果说DAMO-YOLO是基础,那TinyNAS就是让它变得更聪明的关键。
TinyNAS是一种神经架构搜索技术。你可以把它理解成一个“自动模型设计师”。传统的做法是工程师手动设计网络结构,然后不断尝试调整。TinyNAS不一样,它让算法自己去搜索:在给定的计算资源限制下,什么样的网络结构能达到最好的效果?
这个过程有点像训练一个AI去设计AI。最终得到的模型,在同样的精度下,体积更小、速度更快。这正是我们需要的——既要能准确检测手机,又要快,还要省资源。
2.3 扩展功能:手机型号识别怎么实现的?
原来的系统只能检测“这是手机”,现在我们要让它能识别“这是iPhone 15”。这听起来是个大工程,但其实有比较巧妙的实现方式。
思路是这样的:
- 先用DAMO-YOLO检测出图片中所有的手机
- 把每个检测到的手机区域裁剪出来
- 用另一个专门训练的分类模型识别手机型号
- 把识别结果和检测框一起显示出来
这个“另一个模型”可以是现成的手机分类模型,也可以自己训练。市面上有不少开源的手机数据集,包含几十种常见型号。训练一个分类模型的技术现在已经很成熟了。
3. 系统搭建:从零开始部署完整流程
3.1 环境准备:你需要什么?
在开始之前,先确认你的环境是否符合要求:
硬件要求(最低配置):
- CPU:4核以上(推荐8核)
- 内存:8GB(推荐16GB)
- 存储:至少10GB可用空间
- GPU:可选,有GPU会快很多(T4或以上)
软件要求:
- 操作系统:Ubuntu 20.04或更高版本(其他Linux发行版也可以)
- Python:3.8或3.9(3.11可能有一些包兼容性问题)
- CUDA:如果你用GPU,需要11.7或以上
3.2 一步一步安装部署
第一步:克隆项目代码
# 创建项目目录
mkdir phone-detection-system
cd phone-detection-system
# 克隆代码(这里用示例仓库,实际需要替换为你的仓库)
git clone https://github.com/your-repo/phone-detection.git
cd phone-detection
第二步:创建Python虚拟环境
# 创建虚拟环境
python -m venv venv
# 激活虚拟环境
source venv/bin/activate
# 在Windows上是:
# venv\Scripts\activate
第三步:安装依赖包
创建一个requirements.txt文件,内容如下:
torch>=2.0.0
torchvision>=0.15.0
gradio>=4.0.0
opencv-python>=4.8.0
pillow>=10.0.0
numpy>=1.24.0
modelscope>=1.9.0
# 如果需要手机型号识别,添加:
# timm>=0.9.0 # 图像分类库
# scikit-learn>=1.3.0 # 机器学习工具
然后安装:
pip install -r requirements.txt
如果你用GPU,可能需要单独安装对应版本的PyTorch:
# 例如CUDA 11.8
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118
第四步:下载模型文件
DAMO-YOLO模型可以通过ModelScope下载:
from modelscope import snapshot_download
model_dir = snapshot_download('damo/cv_tinynas_object-detection_damoyolo')
print(f"模型下载到: {model_dir}")
手机型号分类模型需要另外准备。你可以:
- 使用开源的预训练模型
- 自己收集数据训练
- 使用商业API(如果有的话)
第五步:编写主程序
创建一个app.py文件,这是Web界面的核心:
import gradio as gr
import cv2
import torch
import numpy as np
from PIL import Image
import os
# 这里导入你的检测和识别函数
from detection import detect_phones
from recognition import recognize_phone_model
def process_image(input_image):
"""
处理上传的图片:检测手机 + 识别型号
"""
# 将Gradio的图片转换为OpenCV格式
if isinstance(input_image, str):
# 如果是文件路径
image = cv2.imread(input_image)
else:
# 如果是numpy数组(Gradio默认)
image = cv2.cvtColor(input_image, cv2.COLOR_RGB2BGR)
# 第一步:检测手机位置
detections = detect_phones(image)
# 第二步:对每个检测到的手机识别型号
results = []
for det in detections:
x1, y1, x2, y2, confidence = det
# 裁剪出手机区域
phone_region = image[y1:y2, x1:x2]
if phone_region.size == 0:
continue
# 识别手机型号
model_name, model_confidence = recognize_phone_model(phone_region)
results.append({
'bbox': [x1, y1, x2, y2],
'detection_confidence': confidence,
'model_name': model_name,
'model_confidence': model_confidence
})
# 在图片上绘制结果
output_image = image.copy()
for result in results:
x1, y1, x2, y2 = result['bbox']
model_name = result['model_name']
det_conf = result['detection_confidence']
model_conf = result['model_confidence']
# 画检测框
cv2.rectangle(output_image, (x1, y1), (x2, y2), (0, 0, 255), 2)
# 准备显示文本
label = f"{model_name} ({det_conf:.1%}, {model_conf:.1%})"
# 计算文本背景框
(text_width, text_height), baseline = cv2.getTextSize(
label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2
)
# 画文本背景
cv2.rectangle(output_image,
(x1, y1 - text_height - 10),
(x1 + text_width, y1),
(0, 0, 255), -1)
# 写文本
cv2.putText(output_image, label,
(x1, y1 - 5),
cv2.FONT_HERSHEY_SIMPLEX, 0.5,
(255, 255, 255), 2)
# 转换回RGB格式供Gradio显示
output_image = cv2.cvtColor(output_image, cv2.COLOR_BGR2RGB)
# 准备返回信息
detection_count = len(results)
if detection_count == 0:
result_text = "未检测到手机"
else:
result_text = f"检测到 {detection_count} 个手机\n"
for i, result in enumerate(results, 1):
result_text += f"\n手机 {i}:\n"
result_text += f" 型号: {result['model_name']}\n"
result_text += f" 检测置信度: {result['detection_confidence']:.1%}\n"
result_text += f" 型号置信度: {result['model_confidence']:.1%}"
return output_image, result_text
# 创建Gradio界面
with gr.Blocks(title="手机检测与型号识别系统") as demo:
gr.Markdown("# 📱 手机检测与型号识别系统")
gr.Markdown("上传图片,系统会自动检测手机并识别型号")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(label="上传图片", type="numpy")
upload_button = gr.UploadButton("点击上传图片", file_types=["image"])
gr.Markdown("### 或者使用示例图片")
example_images = gr.Examples(
examples=[
["examples/example1.jpg"],
["examples/example2.jpg"],
["examples/example3.jpg"]
],
inputs=[input_image],
label="快速尝试"
)
detect_button = gr.Button("开始检测", variant="primary")
with gr.Column(scale=1):
output_image = gr.Image(label="检测结果", interactive=False)
result_text = gr.Textbox(label="检测详情", lines=10, interactive=False)
# 绑定事件
upload_button.upload(
fn=lambda file: file.name,
inputs=[upload_button],
outputs=[input_image]
)
detect_button.click(
fn=process_image,
inputs=[input_image],
outputs=[output_image, result_text]
)
# 输入图片变化时自动检测
input_image.change(
fn=process_image,
inputs=[input_image],
outputs=[output_image, result_text]
)
# 启动服务
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)
第六步:创建检测和识别模块
创建detection.py:
import torch
import cv2
import numpy as np
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
class PhoneDetector:
def __init__(self, model_path=None):
"""
初始化手机检测器
"""
# 使用ModelScope的DAMO-YOLO模型
self.detector = pipeline(
Tasks.domain_specific_object_detection,
model='damo/cv_tinynas_object-detection_damoyolo'
)
def detect(self, image):
"""
检测图片中的手机
"""
# 运行检测
result = self.detector(image)
# 解析结果
detections = []
if 'boxes' in result:
boxes = result['boxes']
scores = result['scores']
labels = result['labels']
for box, score, label in zip(boxes, scores, labels):
# 只保留手机类别(假设label 0是手机)
if label == 0 and score > 0.5: # 置信度阈值
x1, y1, x2, y2 = box
detections.append([x1, y1, x2, y2, float(score)])
return detections
# 全局检测器实例
detector = PhoneDetector()
def detect_phones(image):
"""
对外提供的检测接口
"""
return detector.detect(image)
创建recognition.py:
import torch
import torch.nn as nn
import cv2
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
class PhoneModelRecognizer:
def __init__(self, model_path, class_names):
"""
初始化手机型号识别器
参数:
- model_path: 训练好的模型路径
- class_names: 类别名称列表,如 ['iPhone 15', '小米14', '华为Mate 60', ...]
"""
# 这里假设使用ResNet18作为分类模型
self.model = self.load_model(model_path)
self.model.eval()
self.class_names = class_names
self.num_classes = len(class_names)
# 图像预处理
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
def load_model(self, model_path):
"""
加载预训练模型
这里需要根据你的实际模型结构来修改
"""
# 示例:使用ResNet18
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=False)
# 修改最后一层,适应你的类别数
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, self.num_classes)
# 加载训练好的权重
if model_path:
model.load_state_dict(torch.load(model_path))
return model
def recognize(self, image):
"""
识别手机型号
参数:
- image: OpenCV格式的图片(BGR)
返回:
- model_name: 识别出的型号名称
- confidence: 置信度
"""
# 转换颜色空间
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(image_rgb)
# 预处理
input_tensor = self.transform(pil_image)
input_batch = input_tensor.unsqueeze(0)
# 使用GPU如果可用
if torch.cuda.is_available():
input_batch = input_batch.to('cuda')
self.model.to('cuda')
# 推理
with torch.no_grad():
output = self.model(input_batch)
probabilities = torch.nn.functional.softmax(output[0], dim=0)
# 获取最高置信度的类别
confidence, predicted_idx = torch.max(probabilities, 0)
confidence = confidence.item()
predicted_idx = predicted_idx.item()
model_name = self.class_names[predicted_idx]
return model_name, confidence
# 示例:创建识别器(需要你先训练好模型)
# recognizer = PhoneModelRecognizer(
# model_path='models/phone_classifier.pth',
# class_names=['iPhone 15', 'Samsung S24', 'Xiaomi 14', 'Huawei Mate 60']
# )
def recognize_phone_model(image):
"""
对外提供的识别接口
注意:这是一个示例函数,实际使用时需要:
1. 训练自己的手机分类模型
2. 准备类别列表
3. 加载训练好的模型
"""
# 这里返回示例数据,实际应该调用recognizer.recognize(image)
# 为了演示,我们返回一个随机结果
import random
phone_models = [
"iPhone 15 Pro Max",
"Samsung Galaxy S24 Ultra",
"Xiaomi 14 Pro",
"Huawei Mate 60 Pro",
"OPPO Find X7",
"vivo X100 Pro"
]
model_name = random.choice(phone_models)
confidence = random.uniform(0.7, 0.95) # 模拟置信度
return model_name, confidence
第七步:创建启动脚本
创建start.sh:
#!/bin/bash
# 启动手机检测服务
cd /root/phone-detection
# 激活虚拟环境
source venv/bin/activate
# 启动Gradio应用
python app.py \
--server-name 0.0.0.0 \
--server-port 7860 \
--share false \
> logs/app.log 2>&1 &
echo "服务已启动,访问 http://localhost:7860"
echo "查看日志:tail -f logs/app.log"
创建stop.sh:
#!/bin/bash
# 停止手机检测服务
pkill -f "python app.py"
echo "服务已停止"
第八步:使用Supervisor管理服务(可选但推荐)
创建配置文件/etc/supervisor/conf.d/phone-detection.conf:
[program:phone-detection]
command=/root/phone-detection/venv/bin/python /root/phone-detection/app.py
directory=/root/phone-detection
user=root
autostart=true
autorestart=true
startsecs=10
stopwaitsecs=10
stdout_logfile=/root/phone-detection/logs/access.log
stdout_logfile_maxbytes=10MB
stdout_logfile_backups=5
stderr_logfile=/root/phone-detection/logs/error.log
stderr_logfile_maxbytes=10MB
stderr_logfile_backups=5
environment=PYTHONPATH="/root/phone-detection"
然后重新加载Supervisor配置:
sudo supervisorctl reread
sudo supervisorctl update
sudo supervisorctl start phone-detection
3.3 测试你的系统
现在打开浏览器,访问 http://你的服务器IP:7860,你应该能看到这样的界面:
┌─────────────────────────────────────────────────────────────────────┐
│ 📱 手机检测与型号识别系统 │
│ 基于 DAMO-YOLO + 手机型号分类模型 │
├──────────────────────┬──────────────────────────────────────────────┤
│ 📤 上传图片区域 │ 🖼️ 检测结果展示区 │
│ │ │
│ [选择图片按钮] │ [这里显示带标注的图片] │
│ [拖拽上传提示] │ │
│ │ 📊 检测详情 │
│ 示例图片: │ 检测到 2 个手机 │
│ ○ 会议室场景 │ │
│ ○ 商场柜台 │ 手机 1: │
│ ○ 户外广告 │ 型号: iPhone 15 Pro Max │
│ │ 检测置信度: 96.5% │
│ [🔍 开始检测按钮] │ 型号置信度: 88.2% │
│ │ │
│ │ 手机 2: │
│ │ 型号: 华为 Mate 60 Pro │
│ │ 检测置信度: 94.1% │
│ │ 型号置信度: 91.5% │
└──────────────────────┴──────────────────────────────────────────────┘
上传一张包含手机的图片,系统会自动:
- 用红色框标出所有手机位置
- 在每个框上方显示手机型号
- 在右侧显示详细的检测结果
4. 扩展功能开发:让系统更强大
基础功能有了,但你可能想要更多。下面介绍几个实用的扩展功能。
4.1 批量处理功能
一次上传多张图片,批量处理:
def batch_process(image_files):
"""
批量处理多张图片
"""
all_results = []
for image_file in image_files:
# 处理单张图片
result_image, result_text = process_image(image_file)
all_results.append({
'filename': os.path.basename(image_file),
'image': result_image,
'result': result_text,
'detections': parse_detection_result(result_text)
})
# 生成汇总报告
summary = generate_summary(all_results)
return all_results, summary
def generate_summary(results):
"""
生成批量处理汇总报告
"""
total_images = len(results)
total_phones = sum(len(r['detections']) for r in results)
# 统计各型号出现次数
model_counts = {}
for result in results:
for det in result['detections']:
model = det.get('model', '未知')
model_counts[model] = model_counts.get(model, 0) + 1
# 生成报告文本
report = f"批量处理完成!\n\n"
report += f"共处理 {total_images} 张图片\n"
report += f"检测到 {total_phones} 个手机\n\n"
report += "型号分布:\n"
for model, count in sorted(model_counts.items(), key=lambda x: x[1], reverse=True):
percentage = count / total_phones * 100
report += f" {model}: {count}个 ({percentage:.1f}%)\n"
return report
在Gradio界面中添加批量上传组件:
with gr.Tab("单张检测"):
# 原来的单张检测界面
with gr.Tab("批量处理"):
batch_input = gr.File(
label="上传多张图片",
file_types=["image"],
file_count="multiple"
)
batch_button = gr.Button("批量处理", variant="primary")
batch_output = gr.Gallery(label="处理结果")
batch_report = gr.Textbox(label="汇总报告", lines=10)
batch_button.click(
fn=batch_process,
inputs=[batch_input],
outputs=[batch_output, batch_report]
)
4.2 导出检测结果
用户可能想把结果保存下来:
import json
import csv
from datetime import datetime
def export_results(results, format='json'):
"""
导出检测结果
支持格式:json, csv, txt
"""
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
if format == 'json':
filename = f"detection_results_{timestamp}.json"
with open(filename, 'w', encoding='utf-8') as f:
json.dump(results, f, ensure_ascii=False, indent=2)
elif format == 'csv':
filename = f"detection_results_{timestamp}.csv"
with open(filename, 'w', newline='', encoding='utf-8') as f:
writer = csv.writer(f)
writer.writerow(['图片名称', '手机序号', '型号', '检测置信度', '型号置信度',
'左上角X', '左上角Y', '右下角X', '右下角Y'])
for result in results:
for i, det in enumerate(result['detections'], 1):
writer.writerow([
result['filename'],
i,
det.get('model', '未知'),
f"{det.get('detection_confidence', 0):.1%}",
f"{det.get('model_confidence', 0):.1%}",
det['bbox'][0], det['bbox'][1],
det['bbox'][2], det['bbox'][3]
])
elif format == 'txt':
filename = f"detection_results_{timestamp}.txt"
with open(filename, 'w', encoding='utf-8') as f:
for result in results:
f.write(f"图片: {result['filename']}\n")
f.write(f"检测到 {len(result['detections'])} 个手机\n")
for i, det in enumerate(result['detections'], 1):
f.write(f"\n手机 {i}:\n")
f.write(f" 型号: {det.get('model', '未知')}\n")
f.write(f" 检测置信度: {det.get('detection_confidence', 0):.1%}\n")
f.write(f" 型号置信度: {det.get('model_confidence', 0):.1%}\n")
f.write(f" 位置: ({det['bbox'][0]}, {det['bbox'][1]}) "
f"到 ({det['bbox'][2]}, {det['bbox'][3]})\n")
f.write("\n" + "="*50 + "\n")
return filename
在界面中添加导出按钮:
with gr.Row():
export_json = gr.Button("导出JSON")
export_csv = gr.Button("导出CSV")
export_txt = gr.Button("导出TXT")
export_json.click(
fn=lambda: export_results(current_results, 'json'),
outputs=gr.File(label="下载文件")
)
4.3 实时视频流处理
虽然当前版本主要处理图片,但扩展到视频流也不难:
import cv2
import threading
import queue
import time
class VideoProcessor:
def __init__(self, camera_url=0):
"""
初始化视频处理器
camera_url: 摄像头URL或索引,0表示默认摄像头
"""
self.camera_url = camera_url
self.cap = None
self.processing = False
self.frame_queue = queue.Queue(maxsize=10)
self.result_queue = queue.Queue(maxsize=10)
def start_capture(self):
"""开始捕获视频"""
self.cap = cv2.VideoCapture(self.camera_url)
if not self.cap.isOpened():
raise ValueError("无法打开摄像头")
self.processing = True
# 启动捕获线程
capture_thread = threading.Thread(target=self._capture_frames)
capture_thread.daemon = True
capture_thread.start()
# 启动处理线程
process_thread = threading.Thread(target=self._process_frames)
process_thread.daemon = True
process_thread.start()
def _capture_frames(self):
"""捕获帧的线程函数"""
while self.processing:
ret, frame = self.cap.read()
if not ret:
break
# 限制帧率,避免队列积压
if self.frame_queue.qsize() < 5:
self.frame_queue.put(frame)
else:
time.sleep(0.01)
def _process_frames(self):
"""处理帧的线程函数"""
while self.processing:
try:
frame = self.frame_queue.get(timeout=1)
# 处理帧(检测+识别)
processed_frame, results = process_video_frame(frame)
# 将结果放入队列
if self.result_queue.qsize() < 5:
self.result_queue.put((processed_frame, results))
except queue.Empty:
continue
def get_frame(self):
"""获取处理后的帧"""
try:
return self.result_queue.get(timeout=0.1)
except queue.Empty:
return None, None
def stop(self):
"""停止处理"""
self.processing = False
if self.cap:
self.cap.release()
def process_video_frame(frame):
"""
处理视频帧
为了实时性,这里可以做些优化:
1. 降低检测频率(比如每3帧检测一次)
2. 使用更小的输入尺寸
3. 缓存识别结果
"""
# 这里可以添加帧率控制逻辑
# 例如:每3帧进行一次完整检测,中间帧使用跟踪
# 实际的处理逻辑
detections = detect_phones(frame)
# 绘制结果
output_frame = frame.copy()
for det in detections:
x1, y1, x2, y2, confidence = det
# 裁剪手机区域进行识别
phone_region = frame[y1:y2, x1:x2]
if phone_region.size > 0:
model_name, model_conf = recognize_phone_model(phone_region)
# 绘制框和标签
cv2.rectangle(output_frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
label = f"{model_name} ({confidence:.1%})"
cv2.putText(output_frame, label, (x1, y1-10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
return output_frame, detections
在Gradio中添加视频处理界面:
with gr.Tab("实时视频"):
video_output = gr.Image(label="实时检测", streaming=True)
start_button = gr.Button("开始摄像头", variant="primary")
stop_button = gr.Button("停止摄像头", variant="secondary")
# 视频处理器实例
video_processor = gr.State(None)
def start_video():
"""启动视频处理"""
processor = VideoProcessor(camera_url=0)
processor.start_capture()
return processor
def update_video_frame(processor):
"""更新视频帧"""
if processor:
frame, _ = processor.get_frame()
if frame is not None:
# 转换颜色空间
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
return frame_rgb
return None
def stop_video(processor):
"""停止视频处理"""
if processor:
processor.stop()
return None
start_button.click(
fn=start_video,
outputs=video_processor
).then(
fn=lambda: None,
outputs=video_output
)
# 设置视频流
video_output.stream(
fn=update_video_frame,
inputs=[video_processor],
outputs=[video_output],
every=0.03 # 约30fps
)
stop_button.click(
fn=stop_video,
inputs=[video_processor],
outputs=[video_processor]
)
5. 性能优化与实用技巧
5.1 加速推理的几种方法
使用GPU加速:
import torch
# 检查GPU是否可用
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"使用设备: {device}")
# 将模型移到GPU
model.to(device)
# 在推理时,确保输入数据也在GPU上
input_tensor = input_tensor.to(device)
使用半精度浮点数:
# 如果GPU支持,使用半精度可以大幅减少显存占用并加速
model.half() # 转换为半精度
input_tensor = input_tensor.half()
批处理:
def batch_detect(images):
"""
批量检测,比单张检测更高效
"""
# 将多张图片堆叠成一个批次
batch_tensor = torch.stack(images)
# 批量推理
with torch.no_grad():
outputs = model(batch_tensor)
return outputs
使用TensorRT加速(高级优化):
# 将PyTorch模型转换为ONNX
torch.onnx.export(model, dummy_input, "model.onnx")
# 然后使用TensorRT优化ONNX模型
# 这需要安装TensorRT并编写转换脚本
5.2 提高识别准确率
数据增强:
from torchvision import transforms
# 训练时使用数据增强
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.RandomRotation(10),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# 推理时使用简单的预处理
infer_transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
模型集成:
class EnsembleModel:
"""多个模型的集成,提高准确率"""
def __init__(self, model_paths):
self.models = []
for path in model_paths:
model = load_model(path)
model.eval()
self.models.append(model)
def predict(self, image):
predictions = []
confidences = []
for model in self.models:
pred, conf = model.predict(image)
predictions.append(pred)
confidences.append(conf)
# 投票或加权平均
# 这里使用简单投票
from collections import Counter
most_common = Counter(predictions).most_common(1)[0][0]
# 计算平均置信度
avg_confidence = sum(confidences) / len(confidences)
return most_common, avg_confidence
后处理优化:
def post_process(detections, model_predictions, min_confidence=0.5):
"""
后处理:过滤低置信度结果,合并重叠框等
"""
results = []
for det, model_pred in zip(detections, model_predictions):
bbox, det_conf = det
model_name, model_conf = model_pred
# 综合置信度
combined_conf = det_conf * 0.6 + model_conf * 0.4
if combined_conf >= min_confidence:
results.append({
'bbox': bbox,
'model': model_name,
'detection_confidence': det_conf,
'model_confidence': model_conf,
'combined_confidence': combined_conf
})
# 非极大值抑制,去除重叠框
results = non_max_suppression(results, iou_threshold=0.5)
return results
def non_max_suppression(results, iou_threshold=0.5):
"""
非极大值抑制
"""
if not results:
return []
# 按置信度排序
results.sort(key=lambda x: x['combined_confidence'], reverse=True)
keep = []
while results:
# 取置信度最高的
best = results.pop(0)
keep.append(best)
# 移除与best重叠度高的
results = [r for r in results if
calculate_iou(best['bbox'], r['bbox']) < iou_threshold]
return keep
5.3 实际部署建议
资源监控:
import psutil
import time
class ResourceMonitor:
"""监控系统资源使用情况"""
def __init__(self):
self.start_time = time.time()
self.frame_count = 0
def get_status(self):
"""获取当前状态"""
# CPU使用率
cpu_percent = psutil.cpu_percent(interval=1)
# 内存使用
memory = psutil.virtual_memory()
# GPU信息(如果可用)
gpu_info = self.get_gpu_info()
# 帧率
elapsed = time.time() - self.start_time
fps = self.frame_count / elapsed if elapsed > 0 else 0
return {
'cpu_percent': cpu_percent,
'memory_percent': memory.percent,
'memory_used_gb': memory.used / 1024**3,
'gpu_info': gpu_info,
'fps': fps,
'uptime': elapsed,
'total_frames': self.frame_count
}
def increment_frame(self):
"""增加帧计数"""
self.frame_count += 1
def get_gpu_info(self):
"""获取GPU信息"""
try:
import pynvml
pynvml.nvmlInit()
gpu_count = pynvml.nvmlDeviceGetCount()
gpus = []
for i in range(gpu_count):
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
util = pynvml.nvmlDeviceGetUtilizationRates(handle)
memory = pynvml.nvmlDeviceGetMemoryInfo(handle)
gpus.append({
'index': i,
'name': pynvml.nvmlDeviceGetName(handle),
'gpu_util': util.gpu,
'memory_used_mb': memory.used / 1024**2,
'memory_total_mb': memory.total / 1024**2,
'memory_percent': memory.used / memory.total * 100
})
pynvml.nvmlShutdown()
return gpus
except ImportError:
return None
添加监控界面:
with gr.Accordion("系统监控", open=False):
monitor_output = gr.JSON(label="系统状态", every=2) # 每2秒更新一次
def update_monitor():
status = resource_monitor.get_status()
return status
# 自动更新监控信息
monitor_output.change(
fn=update_monitor,
outputs=monitor_output,
every=2
)
6. 总结
6.1 项目回顾
我们从头构建了一个完整的手机检测与型号识别系统,核心功能包括:
- 基础检测功能:基于DAMO-YOLO的实时手机检测,准确率88.8%,速度3.83ms/张
- 型号识别扩展:在检测基础上增加手机型号识别,让系统更有实用价值
- 友好Web界面:使用Gradio构建,支持图片上传、批量处理、实时视频
- 完整部署方案:从环境搭建到服务管理,提供了一站式解决方案
6.2 关键收获
技术层面:
- DAMO-YOLO确实做到了“小、快、省”,特别适合边缘部署
- Gradio让AI应用开发变得简单,快速构建Web界面
- 模型组合(检测+分类)是解决复杂问题的有效思路
工程层面:
- 良好的代码结构让系统易于维护和扩展
- 资源监控和性能优化是生产部署的关键
- 错误处理和日志记录不能忽视
6.3 下一步建议
如果你想让这个系统更强大,可以考虑:
- 扩充型号库:收集更多手机型号数据,训练更全面的分类模型
- 优化性能:尝试TensorRT加速、模型量化、蒸馏等技术
- 增加功能:比如手机状态识别(是否在充电、屏幕是否亮着)
- 云端部署:将系统部署到云端,提供API服务
- 移动端适配:开发手机App,随时随地使用
这个项目的价值在于它展示了一个完整的AI应用开发流程:从模型选择、功能扩展、界面开发到部署运维。你可以基于这个框架,轻松扩展到其他物体检测和识别任务。
获取更多AI镜像
想探索更多AI镜像和应用场景?访问 CSDN星图镜像广场,提供丰富的预置镜像,覆盖大模型推理、图像生成、视频生成、模型微调等多个领域,支持一键部署。

666


被折叠的 条评论
为什么被折叠?



