一、项目背景
在电商、传统零售等行业中,广告投放是提升销售额的重要手段。企业常常需要分析TV、广播、报纸三类广告渠道的投入与最终销售额之间的关联,以此制定合理的广告投放策略、控制营销成本。
本文基于经典广告数据集,使用 Pandas 完成数据读取与处理,Seaborn+Matplotlib 实现数据可视化,最后借助 Scikit-learn 搭建多元线性回归模型,拟合广告投入与销售额的数学关系,完成销量预测与模型评估。
二、环境与依赖库
本次项目用到的核心 Python 库:
pandas:数据读取、数据结构化处理numpy:数值计算基础库seaborn/matplotlib:数据可视化绘图scikit-learn:划分数据集、构建线性回归模型
环境说明:本文使用 NumPy 1.x 版本,规避高版本 NumPy 带来的库兼容报错,若你遇到
NumPy 2.x模块编译报错,建议降级numpy==1.26.4。
三、数据集介绍
使用经典的 Advertising.csv 广告数据集,数据集共包含 4 列:
TV:电视广告投入金额radio:广播广告投入金额newspaper:报纸广告投入金额sales:对应产生的销售额
数据集样本结构简单、特征明确,非常适合入门多元线性回归实战。
四、完整代码实现与逐行解析
4.1 导入依赖库
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
4.2 读取并查看原始数据
通过pandas读取本地 CSV 文件,并打印前 5 行查看数据格式:
# 读取广告数据集
data = pd.read_csv('Advertising.csv')
# 查看前5条数据
print(data.head())
运行结果示例:
Number TV radio newspaper sales
0 1 230.1 37.8 69.2 22.1
1 2 44.5 39.3 45.1 10.4
2 3 17.2 45.9 69.3 9.3
3 4 151.5 41.3 58.5 18.5
4 5 180.8 10.8 58.4 12.9
4.3 数据可视化:特征与销量相关性分析
使用 seaborn.pairplot 绘制特征 - 标签回归关系图,直观观察三类广告和销售额的线性相关程度:
# 绘制单变量与销售额的回归散点图
sns.pairplot(
data,
x_vars=['TV', 'radio', 'newspaper'], # X轴:三个广告特征
y_vars='sales', # Y轴:目标标签 销售额
height=5, # 子图高度
aspect=0.8, # 子图宽高比
kind='reg' # 绘制拟合回归线
)
plt.show()
代码说明:
x_vars/y_vars:指定横、纵坐标对应的字段;kind='reg':开启回归拟合,自动画出散点 + 拟合直线;- 从图像可以初步判断:TV、广播广告和销售额线性相关性更强,报纸广告相关性偏弱。
4.4 划分特征集与标签集
将数据拆分为特征 X(广告投入)和标签 y(销售额),这是机器学习建模的标准步骤:
# 特征集:三类广告投入
X = data[['TV', 'radio', 'newspaper']]
# 标签集:销售额(预测目标)
y = data['sales']
4.5 划分训练集与测试集
使用 train_test_split 将数据集按照 75% 训练集 + 25% 测试集 划分,random_state 固定随机种子,保证每次运行划分结果一致:
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(
X, y,
random_state=1,
test_size=0.25 # 测试集占比25%
)
4.6 构建并训练多元线性回归模型
本次任务是连续值预测(回归问题),因此使用 LinearRegression 线性回归模型(注意区分分类算法逻辑回归):
# 导入线性回归模型
from sklearn.linear_model import LinearRegression
# 初始化模型
linreg = LinearRegression()
# 使用训练集训练模型
linreg.fit(X_train, y_train)
4.7 模型预测 & 输出回归方程
模型训练完成后,使用测试集数据进行预测,并打印回归系数和截距,推导出最终的销量计算公式:
# 测试集预测
y_pred = linreg.predict(X_test)
# 输出模型参数
print("="*50)
print("各广告特征系数(TV, 广播, 报纸):", linreg.coef_)
print("回归截距:", linreg.intercept_)
print("="*50)
# 拼接线性回归公式
print("销售额预测公式:")
print(f"销售额 = {linreg.coef_[0]:.3f}*TV + {linreg.coef_[1]:.3f}*广播 + {linreg.coef_[2]:.3f}*报纸 + {linreg.intercept_:.3f}")
参数解读:
- 系数 (coef_):代表单个广告渠道每增加 1 单位投入,销售额的增量;系数越大,该渠道广告效果越好。
- 截距 (intercept_):所有广告投入为 0 时,基础销售额。
五、常见问题避坑
-
NumPy 2.x 版本兼容报错 报错提示
A module that was compiled using NumPy 1.x cannot be run in NumPy 2.x,解决方案:降级 NumPypip install numpy==1.26.4 --force-reinstall -
回归与逻辑回归混淆 销售额是连续数值,属于回归问题,必须使用
LinearRegression;LogisticRegression仅用于二分类 / 多分类场景,误用会直接报错。 -
文件路径问题
Advertising.csv必须和代码文件放在同一文件夹,否则需要填写文件绝对路径。
六、总结
本文完整实现了数据读取 → 可视化分析 → 数据集划分 → 模型训练 → 预测验证全流程的多元线性回归实战。
- 利用可视化完成探索性数据分析,挖掘广告与销量的关联;
- 基于 Scikit-learn 快速搭建多元线性回归模型,得到可落地的预测公式;
- 该案例是机器学习回归算法的入门经典,可延伸拓展:特征筛选、模型调优、新增特征、岭回归 / Lasso 回归防过拟合等。
整套代码简洁易懂,适合 Python 数据分析、机器学习入门学习者练习使用。


被折叠的 条评论
为什么被折叠?



