← back

Implement Gated Attention

#271 · Deep Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement Gated Attention — an attention mechanism that uses a learned gate to modulate the attention output. The gate controls how much of the attended information flows through.

Solution

Compute standard scaled dot-product attention, then apply a sigmoid gate derived from the query and context to modulate the output.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import math

def softmax(values: list[float]) -> list[float]:
    max_v = max(values)
    exps = [math.exp(v - max_v) for v in values]
    total = sum(exps)
    return [e / total for e in exps]

def gated_attention(
    query: list[float],
    keys: list[list[float]],
    values: list[list[float]],
    gate_weights: list[float],
    gate_bias: float = 0.0,
) -> dict:
    dim = len(query)
    n = len(keys)
    scale = math.sqrt(dim)

    # Compute attention scores
    scores = []
    for i in range(n):
        dot = sum(query[d] * keys[i][d] for d in range(dim))
        scores.append(dot / scale)

    # Softmax to get attention weights
    attn_weights = softmax(scores)

    # Compute attended value (context vector)
    v_dim = len(values[0])
    context = [0.0] * v_dim
    for i in range(n):
        for d in range(v_dim):
            context[d] += attn_weights[i] * values[i][d]

    # Compute gate: sigmoid(W_g . query + b_g)
    gate_input = sum(gate_weights[d] * query[d] for d in range(dim)) + gate_bias
    gate = 1.0 / (1.0 + math.exp(-gate_input))

    # Apply gate to context
    gated_output = [gate * context[d] for d in range(v_dim)]

    return {
        "attention_weights": [round(w, 6) for w in attn_weights],
        "context": [round(c, 6) for c in context],
        "gate": round(gate, 6),
        "output": [round(o, 6) for o in gated_output],
    }

Explanation

  1. Compute scaled dot-product attention: scores = Q * K^T / sqrt(d), then softmax.
  2. Compute the context vector as the weighted sum of values using the attention weights.
  3. Compute a gate value in [0, 1] by applying a linear projection of the query through a sigmoid.
  4. Multiply the context vector element-wise by the gate to produce the final output.
  5. The gate allows the model to learn when to rely on attention and when to suppress it.

Complexity

  • Time: O(n * d) where n is the number of keys and d is the dimension
  • Space: O(n + d) for attention weights and context vector