《Hands-On Machine Learning with Scikit-Learn, Keras TensorFlow》第一章读书笔记

第一部分:理论篇

在这里插入图片描述

1. 什么是机器学习

核心定义

机器学习是让计算机从数据中学习的科学,而无需显式编程。

经典定义

  • Arthur Samuel (1959):

“让计算机无需明确编程就具备学习能力”

  • Tom Mitchell 的工程定义:

“如果一个程序通过经验 E 在某任务 T 上获得性能改善 P,则称其学习”


2. 为什么使用机器学习

传统编程方法(以垃圾邮件检测为例):

  • 研究垃圾邮件特征
  • 编写规则检测这些特征
  • 测试并重复改进规则

缺点:

❗ 规则复杂度指数级增长

❗ 维护成本高昂

❗ 无法适应新型垃圾邮件

机器学习方法:

✅ 自动从样本中学习特征模式

✅ 动态适应新型垃圾邮件

✅ 维护和扩展成本显著降低


3. 机器学习的应用类型

适用场景:

  1. 传统方法难以制定明确规则的问题
  2. 传统解决方案过于复杂的问题
  3. 需适应环境变化的动态问题
  4. 从复杂数据中提取洞察

典型应用:

  • 图像分类
  • 语音识别
  • 垃圾邮件检测
  • 客户细分
  • 预测分析
  • 异常检测

4. 机器学习系统的类型

按训练监督划分

类型特征典型任务示例
监督学习训练数据含标签分类/回归垃圾邮件分类
无监督学习训练数据无标签聚类/降维客户分群
半监督学习部分数据有标签混合任务照片分类服务
强化学习基于奖惩机制动态决策游戏AI/机器人行走

按学习方式划分:

  • 批量学习:全量数据训练 ➔ 离线更新
  • 在线学习:流式数据训练 ➔ 实时适应

按工作方式划分:

  • 基于实例学习:记忆样本 ➔ 相似度匹配预测
  • 基于模型学习:构建模型 ➔ 参数化预测

5. 机器学习的主要挑战

数据挑战

  1. 数据量不足
  2. 样本代表性缺失
  3. 低质量数据(噪声/错误)
  4. 无关特征干扰

算法挑战

  • 过拟合:模型过度记忆训练数据细节
  • 欠拟合:模型未能捕捉数据基本规律

6. 测试和验证

核心原则

  • 数据划分:训练集(70-80%)/验证集(10-15%)/测试集(10-15%)
  • 交叉验证:K折交叉验证提升评估可靠性
  • 超参数调优:必须使用独立验证集

7. 实践建议

✅ 构建端到端流程:数据收集→预处理→建模→评估→部署

✅ 先尝试简单模型(如线性回归)建立基线

✅ 数据质量 > 算法复杂度

✅ 选择与业务目标匹配的评估指标(如F1分数/ROC-AUC)

严防数据泄露:确保训练数据不包含测试集信息


第二部分:习题篇

在这里插入图片描述

基础概念题

Q1. 如何定义机器学习

机器学习是关于构建能够从数据中学习的系统。学习的含义是:在某个任务上,根据某种性能度量标准,不断提升表现。


Q2. 机器学习在哪些场景中表现突出?

✅ 四类典型场景:

  1. 无明确算法解决方案的问题
  2. 替代人工调优的复杂规则系统
  3. 需动态适应环境变化的系统
  4. 从数据中挖掘隐含规律(数据挖掘)

Q3. 什么是带标签的训练集?

包含每个样本的期望输出(即标签)的训练数据集。例如:邮件数据集中的每封邮件被标注为“垃圾邮件”或“非垃圾邮件”。


学习类型辨析

Q4. 监督学习的典型任务是什么?

两类核心任务:

  • 回归(预测连续值,如房价预测)
  • 分类(预测离散类别,如图像识别)

Q5. 无监督学习的常见任务有哪些?

四类典型应用:

  1. 聚类(客户分群)
  2. 可视化(高维数据降维展示)
  3. 降维(特征压缩)
  4. 关联规则学习(购物篮分析)

Q6. 如何选择机器人行走地形的学习算法?

