Module text_embeddings.pruning.ltp

Expand source code
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Date    : 2021-07-18 14:39:02
# @Author  : Chenghao Mou (mouchenghao@gmail.com)

from typing import Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor


class LTPMultiHeadAttention(nn.MultiheadAttention):
    def __init__(
        self,
        temperature,
        embed_dim,
        num_heads,
        dropout=0.0,
        bias=True,
        add_bias_kv=False,
        add_zero_attn=False,
        kdim=None,
        vdim=None,
        batch_first=False,
        device=None,
        dtype=None,
    ) -> None:
        """
        Examples
        --------
        >>> attention = LTPMultiHeadAttention(1, 512, 8, 0.5, batch_first=True)
        >>> x = torch.rand((10, 128, 512))
        >>> output, weights, norm = attention(x, x, x)
        >>> output.shape
        torch.Size([10, 128, 512])
        >>> weights.shape
        torch.Size([10, 128, 128])
        """
        super().__init__(
            embed_dim,
            num_heads,
            dropout,
            bias,
            add_bias_kv,
            add_zero_attn,
            kdim,
            vdim,
            batch_first,
            device,
            dtype,
        )
        self.temperature = temperature
        self.soft_threshold = nn.Parameter(torch.rand(1), requires_grad=True)

    def forward(
        self,
        query: Tensor,
        key: Tensor,
        value: Tensor,
        key_padding_mask: Optional[Tensor] = None,
        need_weights: bool = True,
        attn_mask: Optional[Tensor] = None,
    ) -> Tuple[Tensor, Optional[Tensor]]:
        r"""
        Args:
            query, key, value: map a query and a set of key-value pairs to an output.
                See "Attention Is All You Need" for more details.
            key_padding_mask: if provided, specified padding elements in the key will
                be ignored by the attention. When given a binary mask and a value is True,
                the corresponding value on the attention layer will be ignored. When given
                a byte mask and a value is non-zero, the corresponding value on the attention
                layer will be ignored
            need_weights: output attn_output_weights.
            attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
                the batches while a 3D mask allows to specify a different mask for the entries of each batch.

        Shapes for inputs:
            - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
              the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``.
            - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
              the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``.
            - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
              the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``.
            - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
              If a ByteTensor is provided, the non-zero positions will be ignored while the position
              with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
              value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
            - attn_mask: if a 2D mask: :math:`(L, S)` where L is the target sequence length, S is the
              source sequence length.

              If a 3D mask: :math:`(N\cdot\text{num\_heads}, L, S)` where N is the batch size, L is the target sequence
              length, S is the source sequence length. ``attn_mask`` ensure that position i is allowed to attend
              the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
              while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
              is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
              is provided, it will be added to the attention weight.

        Shapes for outputs:
            - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
              E is the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``.
            - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
              L is the target sequence length, S is the source sequence length.
        """
        if self.batch_first:
            query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
        
        if not self._qkv_same_embed_dim:
            attn_output, attn_output_weights = F.multi_head_attention_forward(
                query,
                key,
                value,
                self.embed_dim,
                self.num_heads,
                self.in_proj_weight,
                self.in_proj_bias,
                self.bias_k,
                self.bias_v,
                self.add_zero_attn,
                self.dropout,
                self.out_proj.weight,
                self.out_proj.bias,
                training=self.training,
                key_padding_mask=key_padding_mask,
                need_weights=need_weights,
                attn_mask=attn_mask,
                use_separate_proj_weight=True,
                q_proj_weight=self.q_proj_weight,
                k_proj_weight=self.k_proj_weight,
                v_proj_weight=self.v_proj_weight,
            )
        else:
            attn_output, attn_output_weights = F.multi_head_attention_forward(
                query,
                key,
                value,
                self.embed_dim,
                self.num_heads,
                self.in_proj_weight,
                self.in_proj_bias,
                self.bias_k,
                self.bias_v,
                self.add_zero_attn,
                self.dropout,
                self.out_proj.weight,
                self.out_proj.bias,
                training=self.training,
                key_padding_mask=key_padding_mask,
                need_weights=need_weights,
                attn_mask=attn_mask,
            )

        # (N, L, S) -> (N, S/L)
        scores = torch.mean(attn_output_weights, dim=1)
        pruning_mask = F.sigmoid((scores - self.soft_threshold) / self.temperature)
        attn_output = attn_output.transpose(1, 0)
        attn_output = pruning_mask[:, :, None] * attn_output

        if self.batch_first:
            return (
                attn_output,
                attn_output_weights,
                torch.sum(torch.norm(pruning_mask, p=1) / self.num_heads),
            )
        else:
            return (
                attn_output.transpose(1, 0),
                attn_output_weights,
                torch.sum(torch.norm(pruning_mask, p=1) / self.num_heads),
            )

