pytorch实现长短期记忆网络 (LSTM)

news/2025/2/2 21:07:27 标签: pytorch, lstm, 机器学习

 人工智能例子汇总:AI常见的算法和例子-CSDN博客 

LSTM 通过 记忆单元(cell)三个门控机制(遗忘门、输入门、输出门)来控制信息流:

 记忆单元(Cell State)

  • 负责存储长期信息,并通过门控机制决定保留或丢弃信息。

 遗忘门(Forget Gate, ftf_tft​)

 输入门(Input Gate, iti_tit​)

 输出门(Output Gate, oto_tot​)

特性

传统 RNNLSTM
记忆能力短期记忆长短期记忆
计算复杂度
解决梯度消失
适用场景短序列数据长序列数据

LSTM 应用场景

  • 自然语言处理(NLP):文本生成、情感分析、机器翻译
  • 时间序列预测:股票预测、天气预报、传感器数据分析
  • 语音识别:自动字幕生成、语音转文字(ASR)
  • 机器人与控制系统:智能体决策、自动驾驶

例子:

下面例子实现了一个 基于 LSTM 的强化学习智能体,在 1D 网格环境 里移动,并找到最优路径。
最终,我们 绘制 5 条测试路径,并高亮显示最佳路径(红色)

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt


# ========== 1. 定义 LSTM 策略网络 ==========
class LSTMPolicy(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, num_layers=1):
        super(LSTMPolicy, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, output_size)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x, hidden_state):
        batch_size = x.size(0)

        # 确保 hidden_state 维度正确
        if hidden_state[0].dim() == 2:
            hidden_state = (hidden_state[0].unsqueeze(1).repeat(1, batch_size, 1),
                            hidden_state[1].unsqueeze(1).repeat(1, batch_size, 1))

        out, hidden_state = self.lstm(x, hidden_state)
        out = self.fc(out[:, -1, :])  # 取最后时间步的输出
        action_prob = self.softmax(out)  # 归一化输出,作为策略
        return action_prob, hidden_state

    def init_hidden(self, batch_size=1):
        return (torch.zeros(self.num_layers, batch_size, self.hidden_size),
                torch.zeros(self.num_layers, batch_size, self.hidden_size))


# ========== 2. 创建网格环境 ==========
class GridWorld:
    def __init__(self, grid_size=10, goal_position=9):
        self.grid_size = grid_size
        self.goal_position = goal_position
        self.reset()

    def reset(self):
        self.position = 0
        return self.position

    def step(self, action):
        if action == 0:
            self.position = max(0, self.position - 1)
        elif action == 1:
            self.position = min(self.grid_size - 1, self.position + 1)

        reward = 1 if self.position == self.goal_position else -0.1
        done = self.position == self.goal_position
        return self.position, reward, done


# ========== 3. 训练智能体 ==========
def train(num_episodes=500, max_steps=50):
    env = GridWorld()
    input_size = 1
    hidden_size = 64
    output_size = 2
    num_layers = 1

    policy = LSTMPolicy(input_size, hidden_size, output_size, num_layers)
    optimizer = optim.Adam(policy.parameters(), lr=0.01)
    gamma = 0.99

    for episode in range(num_episodes):
        state = torch.tensor([[env.reset()]], dtype=torch.float32).unsqueeze(0)  # (1, 1, input_size)
        hidden_state = policy.init_hidden(batch_size=1)

        log_probs = []
        rewards = []

        for step in range(max_steps):
            action_probs, hidden_state = policy(state, hidden_state)
            action = torch.multinomial(action_probs, 1).item()
            log_prob = torch.log(action_probs.squeeze(0)[action])
            log_probs.append(log_prob)

            next_state, reward, done = env.step(action)
            rewards.append(reward)

            if done:
                break

            state = torch.tensor([[next_state]], dtype=torch.float32).unsqueeze(0)

        # 计算回报并更新策略
        returns = []
        R = 0
        for r in reversed(rewards):
            R = r + gamma * R
            returns.insert(0, R)

        returns = torch.tensor(returns, dtype=torch.float32)
        returns = (returns - returns.mean()) / (returns.std() + 1e-9)

        loss = sum([-log_prob * R for log_prob, R in zip(log_probs, returns)])

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if (episode + 1) % 50 == 0:
            print(f"Episode {episode + 1}/{num_episodes}, Total Reward: {sum(rewards)}")

    torch.save(policy.state_dict(), "policy.pth")


# 训练智能体
train(500)