强化学习为首选方案:

  • 通过奖惩机制学习动态决策
  • 天然适配环境交互场景

⚠️ 注:监督学习需预定义行为标签,半监督学习需部分标注,二者均不直接适配该场景。


Q7. 客户分群应使用哪种算法?

分情况讨论:

  • 无预定义群体 → 聚类算法(无监督学习,如K-Means)
  • 已知群体类别 → 分类算法(监督学习,需标注数据)

系统特性与挑战

Q8. 在线学习系统的核心特征是什么?

⚡ 关键特性:

  • 增量学习:持续处理数据流
  • 实时适应:动态更新模型参数
  • 内存高效:无需存储历史全量数据

典型场景:金融实时风控、新闻推荐系统


Q9. 什么是外部存储(Out-of-Core)学习?

定义:

处理超出内存容量的超大规模数据时,将数据分批次加载到内存中进行增量训练的技术。

实现方式:

  • 数据分块(Mini-Batch)
  • 结合在线学习策略

Q10. 基于实例学习的原理是什么?

工作机制:

  1. 存储全部训练样本
  2. 新数据输入时,计算与存储样本的相似度(如欧氏距离)
  3. 根据最相似样本的标签进行预测

典型算法:K-近邻(KNN)


模型与参数解析

Q11. 模型参数 vs 超参数的区别?

特性模型参数超参数
定义模型内部的权重(如线性回归斜率)控制训练过程的配置参数
学习方式通过训练数据自动优化人工设定或自动调优(如网格搜索)
示例神经网络权重、SVM支持向量学习率、正则化系数、树的最大深度

Q12. 基于模型的学习如何运作?

三阶段过程:

  1. 目标:寻找最优模型参数,最小化损失函数(含正则化项)
  2. 训练策略:梯度下降、随机优化等
  3. 预测方式:将新数据输入参数化模型函数(如 )

实战挑战与解决方案

Q13. 机器学习的四大核心挑战?

⚠️ 关键瓶颈:

  1. 数据不足 → 模型难以捕捉规律
  2. 数据质量差 → 噪声干扰模型学习
  3. 样本不具代表性 → 泛化能力低下
  4. 特征信息量不足 → 模型性能天花板受限

Q14. 过拟合的识别与解决方案

问题识别

  • 训练集准确率高 ➔ 测试集准确率显著下降

三类解决方案

  1. 数据增强:收集更多数据/数据扩增(如图像旋转、添加噪声)
  2. 模型简化:减少网络层数、降低多项式次数、增加正则化
  3. 数据清洗:剔除异常样本、修复标签错误

Q15. 测试集的核心作用与误用风险

核心作用

  • 作为"未知数据"的代理,评估模型泛化性能

误用风险

  • 若用测试集调参 → 模型间接学习测试集分布 → 泛化误差评估虚高
  • 后果:线上部署后性能显著低于预期

Q16. 验证集 vs 训练-开发集的用途

数据集类型核心用途典型使用场景
验证集模型选择与超参数调优比较不同算法/参数组合效果
训练-开发集检测训练数据与验证数据的分布偏移数据分布不一致时的诊断工具

Q17. 数据泄露的预防策略

关键原则:

  • 严格隔离训练集、验证集、测试集
  • 避免在预处理阶段使用全量数据(如归一化需仅基于训练集统计量)
  • 时序数据需按时间顺序划分(禁止随机划分)

Q18. 如何诊断训练数据与验证数据的分布偏移?

训练-开发集(Train-Dev Set)的使用方法:

  1. 数据划分:
  • 训练集(70%) → 模型训练
  • 训练-开发集(10%) → 从训练集划分,用于检测过拟合
  • 验证集(20%) → 保持独立分布
  1. 诊断逻辑:
  • 若模型在 训练集 表现好,但在 训练-开发集 表现差 → 过拟合
  • 若模型在 训练-开发集 表现好,但在 验证集 表现差 → 数据分布偏移
  1. 解决方案:
  • 调整训练数据,使其更接近真实场景分布
  • 使用数据增强技术模拟验证集特征

Q19. 为什么不能直接用测试集调参?

