@@ -422,43 +422,6 @@ def __call__(self, x: jnp.ndarray):
422422 return x_norm
423423
424424
425- class PrepareForMultiHeadAttention (Module ):
426- """
427- <a id="PrepareMHA"></a>
428-
429- ## Prepare for multi-head attention
430-
431- This module does a linear transformation and splits the vector into given
432- number of heads for multi-head attention.
433- This is used to transform **key**, **query**, and **value** vectors.
434- """
435-
436- def __init__ (self , rnd_key : jax .random .PRNGKey , d_model : int , heads : int , d_k : int ):
437- super ().__init__ ()
438- # Linear layer for linear transform
439- self .linear = Linear (rnd_key , d_model , heads * d_k )
440- # Number of heads
441- self .heads = heads
442- # Number of dimensions in vectors in each head
443- self .d_k = d_k
444-
445- def __call__ (self , x : jnp .ndarray ):
446- # Input has shape `[seq_len, batch_size, d_model]` or `[batch_size, d_model]`.
447- # We apply the linear transformation to the last dimension and split that into
448- # the heads.
449- head_shape = x .shape [:- 1 ]
450-
451- # Linear transform
452- x = self .linear (x )
453-
454- # Split last dimension into heads
455-
456- x = x .reshape (* head_shape , self .heads , self .d_k )
457-
458- # Output has shape `[seq_len, batch_size, heads, d_k]` or `[batch_size, d_model]`
459- return x
460-
461-
462425class MultiHeadAttention (Module ):
463426 r"""
464427 <a id="MHA"></a>
@@ -503,9 +466,9 @@ def __init__(self, rnd_key: jax.random.PRNGKey, heads: int, d_model: int):
503466 self .heads = heads
504467
505468 # These transform the `query`, `key` and `value` vectors for multi-headed attention.
506- self .query = PrepareForMultiHeadAttention (rnd_keys [0 ], d_model , heads , self . d_k )
507- self .key = PrepareForMultiHeadAttention (rnd_keys [1 ], d_model , heads , self . d_k )
508- self .value = PrepareForMultiHeadAttention (rnd_keys [2 ], d_model , heads , self . d_k )
469+ self .query = Linear (rnd_keys [0 ], d_model , d_model )
470+ self .key = Linear (rnd_keys [1 ], d_model , d_model )
471+ self .value = Linear (rnd_keys [2 ], d_model , d_model )
509472
510473 # Output layer
511474 self .output = Linear (rnd_keys [3 ], d_model , d_model )
@@ -537,12 +500,18 @@ def __call__(self, *,
537500 # Same mask applied to all heads.
538501 mask = mask [:, :, None ]
539502
540- # Prepare `query`, `key` and `value` for attention computation.
541- # These will then have shape `[seq_len, heads, d_k]`.
503+ # Apply linear transformations
542504 query = self .query (query )
543505 key = self .key (key )
544506 value = self .value (value )
545507
508+ # Reshape to split into heads
509+ # Input has shape `[seq_len, batch_size, d_model]`.
510+ # We split the last dimension into `heads` and `d_k`.
511+ query = query .reshape (* query .shape [:- 1 ], self .heads , self .d_k )
512+ key = key .reshape (* key .shape [:- 1 ], self .heads , self .d_k )
513+ value = value .reshape (* value .shape [:- 1 ], self .heads , self .d_k )
514+
546515 # Compute attention scores $Q K^\top$.
547516 # This gives a tensor of shape `[seq_len, seq_len, heads]`.
548517 # $$S_{ijh} = \sum_d Q_{ihd} K_{jhd}$$
@@ -1038,4 +1007,4 @@ def get_loss(params, seq):
10381007
10391008#
10401009if __name__ == '__main__' :
1041- main ()
1010+ main ()
0 commit comments