1212
1313
1414class TransformerDecoderLayer (nn .Module ):
15- """
15+ """Transformer Decoder layer block in Pre-Norm style.
16+ Pre-Norm style is an improvement w.r.t. Original paper's Post-Norm style,
17+ providing better converge speed and performance. This is also the actual
18+ implementation in tensor2tensor and also avalable in fairseq.
19+ See https://tunz.kr/post/4 and :cite:`DeeperTransformer`.
20+
21+ .. mermaid::
22+
23+ graph LR
24+ %% "*SubLayer" can be self-attn, src-attn or feed forward block
25+ A(input) --> B[Norm]
26+ B --> C["*SubLayer"]
27+ C --> D[Drop]
28+ D --> E((+))
29+ A --> E
30+ E --> F(out)
31+
32+
1633 Args:
17- d_model (int): the dimension of keys/values/queries in
18- :class:`MultiHeadedAttention`, also the input size of
19- the first-layer of the :class:`PositionwiseFeedForward`.
20- heads (int): the number of heads for MultiHeadedAttention.
21- d_ff (int): the second-layer of the :class:`PositionwiseFeedForward`.
22- dropout (float): dropout probability.
23- self_attn_type (string): type of self-attention scaled-dot, average
34+ d_model (int): the dimension of keys/values/queries in
35+ :class:`MultiHeadedAttention`, also the input size of
36+ the first-layer of the :class:`PositionwiseFeedForward`.
37+ heads (int): the number of heads for MultiHeadedAttention.
38+ d_ff (int): the second-layer of the :class:`PositionwiseFeedForward`.
39+ dropout (float): dropout in residual, self-attn(dot) and feed-forward
40+ attention_dropout (float): dropout in context_attn (and self-attn(avg))
41+ self_attn_type (string): type of self-attention scaled-dot, average
42+ max_relative_positions (int):
43+ Max distance between inputs in relative positions representations
44+ aan_useffn (bool): Turn on the FFN layer in the AAN decoder
45+ full_context_alignment (bool):
46+ whether enable an extra full context decoder forward for alignment
47+ alignment_heads (int):
48+ N. of cross attention heads to use for alignment guiding
2449 """
2550
2651 def __init__ (self , d_model , heads , d_ff , dropout , attention_dropout ,
2752 self_attn_type = "scaled-dot" , max_relative_positions = 0 ,
2853 aan_useffn = False , full_context_alignment = False ,
29- alignment_heads = None ):
54+ alignment_heads = 0 ):
3055 super (TransformerDecoderLayer , self ).__init__ ()
3156
3257 if self_attn_type == "scaled-dot" :
@@ -48,10 +73,10 @@ def __init__(self, d_model, heads, d_ff, dropout, attention_dropout,
4873 self .alignment_heads = alignment_heads
4974
5075 def forward (self , * args , ** kwargs ):
51- """ Extend _forward for (possibly) multiple decoder pass:
52- 1. Always a default (future masked) decoder forward pass,
53- 2. Possibly a second future aware decoder pass for joint learn
54- full context alignement.
76+ """ Extend ` _forward` for (possibly) multiple decoder pass:
77+ Always a default (future masked) decoder forward pass,
78+ Possibly a second future aware decoder pass for joint learn
79+ full context alignement, :cite:`garg2019jointly` .
5580
5681 Args:
5782 * All arguments of _forward.
@@ -60,9 +85,9 @@ def forward(self, *args, **kwargs):
6085 Returns:
6186 (FloatTensor, FloatTensor, FloatTensor or None):
6287
63- * output ``(batch_size, 1 , model_dim)``
64- * top_attn ``(batch_size, 1 , src_len)``
65- * attn_align ``(batch_size, 1 , src_len)`` or None
88+ * output ``(batch_size, T , model_dim)``
89+ * top_attn ``(batch_size, T , src_len)``
90+ * attn_align ``(batch_size, T , src_len)`` or None
6691 """
6792 with_align = kwargs .pop ('with_align' , False )
6893 output , attns = self ._forward (* args , ** kwargs )
@@ -73,7 +98,7 @@ def forward(self, *args, **kwargs):
7398 # return _, (B, Q_len, K_len)
7499 _ , attns = self ._forward (* args , ** kwargs , future = True )
75100
76- if self .alignment_heads is not None :
101+ if self .alignment_heads > 0 :
77102 attns = attns [:, :self .alignment_heads , :, :].contiguous ()
78103 # layer average attention across heads, get ``(B, Q, K)``
79104 # Case 1: no full_context, no align heads -> layer avg baseline
@@ -85,18 +110,23 @@ def forward(self, *args, **kwargs):
85110 def _forward (self , inputs , memory_bank , src_pad_mask , tgt_pad_mask ,
86111 layer_cache = None , step = None , future = False ):
87112 """ A naive forward pass for transformer decoder.
88- # TODO: change 1 to T as T could be 1 or tgt_len
113+
114+ # T: could be 1 in the case of stepwise decoding or tgt_len
115+
89116 Args:
90- inputs (FloatTensor): ``(batch_size, 1 , model_dim)``
117+ inputs (FloatTensor): ``(batch_size, T , model_dim)``
91118 memory_bank (FloatTensor): ``(batch_size, src_len, model_dim)``
92119 src_pad_mask (LongTensor): ``(batch_size, 1, src_len)``
93- tgt_pad_mask (LongTensor): ``(batch_size, 1, 1)``
120+ tgt_pad_mask (LongTensor): ``(batch_size, 1, T)``
121+ layer_cache (dict or None): cached layer info when stepwise decode
122+ step (int or None): stepwise decoding counter
123+ future (bool): If set True, do not apply future_mask.
94124
95125 Returns:
96126 (FloatTensor, FloatTensor):
97127
98- * output ``(batch_size, 1 , model_dim)``
99- * attns ``(batch_size, head, 1 , src_len)``
128+ * output ``(batch_size, T , model_dim)``
129+ * attns ``(batch_size, head, T , src_len)``
100130
101131 """
102132 dec_mask = None
@@ -166,22 +196,31 @@ class TransformerDecoder(DecoderBase):
166196
167197
168198 Args:
169- num_layers (int): number of encoder layers.
170- d_model (int): size of the model
171- heads (int): number of heads
172- d_ff (int): size of the inner FF layer
173- copy_attn (bool): if using a separate copy attention
174- self_attn_type (str): type of self-attention scaled-dot, average
175- dropout (float): dropout parameters
176- embeddings (onmt.modules.Embeddings):
177- embeddings to use, should have positional encodings
199+ num_layers (int): number of encoder layers.
200+ d_model (int): size of the model
201+ heads (int): number of heads
202+ d_ff (int): size of the inner FF layer
203+ copy_attn (bool): if using a separate copy attention
204+ self_attn_type (str): type of self-attention scaled-dot, average
205+ dropout (float): dropout in residual, self-attn(dot) and feed-forward
206+ attention_dropout (float): dropout in context_attn (and self-attn(avg))
207+ embeddings (onmt.modules.Embeddings):
208+ embeddings to use, should have positional encodings
209+ max_relative_positions (int):
210+ Max distance between inputs in relative positions representations
211+ aan_useffn (bool): Turn on the FFN layer in the AAN decoder
212+ full_context_alignment (bool):
213+ whether enable an extra full context decoder forward for alignment
214+ alignment_layer (int): N° Layer to supervise with for alignment guiding
215+ alignment_heads (int):
216+ N. of cross attention heads to use for alignment guiding
178217 """
179218
180219 def __init__ (self , num_layers , d_model , heads , d_ff ,
181220 copy_attn , self_attn_type , dropout , attention_dropout ,
182221 embeddings , max_relative_positions , aan_useffn ,
183222 full_context_alignment , alignment_layer ,
184- alignment_heads = None ):
223+ alignment_heads ):
185224 super (TransformerDecoder , self ).__init__ ()
186225
187226 self .embeddings = embeddings
0 commit comments