Skip to content

Commit f44d1d2

Browse files
Zenglinxiaofrancoishernandez
authored andcommitted
More documentation for transformer decoder, default alignment_heads (#1692)
1 parent 9a4f7a4 commit f44d1d2

3 files changed

Lines changed: 92 additions & 33 deletions

File tree

docs/source/refs.bib

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,3 +445,23 @@ @inproceedings{garg2019jointly
445445
url = {https://arxiv.org/abs/1909.02074},
446446
year = {2019},
447447
}
448+
449+
@inproceedings{DeeperTransformer,
450+
title = "Learning Deep Transformer Models for Machine Translation",
451+
author = "Wang, Qiang and
452+
Li, Bei and
453+
Xiao, Tong and
454+
Zhu, Jingbo and
455+
Li, Changliang and
456+
Wong, Derek F. and
457+
Chao, Lidia S.",
458+
booktitle = "Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics",
459+
month = jul,
460+
year = "2019",
461+
address = "Florence, Italy",
462+
publisher = "Association for Computational Linguistics",
463+
url = "https://www.aclweb.org/anthology/P19-1176",
464+
doi = "10.18653/v1/P19-1176",
465+
pages = "1810--1822",
466+
abstract = "Transformer is the state-of-the-art model in recent machine translation evaluations. Two strands of research are promising to improve models of this kind: the first uses wide networks (a.k.a. Transformer-Big) and has been the de facto standard for development of the Transformer system, and the other uses deeper language representation but faces the difficulty arising from learning deep networks. Here, we continue the line of research on the latter. We claim that a truly deep Transformer model can surpass the Transformer-Big counterpart by 1) proper use of layer normalization and 2) a novel way of passing the combination of previous layers to the next. On WMT{'}16 English-German and NIST OpenMT{'}12 Chinese-English tasks, our deep system (30/25-layer encoder) outperforms the shallow Transformer-Big/Base baseline (6-layer encoder) by 0.4-2.4 BLEU points. As another bonus, the deep model is 1.6X smaller in size and 3X faster in training than Transformer-Big.",
467+
}

onmt/decoders/transformer.py

Lines changed: 71 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,46 @@
1212

1313

1414
class TransformerDecoderLayer(nn.Module):
15-
"""
15+
"""Transformer Decoder layer block in Pre-Norm style.
16+
Pre-Norm style is an improvement w.r.t. Original paper's Post-Norm style,
17+
providing better converge speed and performance. This is also the actual
18+
implementation in tensor2tensor and also avalable in fairseq.
19+
See https://tunz.kr/post/4 and :cite:`DeeperTransformer`.
20+
21+
.. mermaid::
22+
23+
graph LR
24+
%% "*SubLayer" can be self-attn, src-attn or feed forward block
25+
A(input) --> B[Norm]
26+
B --> C["*SubLayer"]
27+
C --> D[Drop]
28+
D --> E((+))
29+
A --> E
30+
E --> F(out)
31+
32+
1633
Args:
17-
d_model (int): the dimension of keys/values/queries in
18-
:class:`MultiHeadedAttention`, also the input size of
19-
the first-layer of the :class:`PositionwiseFeedForward`.
20-
heads (int): the number of heads for MultiHeadedAttention.
21-
d_ff (int): the second-layer of the :class:`PositionwiseFeedForward`.
22-
dropout (float): dropout probability.
23-
self_attn_type (string): type of self-attention scaled-dot, average
34+
d_model (int): the dimension of keys/values/queries in
35+
:class:`MultiHeadedAttention`, also the input size of
36+
the first-layer of the :class:`PositionwiseFeedForward`.
37+
heads (int): the number of heads for MultiHeadedAttention.
38+
d_ff (int): the second-layer of the :class:`PositionwiseFeedForward`.
39+
dropout (float): dropout in residual, self-attn(dot) and feed-forward
40+
attention_dropout (float): dropout in context_attn (and self-attn(avg))
41+
self_attn_type (string): type of self-attention scaled-dot, average
42+
max_relative_positions (int):
43+
Max distance between inputs in relative positions representations
44+
aan_useffn (bool): Turn on the FFN layer in the AAN decoder
45+
full_context_alignment (bool):
46+
whether enable an extra full context decoder forward for alignment
47+
alignment_heads (int):
48+
N. of cross attention heads to use for alignment guiding
2449
"""
2550

2651
def __init__(self, d_model, heads, d_ff, dropout, attention_dropout,
2752
self_attn_type="scaled-dot", max_relative_positions=0,
2853
aan_useffn=False, full_context_alignment=False,
29-
alignment_heads=None):
54+
alignment_heads=0):
3055
super(TransformerDecoderLayer, self).__init__()
3156

3257
if self_attn_type == "scaled-dot":
@@ -48,10 +73,10 @@ def __init__(self, d_model, heads, d_ff, dropout, attention_dropout,
4873
self.alignment_heads = alignment_heads
4974

5075
def forward(self, *args, **kwargs):
51-
""" Extend _forward for (possibly) multiple decoder pass:
52-
1. Always a default (future masked) decoder forward pass,
53-
2. Possibly a second future aware decoder pass for joint learn
54-
full context alignement.
76+
""" Extend `_forward` for (possibly) multiple decoder pass:
77+
Always a default (future masked) decoder forward pass,
78+
Possibly a second future aware decoder pass for joint learn
79+
full context alignement, :cite:`garg2019jointly`.
5580
5681
Args:
5782
* All arguments of _forward.
@@ -60,9 +85,9 @@ def forward(self, *args, **kwargs):
6085
Returns:
6186
(FloatTensor, FloatTensor, FloatTensor or None):
6287
63-
* output ``(batch_size, 1, model_dim)``
64-
* top_attn ``(batch_size, 1, src_len)``
65-
* attn_align ``(batch_size, 1, src_len)`` or None
88+
* output ``(batch_size, T, model_dim)``
89+
* top_attn ``(batch_size, T, src_len)``
90+
* attn_align ``(batch_size, T, src_len)`` or None
6691
"""
6792
with_align = kwargs.pop('with_align', False)
6893
output, attns = self._forward(*args, **kwargs)
@@ -73,7 +98,7 @@ def forward(self, *args, **kwargs):
7398
# return _, (B, Q_len, K_len)
7499
_, attns = self._forward(*args, **kwargs, future=True)
75100

76-
if self.alignment_heads is not None:
101+
if self.alignment_heads > 0:
77102
attns = attns[:, :self.alignment_heads, :, :].contiguous()
78103
# layer average attention across heads, get ``(B, Q, K)``
79104
# Case 1: no full_context, no align heads -> layer avg baseline
@@ -85,18 +110,23 @@ def forward(self, *args, **kwargs):
85110
def _forward(self, inputs, memory_bank, src_pad_mask, tgt_pad_mask,
86111
layer_cache=None, step=None, future=False):
87112
""" A naive forward pass for transformer decoder.
88-
# TODO: change 1 to T as T could be 1 or tgt_len
113+
114+
# T: could be 1 in the case of stepwise decoding or tgt_len
115+
89116
Args:
90-
inputs (FloatTensor): ``(batch_size, 1, model_dim)``
117+
inputs (FloatTensor): ``(batch_size, T, model_dim)``
91118
memory_bank (FloatTensor): ``(batch_size, src_len, model_dim)``
92119
src_pad_mask (LongTensor): ``(batch_size, 1, src_len)``
93-
tgt_pad_mask (LongTensor): ``(batch_size, 1, 1)``
120+
tgt_pad_mask (LongTensor): ``(batch_size, 1, T)``
121+
layer_cache (dict or None): cached layer info when stepwise decode
122+
step (int or None): stepwise decoding counter
123+
future (bool): If set True, do not apply future_mask.
94124
95125
Returns:
96126
(FloatTensor, FloatTensor):
97127
98-
* output ``(batch_size, 1, model_dim)``
99-
* attns ``(batch_size, head, 1, src_len)``
128+
* output ``(batch_size, T, model_dim)``
129+
* attns ``(batch_size, head, T, src_len)``
100130
101131
"""
102132
dec_mask = None
@@ -166,22 +196,31 @@ class TransformerDecoder(DecoderBase):
166196
167197
168198
Args:
169-
num_layers (int): number of encoder layers.
170-
d_model (int): size of the model
171-
heads (int): number of heads
172-
d_ff (int): size of the inner FF layer
173-
copy_attn (bool): if using a separate copy attention
174-
self_attn_type (str): type of self-attention scaled-dot, average
175-
dropout (float): dropout parameters
176-
embeddings (onmt.modules.Embeddings):
177-
embeddings to use, should have positional encodings
199+
num_layers (int): number of encoder layers.
200+
d_model (int): size of the model
201+
heads (int): number of heads
202+
d_ff (int): size of the inner FF layer
203+
copy_attn (bool): if using a separate copy attention
204+
self_attn_type (str): type of self-attention scaled-dot, average
205+
dropout (float): dropout in residual, self-attn(dot) and feed-forward
206+
attention_dropout (float): dropout in context_attn (and self-attn(avg))
207+
embeddings (onmt.modules.Embeddings):
208+
embeddings to use, should have positional encodings
209+
max_relative_positions (int):
210+
Max distance between inputs in relative positions representations
211+
aan_useffn (bool): Turn on the FFN layer in the AAN decoder
212+
full_context_alignment (bool):
213+
whether enable an extra full context decoder forward for alignment
214+
alignment_layer (int): N° Layer to supervise with for alignment guiding
215+
alignment_heads (int):
216+
N. of cross attention heads to use for alignment guiding
178217
"""
179218

180219
def __init__(self, num_layers, d_model, heads, d_ff,
181220
copy_attn, self_attn_type, dropout, attention_dropout,
182221
embeddings, max_relative_positions, aan_useffn,
183222
full_context_alignment, alignment_layer,
184-
alignment_heads=None):
223+
alignment_heads):
185224
super(TransformerDecoder, self).__init__()
186225

187226
self.embeddings = embeddings

onmt/opts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def model_opts(parser):
162162
"https://arxiv.org/abs/1909.02074")
163163
group.add('--alignment_layer', '-alignment_layer', type=int, default=-3,
164164
help='Layer number which has to be supervised.')
165-
group.add('--alignment_heads', '-alignment_heads', type=int, default=None,
165+
group.add('--alignment_heads', '-alignment_heads', type=int, default=0,
166166
help='N. of cross attention heads per layer to supervised with')
167167
group.add('--full_context_alignment', '-full_context_alignment',
168168
action="store_true",

0 commit comments

Comments
 (0)