⚠️ 核心风险:

  • 测试集污染:超参数调优本质上是"学习"测试集分布的过程
  • 性能虚高:模型会过度适配测试集特性,导致:
  • 线上部署效果显著低于测试集指标
  • 失去对真实未知数据的泛化能力

正确流程:

  1. 原始数据 → 划分训练集、验证集、测试集
  2. 用训练集训练模型
  3. 用验证集调参/选择模型
  4. 用测试集仅做最终评估(且仅限一次!)

第三部分:实践篇

「链接」

环境设置

验证Python版本要求(3.7及以上)

import sys

assert sys.version_info >= (3, 7)  # 检查Python版本是否满足要求

验证Scikit-Learn版本要求(≥1.0.1)

from packaging import version
import sklearn

assert version.parse(sklearn.__version__) >= version.parse("1.0.1")  # 检查sklearn版本是否达标

设置matplotlib默认字体大小,优化图表显示效果

import matplotlib.pyplot as plt

plt.rc('font', size=12)            # 全局字体大小
plt.rc('axes', labelsize=14, titlesize=14)  # 坐标轴标签和标题大小
plt.rc('legend', fontsize=12)       # 图例字体大小
plt.rc('xtick', labelsize=10)       # X轴刻度标签大小
plt.rc('ytick', labelsize=10)       # Y轴刻度标签大小

设置随机种子保证结果可复现

import numpy as np

np.random.seed(42)  # 固定随机数生成器的种子

代码示例1-1

导入必要库并加载数据集,展示GDP与生活满意度的关系

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.linear_model import LinearRegression

# 下载并准备数据
data_root = "https://github.com/ageron/data/raw/main/"
lifesat = pd.read_csv(data_root + "lifesat/lifesat.csv")
X = lifesat[["GDP per capita (USD)"]].values  # 提取GDP特征
y = lifesat[["Life satisfaction"]].values     # 提取目标变量

# 可视化数据
lifesat.plot(kind='scatter', grid=True,
             x="GDP per capita (USD)", y="Life satisfaction")
plt.axis([23_500, 62_500, 4, 9])  # 设置坐标轴范围
plt.show()

# 选择线性回归模型
model = LinearRegression()

# 训练模型
model.fit(X, y)

# 预测塞浦路斯数据
X_new = [[37_655.2]]  # 塞浦路斯2020年人均GDP
print(model.predict(X_new)) # 输出预测值 [[6.30165767]]

输出结果:

[[6.30165767]]

将线性回归模型替换为K近邻回归(k=3)

from sklearn.neighbors import KNeighborsRegressor

# 创建KNN回归器实例
model = KNeighborsRegressor(n_neighbors=3)

# 训练模型
model.fit(X, y)

# 进行预测
print(model.predict(X_new)) # 输出预测值 [[6.33333333]]

输出结果:

[[6.33333333]]

数据与图表生成

以下是生成lifesat.csv数据集的代码。

创建图像保存函数

from pathlib import Path

# 图像保存路径设置
IMAGES_PATH = Path() / "images" / "fundamentals"
IMAGES_PATH.mkdir(parents=True, exist_ok=True)  # 创建多级目录

def save_fig(fig_id, tight_layout=True, fig_extension="png", resolution=300):
    path = IMAGES_PATH / f"{fig_id}.{fig_extension}"
    if tight_layout:
        plt.tight_layout()  # 自动调整子图间距
    plt.savefig(path, format=fig_extension, dpi=resolution)  # 保存图像

加载和处理生活满意度数据

自动下载原始数据文件

import urllib.request

datapath = Path() / "datasets" / "lifesat"
datapath.mkdir(parents=True, exist_ok=True)  # 创建数据存储目录

data_root = "https://github.com/ageron/data/raw/main/"
for filename in ("oecd_bli.csv", "gdp_per_capita.csv"):
    if not (datapath / filename).is_file():
        print("下载中", filename)
        url = data_root + "lifesat/" + filename
        urllib.request.urlretrieve(url, datapath / filename)  # 下载缺失文件

输出:

下载中 oecd_bli.csv
下载中 gdp_per_capita.csv

加载预处理后的GDP数据(仅保留2020年)

