1、RecNN vs. RNN 的区别
特性 | RNN(Recurrent) | RecNN(Recursive) |
---|---|---|
输入结构 | 序列(线性) | 树结构(如句法树) |
应用 | 时间序列、语言建模 | 语义建模、句法分析 |
参数共享 | 是 | 是 |
网络连接方式 | 时间步骤之间 | 语法结构层次之间 |
2、RecNN 实现
PyTorch 没有内置 RecNN
模块,但可以通过递归构造计算图来自定义实现。例如在句法树结构上自底向上地合并左右子节点。每个 TreeNode
表示一个语法树节点,word_embedding
可由词嵌入层提供(如 nn.Embedding
)模型递归地遍历树,合并左右子树向上传递语义信息,可将根节点输出送入分类器完成句子级任务(如情感分析)
import torch
import torch.nn as nn
import torch.nn.functional as F
class TreeNode:
def __init__(self, left=None, right=None, word_embedding=None):
self.left = left
self.right = right
self.word_embedding = word_embedding # torch tensor
self.state = None # 将被递归计算填充
class RecursiveNN(nn.Module):
def __init__(self, input_dim, hidden_dim):
super(RecursiveNN, self).__init__()
self.W = nn.Linear(input_dim * 2, hidden_dim)
def forward(self, node):
if node.left is None and node.right is None:
node.state = node.word_embedding
else:
left_state = self.forward(node.left)
right_state = self.forward(node.right)
combined = torch.cat([left_state, right_state], dim=-1)
node.state = torch.tanh(self.W(combined))
return node.state