Classes

class LTPMultiHeadAttention (temperature, embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None)

Allows the model to jointly attend to information from different representation subspaces as described in the paper: Attention Is All You Need <https://arxiv.org/abs/1706.03762>_.

Multi-Head Attention is defined as:

[ \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O ] where :math:head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V).

forward() will use a special optimized implementation if all of the following conditions are met:

  • self attention is being computed (i.e., query, key, and value are the same tensor. This restriction will be loosened in the future.)
  • Either autograd is disabled (using torch.inference_mode or torch.no_grad) or no tensor argument requires_grad
  • training is disabled (using .eval())
  • dropout is 0
  • add_bias_kv is False
  • add_zero_attn is False
  • batch_first is True and the input is batched
  • kdim and vdim are equal to embed_dim
  • at most one of key_padding_mask or attn_mask is passed
  • if a NestedTensor <https://pytorch.org/docs/stable/nested.html>_ is passed, neither key_padding_mask nor attn_mask is passed

If the optimized implementation is in use, a NestedTensor <https://pytorch.org/docs/stable/nested.html> can be passed for query/key/value to represent padding more efficiently than using a padding mask. In this case, a NestedTensor <https://pytorch.org/docs/stable/nested.html> will be returned, and an additional speedup proportional to the fraction of the input that is padding can be expected.

Args

embed_dim
Total dimension of the model.
num_heads
Number of parallel attention heads. Note that embed_dim will be split across num_heads (i.e. each head will have dimension embed_dim // num_heads).
dropout
Dropout probability on attn_output_weights. Default: 0.0 (no dropout).
bias
If specified, adds bias to input / output projection layers. Default: True.
add_bias_kv
If specified, adds bias to the key and value sequences at dim=0. Default: False.
add_zero_attn
If specified, adds a new batch of zeros to the key and value sequences at dim=1. Default: False.
kdim
Total number of features for keys. Default: None (uses kdim=embed_dim).
vdim
Total number of features for values. Default: None (uses vdim=embed_dim).
batch_first
If True, then the input and output tensors are provided as (batch, seq, feature). Default: False (seq, batch, feature).

Examples::

>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)

Examples

