← back

EAGLE-Style Draft Model from Hidden States

#431 · Deep Learning · Medium

⊣ Solve on deep-ml.com

Problem

Implement an EAGLE-style draft model that predicts the next token using the hidden states from the target model rather than the token embeddings. Given the target model's hidden state at the current position, use a lightweight projection network (one linear layer + activation) to predict draft token logits.

Solution

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import math

def eagle_draft_from_hidden(
    hidden_state: list[float],
    projection_weights: list[list[float]],
    projection_bias: list[float],
    lm_head_weights: list[list[float]],
    temperature: float = 1.0
) -> dict:
    """
    hidden_state: (d_model,) from target model
    projection_weights: (d_proj, d_model) single-layer projection
    projection_bias: (d_proj,)
    lm_head_weights: (vocab_size, d_proj) maps to logits
    """
    d_model = len(hidden_state)
    d_proj = len(projection_weights)
    vocab_size = len(lm_head_weights)

    # Linear projection + SiLU activation
    projected = []
    for i in range(d_proj):
        val = projection_bias[i]
        for j in range(d_model):
            val += projection_weights[i][j] * hidden_state[j]
        # SiLU activation: x * sigmoid(x)
        sigmoid_val = 1.0 / (1.0 + math.exp(-val)) if abs(val) < 500 else (1.0 if val > 0 else 0.0)
        projected.append(val * sigmoid_val)

    # LM head: project to vocabulary
    logits = []
    for v in range(vocab_size):
        val = 0.0
        for j in range(d_proj):
            val += lm_head_weights[v][j] * projected[j]
        logits.append(val)

    # Apply temperature and softmax
    if temperature > 0:
        logits = [l / temperature for l in logits]
    max_logit = max(logits)
    exps = [math.exp(l - max_logit) for l in logits]
    total = sum(exps)
    probs = [e / total for e in exps]

    top_token = max(range(vocab_size), key=lambda i: probs[i])

    return {
        "logits": [round(l, 4) for l in logits],
        "probs": [round(p, 6) for p in probs],
        "predicted_token": top_token,
        "confidence": round(probs[top_token], 6)
    }

Explanation

  1. EAGLE (Extrapolation Algorithm for Greater Language-model Efficiency) uses the target model's hidden states as input to a lightweight draft head.
  2. Unlike standard speculative decoding where the draft model is a separate LM, EAGLE reuses the target model's internal representations.
  3. A single projection layer with SiLU activation maps the hidden state to a lower dimension.
  4. The LM head maps the projected vector to vocabulary logits.
  5. This approach achieves high acceptance rates because hidden states capture richer information than just the token identity.
  6. EAGLE is autoregressive in hidden-state space: it predicts the next hidden state to chain multiple draft tokens.

Complexity

  • Time: O(d_model d_proj + d_proj vocab_size) for the two matrix-vector products
  • Space: O(vocab_size) for logits and probabilities