gdp_year = 2020
gdppc_col = "GDP per capita (USD)"
lifesat_col = "Life satisfaction"

gdp_per_capita = gdp_per_capita[gdp_per_capita["Year"] == gdp_year]  # 筛选年份
gdp_per_capita = gdp_per_capita.drop(["Code", "Year"], axis=1)       # 删除冗余列
gdp_per_capita.columns = ["Country", gdppc_col]                      # 重命名列
gdp_per_capita.set_index("Country", inplace=True)                    # 设置国家为索引

gdp_per_capita.head()  # 显示前五行数据

处理OECD BLI数据(提取生活满意度指标)

oecd_bli = oecd_bli[oecd_bli["INEQUALITY"]=="TOT"]       # 筛选总评数据
oecd_bli = oecd_bli.pivot(index="Country", columns="Indicator", values="Value")  # 数据透视

oecd_bli.head()  # 显示预处理后的数据结构

合并OECD生活满意度数据与GDP数据,创建完整数据集

full_country_stats = pd.merge(left=oecd_bli, right=gdp_per_capita,
                              left_index=True, right_index=True)  # 按国家索引进行左连接
full_country_stats.sort_values(by=gdppc_col, inplace=True)        # 按GDP排序
full_country_stats = full_country_stats[[gdppc_col, lifesat_col]] # 保留关键列

full_country_stats.head()  # 显示合并后的前五行数据

设置GDP过滤范围,创建演示用子集(用于避免过拟合示例)

min_gdp = 23_500
max_gdp = 62_500

country_stats = full_country_stats[(full_country_stats[gdppc_col] >= min_gdp) &
                                   (full_country_stats[gdppc_col] <= max_gdp)]  # GDP区间过滤
country_stats.head()  # 显示筛选后的数据

保存处理后的数据集

country_stats.to_csv(datapath / "lifesat.csv")        # 保存筛选数据集
full_country_stats.to_csv(datapath / "lifesat_full.csv")  # 保存完整数据集

绘制带国家标注的散点图

country_stats.plot(kind='scatter', figsize=(5, 3), grid=True,
                   x=gdppc_col, y=lifesat_col)  # 创建基础散点图

# 定义各国标注位置
position_text = {
    "Turkey": (29_500, 4.2),
    "Hungary": (28_000, 6.9),
    "France": (40_000, 5),
    "New Zealand": (28_000, 8.2),
    "Australia": (50_000, 5.5),
    "United States": (59_000, 5.3),
    "Denmark": (46_000, 8.5)
}

# 添加国家标注和指示箭头
for country, pos_text in position_text.items():
    pos_data_x = country_stats[gdppc_col].loc[country]
    pos_data_y = country_stats[lifesat_col].loc[country]
    country = "U.S." if country == "United States" else country  # 简化显示名称
    plt.annotate(country, xy=(pos_data_x, pos_data_y),
                 xytext=pos_text, fontsize=12,
                 arrowprops=dict(facecolor='black', width=0.5,
                                 shrink=0.08, headwidth=5))  # 添加箭头标注
    plt.plot(pos_data_x, pos_data_y, "ro")  # 标记数据点

plt.axis([min_gdp, max_gdp, min_life_sat, max_life_sat])  # 固定坐标范围
save_fig('money_happy_scatterplot')  # 保存图像
plt.show()

提取高亮国家数据并按GDP排序

highlighted_countries = country_stats.loc[list(position_text.keys())]
highlighted_countries[[gdppc_col, lifesat_col]].sort_values(by=gdppc_col)  # GDP升序排列

绘制不同参数组合的线性模型对比图

country_stats.plot(kind='scatter', figsize=(5, 3), grid=True,
                   x=gdppc_col, y=lifesat_col)  # 基础散点图

X = np.linspace(min_gdp, max_gdp, 1000)  # 生成连续GDP值

# 绘制三组不同参数的直线
w1, w2 = 4.2, 0
plt.plot(X, w1 + w2 * 1e-5 * X, "r")  # 红色水平线
plt.text(40_000, 4.9, fr"$\theta_0 = {w1}$", color="r")
plt.text(40_000, 4.4, fr"$\theta_1 = {w2}$", color="r")

