1、RecNN vs. RNN 的区别
特性 | RNN(Recurrent) | RecNN(Recursive) |
---|---|---|
输入结构 | 序列(线性) | 树结构(如句法树) |
应用 | 时间序列、语言建模 | 语义建模、句法分析 |
参数共享 | 是 | 是 |
网络连接方式 | 时间步骤之间 | 语法结构层次之间 |
2、RecNN 实现
PyTorch 没有内置 RecNN
模块,但可以通过递归构造计算图来自定义实现。例如在句法树结构上自底向上地合并左右子节点。每个 TreeNode
表示一个语法树节点,word_embedding
可由词嵌入层提供(如 nn.Embedding
)模型递归地遍历树,合并左右子树向上传递语义信息,可将根节点输出送入分类器完成句子级任务(如情感分析)
import torch import torch.nn as nn import torch.nn.functional as F class TreeNode: def __init__(self, left=None, right=None, word_embedding=None): self.left = left self.right = right self.word_embedding = word_embedding # torch tensor self.state = None # 将被递归计算填充 class RecursiveNN(nn.Module): def __init__(self, input_dim, hidden_dim): super(RecursiveNN, self).__init__() self.W = nn.Linear(input_dim * 2, hidden_dim) def forward(self, node): if node.left is None and node.right is None: node.state = node.word_embedding else: left_state = self.forward(node.left) right_state = self.forward(node.right) combined = torch.cat([left_state, right_state], dim=-1) node.state = torch.tanh(self.W(combined)) return node.state