cjavapy编程之程

PyTorch 卷积神经网络(Convents)的特征提取

PyTorch 中,卷积神经网络(CNN,Convolutional Neural Networks)用于图像等数据的特征提取。CNN 通过一系列卷积层、激活函数(如 ReLU)、池化层等提取局部特征,并通过更深的层次构建高级特征。迁移学习加载预训练模型(如 ResNet、VGG)并截取中间层输出作为特征提取器。图像检索、风格迁移、目标检测等任务中,特征提取作为关键步骤。

1、特征提取流程简介

CNN 能从局部到整体逐步学习图像结构,前几层学到的特征通常是通用的(如边缘),后面是特定任务的(如人脸、动物),在迁移学习中,常使用 CNN 的前几层作为特征提取器,冻结它们权重,仅训练后面的分类器。

步骤名称作用说明
1输入图像输入通常为 RGB 图像,形状为 (C, H, W),例如 (3, 224, 224)
2卷积层使用多个卷积核滑动窗口提取图像的局部特征(如边缘、纹理等)。
3激活函数(ReLU)添加非线性能力,使模型可以学习复杂模式。
4池化层减少空间维度、压缩信息,保留关键特征(如最大池化 MaxPooling)。
5重复卷积+池化多层堆叠提取更高级的特征(从边缘到形状,再到语义特征)。
6展平(Flatten)将多维特征图转为一维向量,用于后续分类或其他任务。
7分类层(可选)如果用于分类,则连接全连接层 + softmax 输出预测类别。

2、特征提取器

特征提取器(Feature Extractor) 是指提取输入数据(通常是图像)的深层次特征的一段神经网络。它可以是自定义的 CNN 模块,也可以是使用预训练模型(如 ResNet、VGG)的一部分。

import torch
import torch.nn as nn

# 用于提取中间特征的模块
class Feature_extractor(nn.Module):
    def __init__(self):
        super(Feature_extractor, self).__init__()
        self.feature = None

    def forward(self, input):
        self.feature = input.clone()
        return input

# 模拟已有 CNN(可替换为 torchvision.models.resnet18 等)
cnn = nn.Sequential(
    nn.Conv2d(3, 8, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.Conv2d(8, 16, kernel_size=3, padding=1),
    nn.ReLU(),
    nn.MaxPool2d(2),
    nn.Conv2d(16, 32, kernel_size=3, padding=1),
    nn.ReLU(),
)

# 创建新网络
new_net = nn.Sequential().cuda()
target_layers = ["conv_1", "conv_2", "conv_4"]  # 要提取的层
extractors = {}  # 用于保存 extractor 层引用

i = 1
for layer in list(cnn):
    if isinstance(layer, nn.Conv2d):
        name = f"conv_{i}"
        new_net.add_module(name, layer)
        if name in target_layers:
            extractor = Feature_extractor()
            new_net.add_module(f"extractor_{i}", extractor)
            extractors[name] = extractor
        i += 1
    elif isinstance(layer, nn.ReLU):
        new_net.add_module(f"relu_{i}", layer)
    elif isinstance(layer, nn.MaxPool2d):
        new_net.add_module(f"pool_{i}", layer)

# 模拟输入图像
your_image = torch.randn(1, 3, 32, 32).cuda()
output = new_net(your_image)

# 提取特征:例如 conv_4 的特征
print(extractors["conv_4"].feature.shape)  # 输出特征图尺寸

推荐阅读
cjavapy编程之路首页