w1, w2 = 10, -9
plt.plot(X, w1 + w2 * 1e-5 * X, "g")  # 绿色下降趋势线
plt.text(26_000, 8.5, fr"$\theta_0 = {w1}$", color="g")
plt.text(26_000, 8.0, fr"$\theta_1 = {w2} \times 10^{{-5}}$", color="g")

w1, w2 = 3, 8
plt.plot(X, w1 + w2 * 1e-5 * X, "b")  # 蓝色上升趋势线
plt.text(48_000, 8.5, fr"$\theta_0 = {w1}$", color="b")
plt.text(48_000, 8.0, fr"$\theta_1 = {w2} \times 10^{{-5}}$", color="b")

plt.axis([min_gdp, max_gdp, min_life_sat, max_life_sat])
save_fig('tweaking_model_params_plot')
plt.show()

训练线性回归模型并输出参数

from sklearn import linear_model

X_sample = country_stats[[gdppc_col]].values  # 特征矩阵
y_sample = country_stats[[lifesat_col]].values  # 目标向量

lin1 = linear_model.LinearRegression()  # 创建线性回归器
lin1.fit(X_sample, y_sample)  # 训练模型

t0, t1 = lin1.intercept_[0], lin1.coef_[0][0]  # 获取截距和斜率
print(f"θ0={t0:.2f}, θ1={t1:.2e}")  # 格式化输出参数

输出结果:

θ0=3.75, θ1=6.78e-05

绘制最佳拟合线可视化

country_stats.plot(kind='scatter', figsize=(5, 3), grid=True,
                   x=gdppc_col, y=lifesat_col)  # 基础散点图

X = np.linspace(min_gdp, max_gdp, 1000)
plt.plot(X, t0 + t1 * X, "b")  # 绘制回归线

# 添加参数标注
plt.text(max_gdp - 20_000, min_life_sat + 1.9,
         fr"$\theta_0 = {t0:.2f}$", color="b")
plt.text(max_gdp - 20_000, min_life_sat + 1.3,
         fr"$\theta_1 = {t1 * 1e5:.2f} \times 10^{{-5}}$", color="b")

plt.axis([min_gdp, max_gdp, min_life_sat, max_life_sat])
save_fig('best_fit_model_plot')
plt.show()

获取塞浦路斯的GDP数据

cyprus_gdp_per_capita = gdp_per_capita[gdppc_col].loc["Cyprus"]  # 通过国家名索引
cyprus_gdp_per_capita  # 显示值

使用训练好的模型进行预测

cyprus_predicted_life_satisfaction = lin1.predict([[cyprus_gdp_per_capita]])[0, 0]
cyprus_predicted_life_satisfaction  # 显示预测结果

可视化预测结果

country_stats.plot(kind='scatter', figsize=(5, 3), grid=True,
                   x=gdppc_col, y=lifesat_col)  # 基础散点图

X = np.linspace(min_gdp, max_gdp, 1000)
plt.plot(X, t0 + t1 * X, "b")  # 绘制回归线

# 调整参数标注位置
plt.text(min_gdp + 22_000, max_life_sat - 1.1,
         fr"$\theta_0 = {t0:.2f}$", color="b")
plt.text(min_gdp + 22_000, max_life_sat - 0.6,
         fr"$\theta_1 = {t1 * 1e5:.2f} \times 10^{{-5}}$", color="b")

# 添加预测线标注
plt.plot([cyprus_gdp_per_capita, cyprus_gdp_per_capita],
         [min_life_sat, cyprus_predicted_life_satisfaction], "r--")  # 红色虚线
plt.text(cyprus_gdp_per_capita + 1000, 5.0,
         fr"Prediction = {cyprus_predicted_life_satisfaction:.2f}", color="r")
plt.plot(cyprus_gdp_per_capita, cyprus_predicted_life_satisfaction, "ro")  # 红点标记

plt.axis([min_gdp, max_gdp, min_life_sat, max_life_sat])
plt.show()

提取GDP区间外的异常数据点

missing_data = full_country_stats[(full_country_stats[gdppc_col] < min_gdp) |
                                  (full_country_stats[gdppc_col] > max_gdp)]  # 筛选超出GDP范围的数据
