【深度学习基础模型】稀疏自编码器 (Sparse Autoencoders, SAE)详细理解并附实现代码。

news/2024/9/29 12:11:55 标签: 深度学习, 人工智能, python, 学习, SAE, autoencoder

学习>深度学习基础模型】Efficient Learning of Sparse Representations with an Energy-Based Model

学习>深度学习基础模型】Efficient Learning of Sparse Representations with an Energy-Based Model


文章目录

  • 学习>深度学习基础模型】Efficient Learning of Sparse Representations with an Energy-Based Model
  • 1. 稀疏自编码器 (Sparse Autoencoders, SAE) 的原理与应用
    • 1.1 SAE 原理
    • 1.2 SAE 的主要特征:
    • 1.3 SAE 的应用领域:
  • 2. Python 代码实现 SAE 在遥感图像混合像元分解中的应用
    • 2.1 SAE 模型的实现
    • 2.2 代码解释
  • 3. 总结


参考地址:https://www.asimovinstitute.org/neural-network-zoo/
论文地址:https://www.cs.toronto.edu/~ranzato/publications/ranzato-nips06.pdf

欢迎宝子们点赞、关注、收藏!欢迎宝子们批评指正!

SAE__11">1. 稀疏自编码器 (Sparse Autoencoders, SAE) 的原理与应用

SAE__12">1.1 SAE 原理

稀疏自编码器(Sparse Autoencoder, SAE)是一种特殊类型的自编码器,其设计目的是在编码过程中引入稀疏性,以鼓励网络学习更多的特征。与标准自编码器不同,SAE 的目标是将输入信息编码到一个比输入更高维的空间中,帮助提取多种小特征。

SAE__14">1.2 SAE 的主要特征:

  • 稀疏性:通过引入稀疏性约束,限制在某一时刻只有少数神经元被激活。这可以通过添加一个稀疏性损失项来实现,该项鼓励大多数神经元在输出中保持静默。
  • 网络结构SAE 通常包含一个较小的中间层,但该中间层的激活仅通过部分神经元实现,从而生成丰富的特征表示。
  • 特征提取SAE 适用于特征提取,尤其是在需要捕捉数据中细微差别的任务中,如图像分类、异常检测等。

SAE__18">1.3 SAE 的应用领域:

  • 图像处理:可以用于从遥感图像中提取细节特征。
  • 异常检测:在数据中识别异常点。
  • 生物信息学:提取基因表达数据中的重要特征。

在遥感领域,SAE 可以用于混合像元的分解,帮助识别和分离不同地物的光谱特征。

SAE__24">2. Python 代码实现 SAE 在遥感图像混合像元分解中的应用

以下是一个简单的稀疏自编码器实现示例,展示如何在遥感图像的混合像元分解中应用 SAE

SAE__26">2.1 SAE 模型的实现

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import matplotlib.pyplot as plt

# 定义稀疏自编码器模型
class SparseAutoencoder(nn.Module):
    def __init__(self, input_size, hidden_size, sparsity_param, beta):
        super(SparseAutoencoder, self).__init__()
        
        self.encoder = nn.Linear(input_size, hidden_size)
        self.decoder = nn.Linear(hidden_size, input_size)
        self.sparsity_param = sparsity_param  # 稀疏性参数
        self.beta = beta  # 稀疏性损失权重

    def forward(self, x):
        encoded = torch.relu(self.encoder(x))  # 编码过程
        decoded = torch.sigmoid(self.decoder(encoded))  # 解码过程
        return decoded, encoded

    def sparsity_loss(self, p_h):
        # 计算稀疏性损失
        return self.beta * torch.sum(self.sparsity_param * torch.log(self.sparsity_param / p_h) +
                                      (1 - self.sparsity_param) * torch.log((1 - self.sparsity_param) / (1 - p_h)))

