递归神经网络(Recursive Neural Networks, RecNNs) 是一种特殊结构的神经网络,适用于树形结构数据,与时间序列中使用的循环神经网络(RNN)不同。它常用于自然语言处理中的语法树建模、句法分析、情感分析等任务。本文主要介绍PyTorch 递归神经网络(Recursive Neural Networks)。

 1、RecNN vs. RNN 的区别

特性RNN(Recurrent)RecNN(Recursive)
输入结构序列(线性)树结构(如句法树)
应用时间序列、语言建模语义建模、句法分析
参数共享
网络连接方式时间步骤之间语法结构层次之间

2、RecNN 实现

PyTorch 没有内置 RecNN 模块,但可以通过递归构造计算图来自定义实现。例如在句法树结构上自底向上地合并左右子节点。每个 TreeNode 表示一个语法树节点,word_embedding 可由词嵌入层提供(如 nn.Embedding)模型递归地遍历树,合并左右子树向上传递语义信息,可将根节点输出送入分类器完成句子级任务(如情感分析)

import torch
import torch.nn as nn
import torch.nn.functional as F

class TreeNode:
    def __init__(self, left=None, right=None, word_embedding=None):
        self.left = left
        self.right = right
        self.word_embedding = word_embedding  # torch tensor
        self.state = None  # 将被递归计算填充

class RecursiveNN(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(RecursiveNN, self).__init__()
        self.W = nn.Linear(input_dim * 2, hidden_dim)

    def forward(self, node):
        if node.left is None and node.right is None:
            node.state = node.word_embedding
        else:
            left_state = self.forward(node.left)
            right_state = self.forward(node.right)
            combined = torch.cat([left_state, right_state], dim=-1)
            node.state = torch.tanh(self.W(combined))
        return node.state

推荐文档