自定义数据集 ,使用朴素贝叶斯对其进行分类

news/2025/2/3 6:13:21 标签: python, numpy, 开发语言

代码:

# 导入必要的库
import numpy as np
import matplotlib.pyplot as plt

# 定义类1的数据点,每个数据点是二维的坐标
class1_points = np.array([[1.9, 1.2],
                          [1.5, 2.1],
                          [1.9, 0.5],
                          [1.5, 0.9],
                          [0.9, 1.2],
                          [1.1, 1.7],
                          [1.4, 1.1]])

# 定义类2的数据点,每个数据点是二维的坐标
class2_points = np.array([[3.2, 3.2],
                          [3.7, 2.9],
                          [3.2, 2.6],
                          [1.7, 3.3],
                          [3.4, 2.6],
                          [4.1, 2.3],
                          [3.0, 2.9]])

# 将两个类的数据点合并为一个大的数据集 x,标签合并为 y
x = np.concatenate((class1_points, class2_points), axis=0)  # x 是一个 (14, 2) 形状的数组,包含所有的点
y = np.concatenate((np.zeros(len(class1_points)), np.ones(len(class2_points))), axis=0)  # y 是一个包含标签的数组,0 表示 class1,1 表示 class2

# 计算先验概率,即每个类在数据集中的概率
prior_prob = [np.sum(y == 0) / len(y), np.sum(y == 1) / len(y)]

# 计算每个类的均值向量(每个类的特征的均值)
class_u = [np.mean(x[y==0], axis=0), np.mean(x[y==1], axis=0)]

# 计算每个类的协方差矩阵
class_cov = [np.cov(x[y==0], rowvar=False), np.cov(x[y==1], rowvar=False)]

# 定义概率密度函数 (PDF),用于计算高斯分布的概率
def pdf(x, mean, cov):
    n = len(mean)  # mean 是特征的维度,这里是2,因为每个数据点有两个特征
    # 计算常数系数部分,注意是协方差矩阵的行列式
    coff = 1 / (2 * np.pi) ** (n / 2) * np.sqrt(np.linalg.det(cov))
    # 计算指数部分,(x - mean) 转置和协方差矩阵的逆相乘,再与 (x - mean) 相乘
    exponent = np.exp(-(1 / 2) * np.dot(np.dot((x - mean).T, np.linalg.inv(cov)), (x - mean)))
    # 返回高斯分布的概率密度
    return coff * exponent

# 创建网格点,用于在平面上绘制决策边界
xx, yy = np.meshgrid(np.arange(0, 5, 0.05), np.arange(0, 4, 0.05))

# 将网格点转换为 (N, 2) 的矩阵,方便后续计算
grid_points = np.c_[xx.ravel(), yy.ravel()]

# 用于存储每个网格点的预测标签
grid_label = []

# 遍历网格中的每一个点,计算其后验概率,决定属于哪个类
for point in grid_points:
    poster_prob = []  # 存储每个类的后验概率
    for i in range(2):  # 遍历两个类
        # 计算每个类在该点的似然度(即高斯分布的概率)
        likelihood = pdf(point, class_u[i], class_cov[i])
        # 计算该类的后验概率,即 先验概率 * 似然度
        poster_prob.append(prior_prob[i] * likelihood)
    # 选择后验概率最大的类作为预测的标签
    pre_class = np.argmax(poster_prob)
    grid_label.append(pre_class)

# 绘制类1的样本点,蓝色标记
plt.scatter(class1_points[:, 0], class1_points[:, 1], c='blue', label='class 1')
# 绘制类2的样本点,红色标记
plt.scatter(class2_points[:, 0], class2_points[:, 1], c='red', label='class 2')
# 添加图例
plt.legend()

# 将 grid_label 转换为一个数组,并重塑为与网格形状一致的矩阵,便于绘制等高线图
grid_label = np.array(grid_label)
pre_grid_label = grid_label.reshape(xx.shape)

# 绘制决策边界,等高线图,绿色线表示决策边界(即类0与类1的分界线)
contour = plt.contour(xx, yy, pre_grid_label, levels=[0.5], colors='green')

# 显示绘图结果
plt.show()

结果:


http://www.niftyadmin.cn/n/5840531.html

相关文章

deepseek v3 搭建个人知识库

目录 知乎完整教程: deepseek-r1本地部署 Chatbox连接ollama服务 知乎完整教程: https://zhuanlan.zhihu.com/p/19848028238 deepseek-r1本地部署 公司数据不泄露,DeepSeek R1本地化部署web端访问个人知识库搭建与使用,喂饭级…

2025年Android开发趋势全景解读

文章目录 一、界面开发:从"手写代码"到"智能拼装"二、AI融合开发:无需炼丹的普惠智能三、车机开发:手机开发者的新蓝海(车企需求拆解)四、生存技能升级:开发者转型路线图五、避坑指南&…

每日 Java 面试题分享【第 18 天】

欢迎来到每日 Java 面试题分享栏目! 订阅专栏,不错过每一天的练习 今日分享 3 道面试题目! 评论区复述一遍印象更深刻噢~ 目录 问题一:什么是 Java 中的双亲委派模型?问题二:Java 中 wait() 和 sleep()…

Linux环境下的Java项目部署技巧:安装 Mysql

查看 myslq 是否安装: rpm -qa|grep mysql 如果已经安装,可执行命令来删除软件包: rpm -e --nodeps 包名 下载 repo 源: http://dev.mysql.com/get/mysql80-community-release-el7-7.noarch.rpm 执行命令安装 rpm 源(根据下载的…

基于机器学习鉴别中药材的方法

基于机器学习鉴别中药材的方法 摘要 由于不同红外光照射药材时会呈现不同的光谱特征,所以本文基于中药材的这一特点来判断其产地和种类。 针对问题一:要对附件一中所给数据对所给中药材进行分类,并就其特征和差异性进行研究。首先,我们读…

【Go - 小心! Go中slice的传递陷阱 】

📢注意:slice 是引用传递 ,传递过去的参数,内存没有重新分配。 示例 package mainimport "fmt"// 引用传递 ,传递过去的地址,内存没有重新分配 func test(abc []int) {abc[0] -1 }func main()…

【AI】探索自然语言处理(NLP):从基础到前沿技术及代码实践

Hi ! 云边有个稻草人-CSDN博客 必须有为成功付出代价的决心,然后想办法付出这个代价。 目录 引言 1. 什么是自然语言处理(NLP)? 2. NLP的基础技术 2.1 词袋模型(Bag-of-Words,BoW&#xff…

Java 9模块开发:Eclipse实战指南

在上一篇教程中,我们已经了解了Java 9模块的基础知识。今天,我们将深入探讨如何使用Eclipse IDE来开发和运行Java 9模块。Eclipse作为一款强大的开发工具,为Java开发提供了丰富的功能支持。不过需要注意的是,对于Eclipse 4.7&…