>>> attention = LTPMultiHeadAttention(1, 512, 8, 0.5, batch_first=True)
>>> x = torch.rand((10, 128, 512))
>>> output, weights, norm = attention(x, x, x)
>>> output.shape
torch.Size([10, 128, 512])
>>> weights.shape
torch.Size([10, 128, 128])
Expand source code
class LTPMultiHeadAttention(nn.MultiheadAttention):
    def __init__(
        self,
        temperature,
        embed_dim,
        num_heads,
        dropout=0.0,
        bias=True,
        add_bias_kv=False,
        add_zero_attn=False,
        kdim=None,
        vdim=None,
        batch_first=False,
        device=None,
        dtype=None,
    ) -> None:
        """
        Examples
        --------
        >>> attention = LTPMultiHeadAttention(1, 512, 8, 0.5, batch_first=True)
        >>> x = torch.rand((10, 128, 512))
        >>> output, weights, norm = attention(x, x, x)
        >>> output.shape
        torch.Size([10, 128, 512])
        >>> weights.shape
        torch.Size([10, 128, 128])
        """
        super().__init__(
            embed_dim,
            num_heads,
            dropout,
            bias,
            add_bias_kv,
            add_zero_attn,
            kdim,
            vdim,
            batch_first,
            device,
            dtype,
        )
        self.temperature = temperature
        self.soft_threshold = nn.Parameter(torch.rand(1), requires_grad=True)

    def forward(
        self,
        query: Tensor,
        key: Tensor,
        value: Tensor,
        key_padding_mask: Optional[Tensor] = None,
        need_weights: bool = True,
        attn_mask: Optional[Tensor] = None,
    ) -> Tuple[Tensor, Optional[Tensor]]:
        r"""
        Args:
            query, key, value: map a query and a set of key-value pairs to an output.
                See "Attention Is All You Need" for more details.
            key_padding_mask: if provided, specified padding elements in the key will
                be ignored by the attention. When given a binary mask and a value is True,
                the corresponding value on the attention layer will be ignored. When given
                a byte mask and a value is non-zero, the corresponding value on the attention
                layer will be ignored
            need_weights: output attn_output_weights.
            attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
                the batches while a 3D mask allows to specify a different mask for the entries of each batch.

        Shapes for inputs:
            - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
              the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``.
            - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
              the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``.
            - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
              the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``.
            - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
              If a ByteTensor is provided, the non-zero positions will be ignored while the position
              with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
              value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
            - attn_mask: if a 2D mask: :math:`(L, S)` where L is the target sequence length, S is the
              source sequence length.

              If a 3D mask: :math:`(N\cdot\text{num\_heads}, L, S)` where N is the batch size, L is the target sequence
              length, S is the source sequence length. ``attn_mask`` ensure that position i is allowed to attend
              the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
              while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
              is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
              is provided, it will be added to the attention weight.

        Shapes for outputs:
            - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
              E is the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``.
            - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
              L is the target sequence length, S is the source sequence length.
        """
        if self.batch_first:
            query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
        
        if not self._qkv_same_embed_dim:
            attn_output, attn_output_weights = F.multi_head_attention_forward(
                query,
                key,
                value,
                self.embed_dim,
                self.num_heads,
                self.in_proj_weight,
                self.in_proj_bias,
                self.bias_k,
                self.bias_v,
                self.add_zero_attn,
                self.dropout,
                self.out_proj.weight,
                self.out_proj.bias,
                training=self.training,
                key_padding_mask=key_padding_mask,
                need_weights=need_weights,
                attn_mask=attn_mask,
                use_separate_proj_weight=True,
                q_proj_weight=self.q_proj_weight,
                k_proj_weight=self.k_proj_weight,
                v_proj_weight=self.v_proj_weight,
            )
        else:
            attn_output, attn_output_weights = F.multi_head_attention_forward(
                query,
                key,
                value,
                self.embed_dim,
                self.num_heads,
                self.in_proj_weight,
                self.in_proj_bias,
                self.bias_k,
                self.bias_v,
                self.add_zero_attn,
                self.dropout,
                self.out_proj.weight,
                self.out_proj.bias,
                training=self.training,
                key_padding_mask=key_padding_mask,
                need_weights=need_weights,
                attn_mask=attn_mask,
            )

        # (N, L, S) -> (N, S/L)
        scores = torch.mean(attn_output_weights, dim=1)
        pruning_mask = F.sigmoid((scores - self.soft_threshold) / self.temperature)
        attn_output = attn_output.transpose(1, 0)
        attn_output = pruning_mask[:, :, None] * attn_output

        if self.batch_first:
            return (
                attn_output,
                attn_output_weights,
                torch.sum(torch.norm(pruning_mask, p=1) / self.num_heads),
            )
        else:
            return (
                attn_output.transpose(1, 0),
                attn_output_weights,
                torch.sum(torch.norm(pruning_mask, p=1) / self.num_heads),
            )

Ancestors

  • torch.nn.modules.activation.MultiheadAttention
  • torch.nn.modules.module.Module

Class variables

var bias_k : Union[torch.Tensor, NoneType]
var bias_v : Union[torch.Tensor, NoneType]

Methods

def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, key_padding_mask: Union[torch.Tensor, NoneType] = None, need_weights: bool = True, attn_mask: Union[torch.Tensor, NoneType] = None) ‑> Tuple[torch.Tensor, Union[torch.Tensor, NoneType]]

Args

query, key, value: map a query and a set of key-value pairs to an output.
See "Attention Is All You Need" for more details.
key_padding_mask
if provided, specified padding elements in the key will be ignored by the attention. When given a binary mask and a value is True, the corresponding value on the attention layer will be ignored. When given a byte mask and a value is non-zero, the corresponding value on the attention layer will be ignored
need_weights
output attn_output_weights.
attn_mask
2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all the batches while a 3D mask allows to specify a different mask for the entries of each batch.