# 生成模拟遥感图像数据 (64 维特征)
X = np.random.rand(1000, 64)  # 1000 个样本,每个样本有 64 维光谱特征
X = torch.tensor(X, dtype=torch.float32)

# 创建数据加载器
dataset = TensorDataset(X, X)  # 输入和目标均为原始数据
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# 定义模型、优化器
input_size = 64
hidden_size = 32  # 隐藏层大小
sparsity_param = 0.05  # 稀疏性参数
beta = 1  # 稀疏性损失权重
sae = SparseAutoencoder(input_size=input_size, hidden_size=hidden_size, sparsity_param=sparsity_param, beta=beta)
optimizer = optim.Adam(sae.parameters(), lr=0.001)

# 训练 SAE 模型
num_epochs = 50
for epoch in range(num_epochs):
    for data in dataloader:
        optimizer.zero_grad()
        inputs, original_data = data
        reconstructed_data, encoded_data = sae(inputs)  # 前向传播
        loss = nn.functional.binary_cross_entropy(reconstructed_data, original_data)  # 重构损失
        # 计算稀疏性损失
        p_h = torch.mean(encoded_data, dim=0)
        loss += sae.sparsity_loss(p_h)  # 加入稀疏性损失
        loss.backward()
        optimizer.step()
    
    if epoch % 10 == 0:
        print(f'Epoch {epoch}, Loss: {loss.item()}')

# 使用训练好的模型进行特征提取
with torch.no_grad():
    _, encoded_data = sae(X).numpy()  # 提取编码后的特征

# 可视化原始数据和编码后的特征
plt.figure(figsize=(12, 6))

plt.subplot(1, 2, 1)
plt.title('Original Data')
plt.imshow(X.numpy()[:10], aspect='auto', cmap='hot')

plt.subplot(1, 2, 2)
plt.title('Encoded Features')
plt.imshow(encoded_data[:10], aspect='auto', cmap='hot')

plt.show()

2.2 代码解释

1. 模型定义:

class SparseAutoencoder(nn.Module):
    def __init__(self, input_size, hidden_size, sparsity_param, beta):
        super(SparseAutoencoder, self).__init__()
        
        self.encoder = nn.Linear(input_size, hidden_size)
        self.decoder = nn.Linear(hidden_size, input_size)
        self.sparsity_param = sparsity_param  # 稀疏性参数
        self.beta = beta  # 稀疏性损失权重
  • SparseAutoencoder 类定义了编码器和解码器的结构,同时定义了稀疏性参数和稀疏性损失权重。

2. 前向传播:

def forward(self, x):
    encoded = torch.relu(self.encoder(x))  # 编码过程
    decoded = torch.sigmoid(self.decoder(encoded))  # 解码过程
    return decoded, encoded
  • 输入数据通过编码器和解码器处理,输出重构数据和编码数据。

3. 稀疏性损失计算:

def sparsity_loss(self, p_h):
    return self.beta * torch.sum(self.sparsity_param * torch.log(self.sparsity_param / p_h) +
                                  (1 - self.sparsity_param) * torch.log((1 - self.sparsity_param) / (1 - p_h)))
  • 计算稀疏性损失,用于约束神经元的激活。

4. 数据生成:

X = np.random.rand(1000, 64)  # 生成 1000 个样本,每个样本有 64 维光谱特征
  • 模拟生成随机的遥感光谱数据。

5. 数据加载器:

dataset = TensorDataset(X, X)  # 输入和目标均为原始数据
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
  • 使用 DataLoader 创建批处理数据集。

6. 模型训练:

for epoch in range(num_epochs):
    for data in dataloader:
        optimizer.zero_grad()
        inputs, original_data = data
        reconstructed_data, encoded_data = sae(inputs)  # 前向传播
        loss = nn.functional.binary_cross_entropy(reconstructed_data, original_data)  # 重构损失
        p_h = torch.mean(encoded_data, dim=0)  # 计算激活的平均值
        loss += sae.sparsity_loss(p_h)  # 加入稀疏性损失
        loss.backward()
        optimizer.step()
  • 在 50 个 epoch 内进行训练,计算重构损失并加入稀疏性损失,更新模型权重。

