使用 PyTorch 实现逻辑回归并评估模型性能

news/2025/2/3 6:31:34 标签: pytorch, 逻辑回归, 人工智能

1. 逻辑回归简介

逻辑回归是一种用于解决二分类问题的算法。它通过一个逻辑函数(Sigmoid 函数)将线性回归的输出映射到 [0, 1] 区间内,从而将问题转化为概率预测问题。如果预测概率大于 0.5,则将样本分类为正类;否则分类为负类。

2. 数据准备

为了演示逻辑回归的效果,我们构造了一个简单的二维数据集,包含两类样本。每类样本有 7 个数据点,特征维度为 2。

class1_points = np.array([[1.9, 1.2],
                          [1.5, 2.1],
                          [1.9, 0.5],
                          [1.5, 0.9],
                          [0.9, 1.2],
                          [1.1, 1.7],
                          [1.4, 1.1]])

class2_points = np.array([[3.2, 3.2],
                          [3.7, 2.9],
                          [3.2, 2.6],
                          [1.7, 3.3],
                          [3.4, 2.6],
                          [4.1, 2.3],
                          [3.0, 2.9]])

我们将这两类数据点的特征合并,并为每个数据点分配标签(0 表示第一类,1 表示第二类)。

3. 模型构建

我们使用 PyTorch 框架来实现逻辑回归模型。模型结构非常简单,仅包含一个线性层和一个 Sigmoid 激活函数。

class LogisticRegression(nn.Module):
    def __init__(self):
        super(LogisticRegression, self).__init__()
        self.linear = nn.Linear(2, 1)  # 输入特征维度为 2,输出为 1

    def forward(self, x):
        return torch.sigmoid(self.linear(x))

4. 模型训练

我们使用二分类交叉熵损失函数(BCELoss)和随机梯度下降优化器(SGD)来训练模型。训练过程如下:

epochs = 5000
for epoch in range(epochs):
    model.train()
    optimizer.zero_grad()
    outputs = model(X)
    loss = criterion(outputs, y)
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 100 == 0:
        print(f'Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}')

训练过程中,我们每 100 个 epoch 打印一次损失值,以便观察模型的收敛情况。

5. 模型保存与加载

训练完成后,我们将模型的权重保存到文件中,方便后续加载和使用。

torch.save(model.state_dict(), 'model3.pth')
print("模型已保存")

加载模型时,我们创建一个新的模型实例,并使用 load_state_dict 方法加载保存的权重。

loaded_model = LogisticRegression()
loaded_model.load_state_dict(torch.load('model3.pth', map_location=torch.device('cpu')))
loaded_model.eval()

6. 模型预测与性能评估

加载模型后,我们使用模型对训练数据进行预测,并计算精确度、召回率和 F1 分数。

with torch.no_grad():
    predictions = loaded_model(X)
    predicted_labels = (predictions > 0.5).float()

print("实际结果:", y.numpy().flatten())
print("预测结果:", predicted_labels.numpy().flatten())

precision = precision_score(y.numpy().flatten(), predicted_labels.numpy().flatten())
recall = recall_score(y.numpy().flatten(), predicted_labels.numpy().flatten())
f1 = f1_score(y.numpy().flatten(), predicted_labels.numpy().flatten())

print(f"精确度(Precision): {precision:.4f}")
print(f"召回率(Recall): {recall:.4f}")
print(f"F1 分数: {f1:.4f}")

7. 运行结果

运行上述代码后,我们得到了以下结果:

  • 实际结果:[0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1.]

  • 预测结果:[0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1.]

  • 精确度(Precision):1.0000

  • 召回率(Recall):1.0000

  • F1 分数:1.0000

从结果可以看出,模型在训练集上表现良好,精确度、召回率和 F1 分数均为 1.0000。

8. 完整代码

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import precision_score, recall_score, f1_score

"""使用pytorch框架实现逻辑回归并保存模型,然后保存模型后再加载模型进行预测,对预测结果计算精确度和召回率及F1分数"""
# 提取特征和标签
class1_points = np.array([[1.9, 1.2],
                          [1.5, 2.1],
                          [1.9, 0.5],
                          [1.5, 0.9],
                          [0.9, 1.2],
                          [1.1, 1.7],
                          [1.4, 1.1]])

class2_points = np.array([[3.2, 3.2],
                          [3.7, 2.9],
                          [3.2, 2.6],
                          [1.7, 3.3],
                          [3.4, 2.6],
                          [4.1, 2.3],
                          [3.0, 2.9]])

# 提取两类特征,输入特征维度为2
x1_data = np.concatenate((class1_points[:, 0], class2_points[:, 0]), axis=0)
x2_data = np.concatenate((class1_points[:, 1], class2_points[:, 1]), axis=0)
label = np.concatenate((np.zeros(len(class1_points)), np.ones(len(class2_points))), axis=0)

