搜索
您的当前位置:首页正文

STGCN(时空图卷积网络)详解

来源:易榕旅网

1️⃣ STGCN介绍

前面已经介绍过了。这篇论文《Spatial Temporal Graph Convolutional Networks for Skeleton-Based Action Recognition》将GCN扩展到时空图模型上,用于实现动作识别。下图展示了STGCN的输入,即一系列骨架图。其中每个节点对应于人体的一个关节,有两种类型的边,①符合节点自然连通性的空间边(图1中淡蓝色线条)②跨越连续时间步长连接相同节点的时间边(淡绿色线条)


2️⃣ 网络结构

STGCN简单的网络结构如下所示,由三部分构成:

  • 归一化:对输入数据归一化
  • 时空变化:通过多个ST-GCN块,每个块中交替使用GCN和TCN
  • 输出:使用平均池化和全连接层对特征进行分类

在这里,我们不关注数据部分,只对网络结构进行解析,因此我们来看ST-GCN块的详细结构:

  • 步骤一:引入一个可学习的权重矩阵,与邻接矩阵大小一致,记作Learnable edge importance weight。让它与邻接矩阵 A A A按位相乘,得到加权后的邻接矩阵。其目的是给重要的边较大的权重,给非重要的边较小的权重。
  • 步骤二:将加权后的邻接矩阵与输入数据送到GCN中进行运算
  • 步骤三:利用TCN网络,实现时间维度信息的聚合

3️⃣ 代码

# 只包含网络结构,“网络的输入(关节坐标)”和“邻接矩阵”都是随机生成的

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