Shapes for inputs: - query: :math:(L, N, E) where L is the target sequence length, N is the batch size, E is the embedding dimension. :math:(N, L, E) if batch_first is True. - key: :math:(S, N, E), where S is the source sequence length, N is the batch size, E is the embedding dimension. :math:(N, S, E) if batch_first is True. - value: :math:(S, N, E) where S is the source sequence length, N is the batch size, E is the embedding dimension. :math:(N, S, E) if batch_first is True. - key_padding_mask: :math:(N, S) where N is the batch size, S is the source sequence length. If a ByteTensor is provided, the non-zero positions will be ignored while the position with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the value of True will be ignored while the position with the value of False will be unchanged. - attn_mask: if a 2D mask: :math:(L, S) where L is the target sequence length, S is the source sequence length.

  If a 3D mask: :math:`(N\cdot\text{num\_heads}, L, S)` where N is the batch size, L is the target sequence
  length, S is the source sequence length. <code>attn\_mask</code> ensure that position i is allowed to attend
  the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
  while the zero positions will be unchanged. If a BoolTensor is provided, positions with <code>True</code>
  is not allowed to attend while <code>False</code> values will be unchanged. If a FloatTensor
  is provided, it will be added to the attention weight.

Shapes for outputs: - attn_output: :math:(L, N, E) where L is the target sequence length, N is the batch size, E is the embedding dimension. :math:(N, L, E) if batch_first is True. - attn_output_weights: :math:(N, L, S) where N is the batch size, L is the target sequence length, S is the source sequence length.

Expand source code
def forward(
    self,
    query: Tensor,
    key: Tensor,
    value: Tensor,
    key_padding_mask: Optional[Tensor] = None,
    need_weights: bool = True,
    attn_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tensor]]:
    r"""
    Args:
        query, key, value: map a query and a set of key-value pairs to an output.
            See "Attention Is All You Need" for more details.
        key_padding_mask: if provided, specified padding elements in the key will
            be ignored by the attention. When given a binary mask and a value is True,
            the corresponding value on the attention layer will be ignored. When given
            a byte mask and a value is non-zero, the corresponding value on the attention
            layer will be ignored
        need_weights: output attn_output_weights.
        attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
            the batches while a 3D mask allows to specify a different mask for the entries of each batch.

    Shapes for inputs:
        - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
          the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``.
        - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
          the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``.
        - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
          the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``.
        - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
          If a ByteTensor is provided, the non-zero positions will be ignored while the position
          with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
          value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
        - attn_mask: if a 2D mask: :math:`(L, S)` where L is the target sequence length, S is the
          source sequence length.

          If a 3D mask: :math:`(N\cdot\text{num\_heads}, L, S)` where N is the batch size, L is the target sequence
          length, S is the source sequence length. ``attn_mask`` ensure that position i is allowed to attend
          the unmasked positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
          while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
          is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
          is provided, it will be added to the attention weight.

    Shapes for outputs:
        - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
          E is the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``.
        - attn_output_weights: :math:`(N, L, S)` where N is the batch size,
          L is the target sequence length, S is the source sequence length.
    """
    if self.batch_first:
        query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
    
    if not self._qkv_same_embed_dim:
        attn_output, attn_output_weights = F.multi_head_attention_forward(
            query,
            key,
            value,
            self.embed_dim,
            self.num_heads,
            self.in_proj_weight,
            self.in_proj_bias,
            self.bias_k,
            self.bias_v,
            self.add_zero_attn,
            self.dropout,
            self.out_proj.weight,
            self.out_proj.bias,
            training=self.training,
            key_padding_mask=key_padding_mask,
            need_weights=need_weights,
            attn_mask=attn_mask,
            use_separate_proj_weight=True,
            q_proj_weight=self.q_proj_weight,
            k_proj_weight=self.k_proj_weight,
            v_proj_weight=self.v_proj_weight,
        )
    else:
        attn_output, attn_output_weights = F.multi_head_attention_forward(
            query,
            key,
            value,
            self.embed_dim,
            self.num_heads,
            self.in_proj_weight,
            self.in_proj_bias,
            self.bias_k,
            self.bias_v,
            self.add_zero_attn,
            self.dropout,
            self.out_proj.weight,
            self.out_proj.bias,
            training=self.training,
            key_padding_mask=key_padding_mask,
            need_weights=need_weights,
            attn_mask=attn_mask,
        )

    # (N, L, S) -> (N, S/L)
    scores = torch.mean(attn_output_weights, dim=1)
    pruning_mask = F.sigmoid((scores - self.soft_threshold) / self.temperature)
    attn_output = attn_output.transpose(1, 0)
    attn_output = pruning_mask[:, :, None] * attn_output

    if self.batch_first:
        return (
            attn_output,
            attn_output_weights,
            torch.sum(torch.norm(pruning_mask, p=1) / self.num_heads),
        )
    else:
        return (
            attn_output.transpose(1, 0),
            attn_output_weights,
            torch.sum(torch.norm(pruning_mask, p=1) / self.num_heads),
        )