Skip to content

Commit f346824

Browse files
committed
cleanup jax
1 parent 96f7b5a commit f346824

1 file changed

Lines changed: 12 additions & 43 deletions

File tree

labml_nn/transformers/jax_transformer/__init__.py

Lines changed: 12 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -422,43 +422,6 @@ def __call__(self, x: jnp.ndarray):
422422
return x_norm
423423

424424

425-
class PrepareForMultiHeadAttention(Module):
426-
"""
427-
<a id="PrepareMHA"></a>
428-
429-
## Prepare for multi-head attention
430-
431-
This module does a linear transformation and splits the vector into given
432-
number of heads for multi-head attention.
433-
This is used to transform **key**, **query**, and **value** vectors.
434-
"""
435-
436-
def __init__(self, rnd_key: jax.random.PRNGKey, d_model: int, heads: int, d_k: int):
437-
super().__init__()
438-
# Linear layer for linear transform
439-
self.linear = Linear(rnd_key, d_model, heads * d_k)
440-
# Number of heads
441-
self.heads = heads
442-
# Number of dimensions in vectors in each head
443-
self.d_k = d_k
444-
445-
def __call__(self, x: jnp.ndarray):
446-
# Input has shape `[seq_len, batch_size, d_model]` or `[batch_size, d_model]`.
447-
# We apply the linear transformation to the last dimension and split that into
448-
# the heads.
449-
head_shape = x.shape[:-1]
450-
451-
# Linear transform
452-
x = self.linear(x)
453-
454-
# Split last dimension into heads
455-
456-
x = x.reshape(*head_shape, self.heads, self.d_k)
457-
458-
# Output has shape `[seq_len, batch_size, heads, d_k]` or `[batch_size, d_model]`
459-
return x
460-
461-
462425
class MultiHeadAttention(Module):
463426
r"""
464427
<a id="MHA"></a>
@@ -503,9 +466,9 @@ def __init__(self, rnd_key: jax.random.PRNGKey, heads: int, d_model: int):
503466
self.heads = heads
504467

505468
# These transform the `query`, `key` and `value` vectors for multi-headed attention.
506-
self.query = PrepareForMultiHeadAttention(rnd_keys[0], d_model, heads, self.d_k)
507-
self.key = PrepareForMultiHeadAttention(rnd_keys[1], d_model, heads, self.d_k)
508-
self.value = PrepareForMultiHeadAttention(rnd_keys[2], d_model, heads, self.d_k)
469+
self.query = Linear(rnd_keys[0], d_model, d_model)
470+
self.key = Linear(rnd_keys[1], d_model, d_model)
471+
self.value = Linear(rnd_keys[2], d_model, d_model)
509472

510473
# Output layer
511474
self.output = Linear(rnd_keys[3], d_model, d_model)
@@ -537,12 +500,18 @@ def __call__(self, *,
537500
# Same mask applied to all heads.
538501
mask = mask[:, :, None]
539502

540-
# Prepare `query`, `key` and `value` for attention computation.
541-
# These will then have shape `[seq_len, heads, d_k]`.
503+
# Apply linear transformations
542504
query = self.query(query)
543505
key = self.key(key)
544506
value = self.value(value)
545507

508+
# Reshape to split into heads
509+
# Input has shape `[seq_len, batch_size, d_model]`.
510+
# We split the last dimension into `heads` and `d_k`.
511+
query = query.reshape(*query.shape[:-1], self.heads, self.d_k)
512+
key = key.reshape(*key.shape[:-1], self.heads, self.d_k)
513+
value = value.reshape(*value.shape[:-1], self.heads, self.d_k)
514+
546515
# Compute attention scores $Q K^\top$.
547516
# This gives a tensor of shape `[seq_len, seq_len, heads]`.
548517
# $$S_{ijh} = \sum_d Q_{ihd} K_{jhd}$$
@@ -1038,4 +1007,4 @@ def get_loss(params, seq):
10381007

10391008
#
10401009
if __name__ == '__main__':
1041-
main()
1010+
main()

0 commit comments

Comments
 (0)