missing_data  # 显示异常数据点

定义缺失国家的标注位置坐标

position_text_missing_countries = {
    "South Africa": (20_000, 4.2),    # 南非标注位置
    "Colombia": (6_000, 8.2),         # 哥伦比亚
    "Brazil": (18_000, 7.8),          # 巴西
    "Mexico": (24_000, 7.4),          # 墨西哥
    "Chile": (30_000, 7.0),           # 智利
    "Norway": (51_000, 6.2),          # 挪威
    "Switzerland": (62_000, 5.7),     # 瑞士
    "Ireland": (81_000, 5.2),         # 爱尔兰
    "Luxembourg": (92_000, 4.7),      # 卢森堡
}

绘制完整数据分布及异常点标注

full_country_stats.plot(kind='scatter', figsize=(8, 3),
                        x=gdppc_col, y=lifesat_col, grid=True)  # 全数据散点图

# 添加异常点标注
for country, pos_text in position_text_missing_countries.items():
    pos_data_x, pos_data_y = missing_data.loc[country]  # 获取实际坐标
    plt.annotate(country, xy=(pos_data_x, pos_data_y),
                 xytext=pos_text, fontsize=12,
                 arrowprops=dict(facecolor='black', width=0.5,
                                 shrink=0.08, headwidth=5))  # 带箭头文本标注
    plt.plot(pos_data_x, pos_data_y, "rs")  # 红色方块标记

# 绘制局部数据拟合线
X = np.linspace(0, 115_000, 1000)
plt.plot(X, t0 + t1 * X, "b:")  # 蓝色虚线表示局部数据模型

# 使用完整数据训练新模型
lin_reg_full = linear_model.LinearRegression()
Xfull = np.c_[full_country_stats[gdppc_col]]  # 二维数组转换
yfull = np.c_[full_country_stats[lifesat_col]]
lin_reg_full.fit(Xfull, yfull)  # 全数据训练

# 绘制全数据拟合线
t0full, t1full = lin_reg_full.intercept_[0], lin_reg_full.coef_[0][0]
plt.plot(X, t0full + t1full * X, "k")  # 黑色实线表示全数据模型

plt.axis([0, 115_000, min_life_sat, max_life_sat])  # 扩展坐标范围
save_fig('representative_training_data_scatterplot')
plt.show()

构建过拟合多项式回归模型

from sklearn import preprocessing
from sklearn import pipeline

full_country_stats.plot(kind='scatter', figsize=(8, 3),
                        x=gdppc_col, y=lifesat_col, grid=True)  # 全数据散点图

# 创建多项式回归流水线
poly = preprocessing.PolynomialFeatures(degree=10, include_bias=False)  # 10次多项式
scaler = preprocessing.StandardScaler()        # 标准化
lin_reg2 = linear_model.LinearRegression()     # 线性回归

pipeline_reg = pipeline.Pipeline([
    ('poly', poly),    # 特征多项式扩展
    ('scal', scaler),  # 数据标准化
    ('lin', lin_reg2)])  # 线性回归
pipeline_reg.fit(Xfull, yfull)  # 训练高阶模型
curve = pipeline_reg.predict(X[:, np.newaxis])  # 生成预测曲线
plt.plot(X, curve)  # 绘制过拟合曲线

plt.axis([0, 115_000, min_life_sat, max_life_sat])
save_fig('overfitting_model_plot')  # 保存过拟合示例图
plt.show()

筛选国家名称含’W’字母的国家数据

w_countries = [c for c in full_country_stats.index if "W" in c.upper()]  # 列表推导式筛选
full_country_stats.loc[w_countries][lifesat_col]  # 显示生活满意度

获取所有含’W’国家的GDP数据

all_w_countries = [c for c in gdp_per_capita.index if "W" in c.upper()]
gdp_per_capita.loc[all_w_countries].sort_values(by=gdppc_col)  # GDP升序排列

对比不同回归模型效果