# ---------------------------------------------------------------------------------
# 空域图卷积
# ---------------------------------------------------------------------------------
class SpatialGraphConvolution(nn.Module):
    """
    Args:
    in_channels: 输入通道数,表示每个节点的特征维度(例如 3 对应 x, y, z 坐标)。
    out_channels: 输出通道数,表示经过卷积后每个节点的特征维度。
    s_kernel_size: 空间卷积核的大小,等于邻接矩阵的数量,表示多重图卷积的支持
    """
    def __init__(self, in_channels, out_channels, s_kernel_size):
        super().__init__()
        self.s_kernel_size = s_kernel_size
        self.conv = nn.Conv2d(in_channels=in_channels,
                            out_channels=out_channels * s_kernel_size,
                            kernel_size=1)

    def forward(self, x, A):
        x = self.conv(x)
        n, kc, t, v = x.size()
        x = x.view(n, self.s_kernel_size, kc//self.s_kernel_size, t, v)
        #对邻接矩阵进行GC,相加特征
        x = torch.einsum('nkctv,kvw->nctw', (x, A))
        return x.contiguous()
  
# ---------------------------------------------------------------------------------
# STGCN块
# ---------------------------------------------------------------------------------
class STGCN_block(nn.Module):
    """
    Args:
    in_channels:人体动作中的关节坐标维度,即in_channels=3
    out_channels:经过这个块后每个节点的特征维度
    stride:时间步长
    t_kernel_size:时间卷积核大小为
    A_size:图的邻接矩阵尺寸为
    dropout=0.5: 随机失活的概率,用于防止过拟合
    """
    def __init__(self, in_channels, out_channels, stride, t_kernel_size, A_size, dropout=0.5):
        super().__init__()
        # 空域图卷积
        self.s_gcn = SpatialGraphConvolution(in_channels=in_channels,
                                        out_channels=out_channels,
                                        s_kernel_size=A_size[0])

        # Learnable weight matrix M 给边缘赋予权重。学习哪个边是重要的。
        self.M = nn.Parameter(torch.ones(A_size))
        # 时域图卷积
        self.t_gcn = nn.Sequential(nn.BatchNorm2d(out_channels),
                                nn.ReLU(),
                                nn.Dropout(dropout),
                                nn.Conv2d(out_channels,
                                        out_channels,
                                        (t_kernel_size, 1), # kernel_size
                                        (stride, 1), # stride
                                        ((t_kernel_size - 1) // 2, 0), # padding
                                        ),
                                nn.BatchNorm2d(out_channels),
                                nn.ReLU())

    def forward(self, x, A):
        x = self.t_gcn(self.s_gcn(x, A * self.M))
        return x
  

# ---------------------------------------------------------------------------------
# STGCN网络模型
# ---------------------------------------------------------------------------------
class STGCN(nn.Module):
    """
    Spatio-Temporal Graph Convolutional Network (ST-GCN) 模型。
    Args:
        A (torch.Tensor): 图的邻接矩阵。
        num_classes (int): 分类的类别数。
        in_channels (int): 输入通道数。
        t_kernel_size (int): 时间卷积核的大小。
    Attributes:
        A (torch.Tensor): 图的邻接矩阵。
        bn (nn.BatchNorm1d): 一维批量归一化层。
        stgcn1, stgcn2, stgcn3, stgcn4, stgcn5, stgcn6 (STGCN_block): 时空图卷积块。
        fc (nn.Conv2d): 用于分类的全连接层。
    """
    def __init__(self, A, num_classes, in_channels, t_kernel_size):
        super().__init__()
        # 邻接矩阵
        self.A = A
        # 邻接矩阵的大小
        A_size = A.size()
        
        # 批量归一化
        self.bn = nn.BatchNorm1d(in_channels * A_size[1])  # 75

        # 空间-时间图卷积块
        
        # 第一个时空卷积块
        # 输入通道数:人体动作中的关节坐标维度,即in_channels=3
        # 输出通道数:经过这个块后每个节点的特征维度
        # 时间步长(stride)为 1(时间维度不缩减)
        # 时间卷积核大小为 t_kernel_size(如 9)
        # 图的邻接矩阵尺寸为 A_size
        self.stgcn1 = STGCN_block(in_channels, 32, 1, t_kernel_size, A_size)  # in_channels=3, t_kernel_size=9,步长为1
        self.stgcn2 = STGCN_block(32, 32, 1, t_kernel_size, A_size)
        self.stgcn3 = STGCN_block(32, 32, 1, t_kernel_size, A_size)
        self.stgcn4 = STGCN_block(32, 64, 2, t_kernel_size, A_size)  # 步长为2
        self.stgcn5 = STGCN_block(64, 64, 1, t_kernel_size, A_size)
        self.stgcn6 = STGCN_block(64, 64, 1, t_kernel_size, A_size)

        # 分类层
        self.fc = nn.Conv2d(64, num_classes, kernel_size=1)

    def forward(self, x):
        """
        前向传播函数
        Args:
            x (torch.Tensor): 输入张量,形状为 (N, C, T, V),其中 N 是批次大小,C 是通道数,T 是帧数,V 是节点数。
        Returns:
            torch.Tensor: 输出张量,形状为 (N, num_classes)。
        """
        # 原始数据形状为 (N, C, T, V),即(batch size, 维度数,帧数,关节数)
        N, C, T, V = x.size()
        # 原始数据被重塑为 (N, V * C, T),N代表batch size;V*C表示关节数*维度数(3*25=75),即每一帧的所有特征作为一个整体;T表示帧数 
        x = x.permute(0, 3, 1, 2).contiguous().view(N, V * C, T)  
        # 在特征维度上,对每个时间帧的所有节点特征归一化
        x = self.bn(x)  
        x = x.view(N, V, C, T).permute(0, 2, 3, 1).contiguous()  # 重新排列张量

        # 空间-时间图卷积块
        x = self.stgcn1(x, self.A)
        x = self.stgcn2(x, self.A)
        x = self.stgcn3(x, self.A)
        x = self.stgcn4(x, self.A)
        x = self.stgcn5(x, self.A)
        x = self.stgcn6(x, self.A)

        # 分类
        x = F.avg_pool2d(x, x.size()[2:])  # 全局平均池化
        x = x.view(N, -1, 1, 1)  # 重塑张量
        x = self.fc(x)  # 应用分类层
        x = x.view(x.size(0), -1)  # 重塑输出张量
        return x


def random_adjacency_matrix(num_nodes):
    # 生成一个随机的上三角矩阵
    upper_triangle = np.triu(np.random.randint(0, 2, size=(num_nodes, num_nodes)), k=1)
    # 生成对角线上的自环
    diagonal = np.eye(num_nodes, dtype=int)
    # 将上三角矩阵转为对称矩阵
    adjacency_matrix = upper_triangle + upper_triangle.T + diagonal
    return adjacency_matrix



if __name__ == "__main__":
    # 一个图有25个节点,随机生成其邻接矩阵
    num_nodes = 25
    adjacency_matrix = random_adjacency_matrix(num_nodes)
    adjacency_tensor_3d = torch.tensor(adjacency_matrix, dtype=torch.float32).unsqueeze(0)

    # adjacency_tensor_3d: 三维邻接张量,用于定义图结构
    # num_classes=10: 模型输出的类别数为10
    # in_channels=3: 输入通道数为3
    # t_kernel_size=9: 时间卷积核的大小为9
    model = STGCN(adjacency_tensor_3d, 
                   num_classes=10,
                   in_channels=3,
                   t_kernel_size=9, 
                   )

    
    # 数据集分为十个类别 
    # 0:喝
    # 1:投掷
    # 2:坐
    # 3:站起来
    # 4:掌声
    # 5:挥手
    # 6:踢
    # 7:跳跃
    # 8:敬礼
    # 9:倒立
    
    # 数据结构表示为(batch size, 维度数,帧数,关节数)
    
    # 生成一个200, 3, 80, 25的张量
    # 表示输入有200个数据,每个数据有80帧,一帧里有25个关节点,每个关节点有3个维度
    input = torch.randn(200, 3, 80, 25)
    # 经过网络模型,得到输出
    # 输出为[batch,10]表示对应动作类别的概率
    output = model(input)
    print(output.shape)
   

输出为torch.Size([200, 10]),表示对应动作类别的概率


4️⃣ 总结

  • STGCN结合了图卷积神经网络(GCN)和时间卷积网络(TCN)

5️⃣ 参考


因篇幅问题不能全部显示,请点此查看更多更全内容

Top