# 将数据转换为 PyTorch 张量
X = torch.tensor(np.column_stack((x1_data, x2_data)), dtype=torch.float32)
y = torch.tensor(label, dtype=torch.float32).view(-1, 1)

# 定义逻辑回归模型
class LogisticRegression(nn.Module):
    def __init__(self):
        super(LogisticRegression, self).__init__()
        self.linear = nn.Linear(2, 1)  # 输入特征维度为 2,输出为 1

    def forward(self, x):
        return torch.sigmoid(self.linear(x))

# 初始化模型、损失函数和优化器
model = LogisticRegression()
criterion = nn.BCELoss()  # 二分类交叉熵损失
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 训练模型
epochs = 5000
for epoch in range(epochs):
    model.train()
    optimizer.zero_grad()
    outputs = model(X)
    loss = criterion(outputs, y)
    loss.backward()
    optimizer.step()

    if (epoch + 1) % 100 == 0:
        print(f'Epoch [{epoch + 1}/{epochs}], Loss: {loss.item():.4f}')

# 保存模型
torch.save(model.state_dict(), 'model3.pth')
print("模型已保存")

# 加载模型
loaded_model = LogisticRegression()
loaded_model.load_state_dict(torch.load('model3.pth',
                                        map_location=torch.device('cpu'),
                                        weights_only=True))
loaded_model.eval()

# 进行预测
with torch.no_grad():
    predictions = loaded_model(X)
    predicted_labels = (predictions > 0.5).float()

 # 展示预测结果和实际结果
print("实际结果:", y.numpy().flatten())
print("预测结果:", predicted_labels.numpy().flatten())

# 计算精确度、召回率和 F1 分数
precision = precision_score(y.numpy().flatten(), predicted_labels.numpy().flatten())
recall = recall_score(y.numpy().flatten(), predicted_labels.numpy().flatten())
f1 = f1_score(y.numpy().flatten(), predicted_labels.numpy().flatten())

print(f"精确度(Precision): {precision:.4f}")
print(f"召回率(Recall): {recall:.4f}")
print(f"F1 分数: {f1:.4f}")


http://www.niftyadmin.cn/n/5840555.html

相关文章

leetcode 2856. 删除数对后的最小数组长度

题目如下 数据范围 示例 我们假设存在一个出现频率最高的数a那么我们可以把这个数组分成三部分那么第一部分和第三部分必然可以消去一部分 然后它们剩下的和a再消去 当a的数量是数组的一半那么显然剩下的就是0 当a的数量大于数组的一半那么显然存在无法消去的a 剩2 * count…

Python之Excel操作 - 读取数据

我们将使用 openpyxl 库,它是一个功能强大且易于使用的库,专门用于处理 Excel 文件。 1. 安装 openpyxl 首先,你需要安装 openpyxl 库。你可以使用 pip 命令进行安装: pip install openpyxl2. 读取 Excel 文件 要读取 Excel 文…

每日一个小题

import pygame import random # 初始化 Pygame pygame.init() # 屏幕大小 screen_width 300 screen_height 600 block_size 30 # 颜色定义 colors [ (0, 0, 0), (255, 0, 0), (0, 150, 0), (0, 0, 255), (255, 120, 0), (255, 255, 0), (180, 0, 255), (0, 220, 220)…

用BGP的路由聚合功能聚合大陆路由,效果显著不?

正文共:666 字 11 图,预估阅读时间:1 分钟 之前我们统计过中国境内的IP地址和路由信息(你知道中国大陆一共有多少IPv4地址吗?),不过数量比较多,有8000多条。截止到2021年底&#xff…

【算法】回溯算法专题② ——组合型回溯 + 剪枝 python

目录 前置知识进入正题小试牛刀实战演练总结 前置知识 【算法】回溯算法专题① ——子集型回溯 python 进入正题 组合https://leetcode.cn/problems/combinations/submissions/596357179/ 给定两个整数 n 和 k,返回范围 [1, n] 中所有可能的 k 个数的组合。 你可以…

RocketMQ中的NameServer主要数据结构

1.前言 NameServer是RocketMQ中的一个比较重要的组件,我们这篇博客针对NameSever中包含的组件进行分析,分析一下NameServer中包含的组件以及组件的作用。以前我有一篇博客中rocketMq源码分析之搭建本地环境-CSDN博客,在这篇博客中就简单看了…

【C语言设计模式学习笔记1】面向接口编程/简单工厂模式/多态

面向接口编程可以提供更高级的抽象,实现的时候,外部不需要知道内部的具体实现,最简单的是使用简单工厂模式来进行实现,比如一个Sensor具有多种表示形式,这时候可以在给Sensor结构体添加一个enum类型的type,…

【go语言】grpc 快速入门

一、什么是 grpc 和 protobuf 1.1 grpc gRPC 是由 Google 开发的一个高效、开源的远程过程调用(RPC)框架,用于在分布式系统中进行通信。它是基于 HTTP/2 协议,支持多种语言,能够让不同的系统或应用程序(即…