# 创建复合图表
country_stats.plot(kind='scatter', x=gdppc_col, y=lifesat_col, figsize=(8, 3))
missing_data.plot(kind='scatter', x=gdppc_col, y=lifesat_col,
                  marker="s", color="r", grid=True, ax=plt.gca())  # 红色方块标记异常点

# 绘制三种回归线
X = np.linspace(0, 115_000, 1000)
plt.plot(X, t0 + t1*X, "b:", label="基于部分数据的线性模型")  # 局部数据模型
plt.plot(X, t0full + t1full * X, "k-", label="全数据线性模型")  # 全数据模型

# 训练岭回归模型
ridge = linear_model.Ridge(alpha=10**9.5)  # 设置正则化强度
X_sample = country_stats[[gdppc_col]]
y_sample = country_stats[[lifesat_col]]
ridge.fit(X_sample, y_sample)  # 训练正则化模型
t0ridge, t1ridge = ridge.intercept_[0], ridge.coef_[0][0]
plt.plot(X, t0ridge + t1ridge * X, "b--", 
         label="部分数据正则化线性模型")  # 蓝色虚线

plt.legend(loc="lower right")  # 右下角图例
plt.axis([0, 115_000, min_life_sat, max_life_sat])
save_fig('ridge_model_plot')  # 保存岭回归对比图
plt.show()

DeepSeek 代码分析

在这里插入图片描述

1. 核心功能

  • 任务类型:单变量回归任务(GDP per capita ➔ Life satisfaction)
  • 模型架构:对比线性回归/K近邻/多项式回归/岭回归等多种方法
  • 特色:通过可视化展示模型差异,演示过拟合现象

2. 关键组件

A[环境配置] --> B[数据获取]–> C[数据预处理]–> D[特征工程]–> E[模型训练]–> F[结果可视化]–> G[模型对比]–> H[过拟合分析]

3. 数据流

原始数据 → 筛选年份 → 合并数据集 → 特征标准化 → 模型输入 → 预测输出 → 可视化展示

4. 创新点

  • 采用对比可视化技术:同时展示多个模型的拟合效果
  • 创新标注系统:动态生成带箭头的国家标注注释
  • 混合可视化策略:将原始数据点与模型预测曲线叠加展示
  • 过拟合演示:通过极端多项式回归直观展示过拟合现象

5. 改进建议 ▲ 数据预处理阶段:

  • 缺失值处理未显式体现(假设数据完整)
  • 特征标准化仅应用于部分模型(建议统一预处理流程)
  • 数据拆分未采用标准train-test split(可能引入数据泄露)

▲ 模型评估:

  • 缺乏量化评估指标(如RMSE/R²分数)
  • 未使用交叉验证评估模型稳定性
  • 超参数选择(如KNN的k值)未说明依据

▲ 代码结构:

  • 可封装重复可视化操作为函数
  • 建议使用sklearn Pipeline整合预处理步骤
  • 路径处理可改用pathlib的现代写法

6. 学习要点 重点关注的实现细节:

  • 数据透视操作:pivot()方法重构数据格式
  • 坐标轴标注技巧:fr" θ 0 = t 0 : . 2 f \theta_0 = {t0:.2f} θ0=t0:.2f"的LaTeX公式渲染
  • 混合可视化技术:基础散点图+预测曲线+标注元素的叠加方式
  • 正则化应用:Ridge回归的参数设置(alpha=10**9.5)对模型的影响

批判性思考补充 ▲ 数据代表性风险:

  • 筛选GDP范围时可能引入样本偏差(min_gdp=23_500)
  • 异常值处理采用简单截断法,可能丢失重要信息
  • 国家标注位置手工设定,缺乏自动化定位机制

▲ 模型可解释性:

  • 多项式回归的degree=10选择缺乏理论依据
  • 未进行特征重要性分析
  • 不同模型预测差异未进行统计显著性检验

代码亮点总结

  • 环境配置完整:版本检查+随机种子+可视化预设
  • 数据故事化呈现:通过渐进式可视化构建分析逻辑
  • 模型对比直观:多模型拟合曲线同图对比
  • 可复现设计:自动下载数据+路径创建+随机种子

建议重点关注代码中数据流转换的实现方式(特别是pandas操作)和可视化组件的构建逻辑,这是实现分析型机器学习项目的典型范式。