7. 特征提取:

with torch.no_grad():
    _, encoded_data = sae(X).numpy()  # 提取编码后的特征
  • 在测试阶段,使用训练好的模型提取编码后的特征。

8. 可视化:

plt.subplot(1, 2, 1)
plt.title('Original Data')
plt.imshow(X.numpy()[:10], aspect='auto', cmap='hot')
  • 可视化原始数据和编码后的特征进行比较。

3. 总结

稀疏自编码器(SAE)是一种强大的特征学习模型,能够提取数据中的细微特征。通过引入稀疏性约束,SAE 有效地鼓励网络学习有用的特征表示。

在遥感领域,SAE 可用于混合像元的分解,帮助识别和分离不同地物的光谱特征。通过简单的 Python 实现,我们展示了如何使用 SAE 处理遥感数据,并可视化其效果。


http://www.niftyadmin.cn/n/5682998.html

相关文章

用Python绘制爱心形状的网络流量分析工具

摘要 在本文中,我们将探讨如何使用Python编程语言创建一个网络流量分析工具,该工具以爱心形状展示网络流量数据。这种创新的可视化方法不仅增加了数据分析的趣味性,同时也提高了数据解读的直观性。文章将详细介绍实现步骤和代码示例。 1. 引…

智能绘画,体现非凡想象力以文生图功能简单操作

智能绘画,体现非凡想象力以文生图功能简单操作 智能绘画技术突破了人类自身的极限,让绘画分析进入到一个更为广泛的视野中。通过输入描述性的文字,便可生成便可生成同一主题、不同风格的画作,体现出非凡的想象力,象征未…

【已解决】【Hadoop】找到java环境路径

在 Hadoop 环境中,Java 环境路径通常指的是 Java 的安装目录,因为 Hadoop 是用 Java 编写的,并且需要 Java 运行时环境(JRE)或 Java 开发工具(JDK)来运行。以下是几种方法来找到 Java 环境路径&…

【Qt笔记】QFrame控件详解

目录 引言 一、QFrame的基本特性 二、QFrame的常用方法 2.1 边框形状(Frame Shape) 2.2 阴影样式(Frame Shadow) 2.3 线条宽度(Line Width) 2.4 样式表(styleSheet) 三、QFrame的应用场景 四、应用…

【Springboot入门- RESTful服务的支持】

使用 RestController 和 RequestMapping RestController 注解表明该类中的所有方法都会返回结果直接写入HTTP响应体中,而不是使用视图解析器。RequestMapping 注解用于将HTTP请求映射到控制器的处理方法上。 RestController public class MyController {RequestMap…

ass字幕文件怎么导入视频mp4?ass字幕怎么编辑?视频加字幕超简单!

ass字幕文件怎么导入视频mp4?ass字幕怎么编辑?在视频制作和观看过程中,添加字幕是一项常见的需求,特别是对于外语视频或需要辅助阅读的场景。ASS(Advanced SubStation Alpha)字幕文件是一种常用的字幕格式&…

第四章-课后练习5:修正指数曲线模型——excel和python应用(2)

一、修正指数曲线模型——excel应用 对于以下数据: 年份销售量201346000201449000201551400201653320201754856201856085201957088202057900202158563 1、数据特征描述 首先使用excel绘制散点图,观察数据分布情况: 由图像得知&#xff0…

LeetCode 算法:只出现一次的数字 c++

原题链接🔗:只出现一次的数字难度:简单⭐️ 题目 给你一个 非空 整数数组 nums ,除了某个元素只出现一次以外,其余每个元素均出现两次。找出那个只出现了一次的元素。 你必须设计并实现线性时间复杂度的算法来解决此…