# ========== 4. 测试智能体并绘制最佳路径 ==========
def test(num_episodes=5):
    env = GridWorld()
    input_size = 1
    hidden_size = 64
    output_size = 2
    num_layers = 1

    policy = LSTMPolicy(input_size, hidden_size, output_size, num_layers)
    policy.load_state_dict(torch.load("policy.pth"))

    plt.figure(figsize=(10, 5))
    best_path = None
    best_steps = float('inf')

    for episode in range(num_episodes):
        state = torch.tensor([[env.reset()]], dtype=torch.float32).unsqueeze(0)  # (1, 1, input_size)
        hidden_state = policy.init_hidden(batch_size=1)
        positions = [env.position]  # 记录位置变化

        while True:
            action_probs, hidden_state = policy(state, hidden_state)
            action = torch.argmax(action_probs, dim=-1).item()
            next_state, reward, done = env.step(action)
            positions.append(next_state)

            if done:
                break

            state = torch.tensor([[next_state]], dtype=torch.float32).unsqueeze(0)

        # 记录最佳路径(最短步数)
        if len(positions) < best_steps:
            best_steps = len(positions)
            best_path = positions

        # 绘制普通路径(蓝色)
        plt.plot(range(len(positions)), positions, marker='o', linestyle='-', color='blue', alpha=0.6,
                 label=f'Episode {episode + 1}' if episode == 0 else "")

    # 绘制最佳路径(红色)
    if best_path:
        plt.plot(range(len(best_path)), best_path, marker='o', linestyle='-', color='red', linewidth=2,
                 label="Best Path")

    # 打印最佳路径
    print(f"Best Path (steps={best_steps}): {best_path}")

    plt.xlabel("Time Steps")
    plt.ylabel("Agent Position")
    plt.title("Agent's Movement Path (Best Path in Red)")
    plt.legend()
    plt.grid(True)
    plt.show()


# 测试并绘制智能体移动路径
test(5)


http://www.niftyadmin.cn/n/5840297.html

相关文章

蓝桥杯python语言基础(1)——编程基础

目录 一、python开发环境 二、python输入输出 &#xff08;1&#xff09;print输出函数 print(*object&#xff0c;sep,end\n,......) &#xff08;2&#xff09;input输入函数 input([prompt]), 输入的变量均为str字符串类型&#xff01; input()会读入一整行的信息 ​编…

Scratch 《像素战场》系列综合游戏:像素战场游戏Ⅰ~Ⅲ 介绍

资源下载 Scratch《像素战场》系列综合游戏合集&#xff1a;像素战场游戏Ⅰ~Ⅲ压缩包 https://download.csdn.net/download/leyang0910/90332765 游戏操作介绍 Scratch 《像素战场Ⅰ》操作规则&#xff1a; 这是一款与朋友一起玩的 1v1 游戏。先赢得6轮胜利&#xff01; WA…

如何编写一个MyBatis插件?

大家好&#xff0c;我是锋哥。今天分享关于【Redis为什么这么快?】面试题。希望对大家有帮助&#xff1b; 如何编写一个MyBatis插件&#xff1f; 1000道 互联网大厂Java工程师 精选面试题-Java资源分享网 编写 MyBatis 插件需要使用 MyBatis 提供的插件接口&#xff0c;MyBa…

CSS 基础:层叠、优先级与继承

CSS 基础&#xff1a;层叠、优先级与继承 一、层叠&#xff08;Cascade&#xff09;示例&#xff1a;层叠的顺序 二、优先级&#xff08;Specificity&#xff09;优先级规则示例&#xff1a;优先级的比较 三、继承&#xff08;Inheritance&#xff09;哪些属性会被继承&#xf…

C++哈希(链地址法)(二)详解

文章目录 1.开放地址法1.1key不能取模的问题1.1.1将字符串转为整型1.1.2将日期类转为整型 2.哈希函数2.1乘法散列法&#xff08;了解&#xff09;2.2全域散列法&#xff08;了解&#xff09; 3.处理哈希冲突3.1线性探测&#xff08;挨着找&#xff09;3.2二次探测&#xff08;跳…

HTB:Administrator[WriteUP]

目录 连接至HTB服务器并启动靶机 信息收集 使用rustscan对靶机TCP端口进行开放扫描 将靶机TCP开放端口号提取并保存 使用nmap对靶机TCP开放端口进行脚本、服务扫描 使用nmap对靶机TCP开放端口进行漏洞、系统扫描 使用nmap对靶机常用UDP端口进行开放扫描 使用nmap对靶机…

11.网络编程的基础知识

11.网络编程的基础知识 **1. OSI模型与TCP/IP模型****2. IP地址分类****3. Socket编程****4. TCP三次握手与四次挥手****5. 常用网络测试工具****6. 练习与作业****7. 总结** 1. OSI模型与TCP/IP模型 OSI模型&#xff08;开放系统互联模型&#xff09;&#xff1a; 7层结构&am…

第十章:大内存的申请和释放

目录 第一节&#xff1a;函数修改 1-1.ConcurrentAlloc.h 1-2.Common.h 1-3.PageCache.cpp 第二节&#xff1a;测试 第三节&#xff1a;结语 大内存的思路是将其以一页为对齐数&#xff0c;申请一个为切分的span&#xff0c;这种span在pc就有&#xff0c;所以直接到pc中申请…