第四部分:思考篇

1. 预备知识要求

  1. Python编程基础
  2. 科学计算库:NumPy、Pandas、Matplotlib
  3. 基础数学概念:
  • 线性代数(向量、矩阵运算)
  • 微积分(理解神经网络训练原理)
  • 基础概率论
  • 基础统计学

2. 学习建议

  1. 不要过早跳入深度学习
  2. 建议先掌握机器学习基础
  3. 大多数问题可以用更简单的技术解决(如随机森林)
  4. 深度学习最适合:
  • 图像识别
  • 语音识别
  • 自然语言处理
  • 需要大量数据和计算资源的问题

http://www.niftyadmin.cn/n/5840744.html

相关文章

利用腾讯云cloud studio云端免费部署deepseek-R1

1. cloud studio 1.1 cloud studio介绍 Cloud Studio&#xff08;云端 IDE&#xff09;是基于浏览器的集成式开发环境&#xff0c;为开发者提供了一个稳定的云端工作站。支持CPU与GPU的访问。用户在使用 Cloud Studio 时无需安装&#xff0c;随时随地打开浏览器即可使用。Clo…

51单片机红外遥控器模拟控制空调,自动制冷制热定时开关

主要功能是通过红外遥控器模拟控制空调&#xff0c;可以实现根据环境温度制冷和制热&#xff0c;能够通过遥控器设定温度&#xff0c;可以定时开关空调。 1.硬件介绍 硬件是我自己设计的一个通用的51单片机开发平台&#xff0c;可以根据需要自行焊接模块&#xff0c;这是用立创…

【Conda 和 虚拟环境详细指南】

Conda 和 虚拟环境的详细指南 什么是 Conda&#xff1f; Conda 是一个开源的包管理和环境管理系统&#xff0c;支持多种编程语言&#xff08;如Python、R等&#xff09;&#xff0c;最初由Continuum Analytics开发。 主要功能&#xff1a; 包管理&#xff1a;安装、更新、删…

【python】python基于机器学习与数据分析的手机特性关联与分类预测(源码+数据集)【独一无二】

&#x1f449;博__主&#x1f448;&#xff1a;米码收割机 &#x1f449;技__能&#x1f448;&#xff1a;C/Python语言 &#x1f449;专__注&#x1f448;&#xff1a;专注主流机器人、人工智能等相关领域的开发、测试技术。 python基于机器学习与数据分析的手机特性关联与分类…

鸟哥Linux私房菜笔记(三)

鸟哥Linux私房菜笔记&#xff08;三&#xff09; 该第三部分和第四部分主要为原书的第十一章&#xff08;正则表达式与文件格式化处理&#xff09;&#xff0c;第十二章学习shell脚本&#xff0c;第十六章&#xff08;进程管理与SElinux初探部分&#xff09;&#xff0c;第十七…

在 Zemax 中使用布尔对象创建光学光圈

在 Zemax 中&#xff0c;布尔对象用于通过组合或减去较简单的几何形状来创建复杂形状。布尔运算涉及使用集合运算&#xff08;如并集、交集和减集&#xff09;来组合或修改对象的几何形状。这允许用户在其设计中为光学元件或机械部件创建更复杂和定制的形状。 本视频中&#xf…

Vue.js `v-memo` 性能优化技巧

Vue.js v-memo 性能优化技巧 今天我们来聊聊 Vue 3.2 引入的一个性能优化指令&#xff1a;v-memo。如果你在处理大型列表或复杂组件时&#xff0c;遇到性能瓶颈&#xff0c;那么 v-memo 可能会成为你的得力助手。 什么是 v-memo&#xff1f; v-memo 是 Vue 3.2 新增的内置指…

MFC程序设计(六)消息和控件

消息的分类 在上一节消息映射中&#xff0c;我们发现&#xff0c;无论是create消息还是paint消息使用的都是同一个宏ON_MESSAGE。通过ON_MESSAGE宏我们可以在上一节的静态数组中添加任何消息处理函数&#xff0c;因此ON_MESSAGE也被称为通用宏。但在实际的应用中&#xff0c;在…