Skip to content

Eagle speculator #31

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open

Eagle speculator #31

wants to merge 3 commits into from

Conversation

rahul-tuli
Copy link
Member

@rahul-tuli rahul-tuli commented Jun 12, 2025

  • Adds a unified LlamaEagleSpeculator class for eagle, eagle2, hass speculators
  • Adds serialization/deserialization using from_pretrained, save_pretrained
  • Adds LlamaEagleSpeculatorConfig to control instantiation of the model
  • Adds unit and integration tests

PR description includes a verification script that checks model weights can be loaded

Open Questions: Does the final norm need to be included?

Verification Script for EAGLE 1
#!/usr/bin/env python3
"""
Script to verify that EAGLE1 model weights from HuggingFace can be loaded
into our LlamaEagleSpeculator model definition.

This script:
1. Downloads the EAGLE1 model from HuggingFace
2. Creates a LlamaEagleSpeculatorConfig
3. Instantiates a LlamaEagleSpeculator
4. Attempts to load the weights and verifies compatibility
"""

import json
import sys
from pathlib import Path

import torch
from huggingface_hub import snapshot_download
from safetensors.torch import load_file

# Add parent directory to path to import speculators
sys.path.insert(0, str(Path(__file__).parent.parent.parent / "src"))

from speculators.models.llama_eagle import (
    LlamaDecoderParameters,
    LlamaEagleSpeculator,
    LlamaEagleSpeculatorConfig,
)


def download_model(model_id: str, cache_dir: str = None):
    """Download model from HuggingFace."""
    if cache_dir is None:
        cache_dir = str(Path(__file__).parent.parent / "eagle_refactoring" / "cache" / "model_cache")
    print(f"Downloading model {model_id} from HuggingFace...")
    model_path = snapshot_download(
        repo_id=model_id,
        cache_dir=cache_dir,
        allow_patterns=["*.safetensors", "*.bin", "*.json", "*.txt"],
    )
    print(f"Model downloaded to: {model_path}")
    return Path(model_path)


def load_original_config(model_path: Path):
    """Load the original model configuration."""
    config_path = model_path / "config.json"
    if not config_path.exists():
        raise FileNotFoundError(f"Config file not found at {config_path}")
    
    with open(config_path, "r") as f:
        config_dict = json.load(f)
    
    print("\nOriginal config keys:")
    for key in sorted(config_dict.keys()):
        print(f"  - {key}: {config_dict[key]}")
    
    return config_dict


def create_llama_eagle_config(original_config: dict):
    """Create LlamaEagleSpeculatorConfig from original config."""
    print("\nCreating LlamaEagleSpeculatorConfig...")
    
    # Extract EOS token ID - handle both list and single int
    eos_token_id = original_config.get("eos_token_id", 128001)
    if isinstance(eos_token_id, list):
        eos_token_id = eos_token_id[0]  # Take first value if it's a list
    
    # Now create the LlamaEagleSpeculatorConfig with LlamaDecoderParameters
    # Most defaults in LlamaDecoderParameters match Llama 3.1 8B
    config = LlamaEagleSpeculatorConfig(
        num_hidden_layers=original_config.get("num_hidden_layers", 1),
        llama_decoder_params=LlamaDecoderParameters(
            # Only override non-matching values
            vocab_size=original_config.get("vocab_size", 128256),
            hidden_size=original_config.get("hidden_size", 4096),
            intermediate_size=original_config.get("intermediate_size", 14336),
            num_attention_heads=original_config.get("num_attention_heads", 32),
            num_key_value_heads=original_config.get("num_key_value_heads", 8),
            max_position_embeddings=original_config.get("max_position_embeddings", 131072),
            rope_theta=original_config.get("rope_theta", 500000.0),
            rms_norm_eps=original_config.get("rms_norm_eps", 1e-5),
            pad_token_id=original_config.get("pad_token_id"),
            bos_token_id=original_config.get("bos_token_id", 128000),
            eos_token_id=eos_token_id,
            # Use defaults for: hidden_act, attention_bias, attention_dropout, mlp_bias
        ),
        # EAGLE-specific parameters
        fc_bias=False,  # EAGLE1 doesn't use bias
        use_extra_layernorms=False,  # EAGLE1 doesn't have extra layernorms
        replace_first_layer_norm=True,  # Key EAGLE optimization
    )
    
    print(f"Created config with:")
    print(f"  - num_hidden_layers: {config.num_hidden_layers}")
    print(f"  - hidden_size: {config.llama_decoder_params.hidden_size}")
    print(f"  - vocab_size: {config.llama_decoder_params.vocab_size}")
    
    return config


def load_model_weights(model_path: Path):
    """Load model weights from safetensors or bin files."""
    # Check for safetensors files
    safetensors_files = list(model_path.glob("*.safetensors"))
    bin_files = list(model_path.glob("*.bin"))
    
    if safetensors_files:
        print(f"\nLoading weights from safetensors files: {safetensors_files}")
        state_dict = {}
        for file in safetensors_files:
            state_dict.update(load_file(file))
    elif bin_files:
        print(f"\nLoading weights from bin files: {bin_files}")
        state_dict = {}
        for file in bin_files:
            state_dict.update(torch.load(file, map_location="cpu"))
    else:
        raise FileNotFoundError("No model weights found (.safetensors or .bin)")
    
    print(f"\nLoaded {len(state_dict)} weight tensors")
    print("\nWeight keys:")
    for key in sorted(state_dict.keys()):
        print(f"  - {key}: {state_dict[key].shape}")
    
    return state_dict


def map_weights_to_our_model(state_dict: dict, our_model: LlamaEagleSpeculator):
    """Map original weights to our model structure."""
    print("\n\nMapping weights to our model...")
    
    our_state_dict = our_model.state_dict()
    mapped_weights = {}
    missing_keys = []
    unexpected_keys = []
    
    # Check if the weights have "model." prefix or not
    has_model_prefix = any(k.startswith("model.") for k in state_dict.keys())
    
    # Expected mappings (with conditional prefix)
    def add_prefix(key):
        return f"model.{key}" if has_model_prefix else key
    
    weight_mappings = {
        # Embeddings
        add_prefix("embed_tokens.weight"): "embed_tokens.weight",
        
        # Fusion layer (EAGLE-specific)
        add_prefix("fc.weight"): "fc.weight",
        add_prefix("fc.bias"): "fc.bias",
        "fc.weight": "fc.weight",  # Also try without prefix
        "fc.bias": "fc.bias",
        
        # Decoder layers
        add_prefix("norm.weight"): "norm.weight",
        add_prefix("lm_head.weight"): "lm_head.weight",
        "lm_head.weight": "lm_head.weight",  # Also try without prefix
    }
    
    # Add decoder layer mappings
    for i in range(our_model.config.num_hidden_layers):
        layer_mappings = {
            add_prefix(f"layers.{i}.self_attn.q_proj.weight"): f"layers.{i}.self_attn.q_proj.weight",
            add_prefix(f"layers.{i}.self_attn.k_proj.weight"): f"layers.{i}.self_attn.k_proj.weight",
            add_prefix(f"layers.{i}.self_attn.v_proj.weight"): f"layers.{i}.self_attn.v_proj.weight",
            add_prefix(f"layers.{i}.self_attn.o_proj.weight"): f"layers.{i}.self_attn.o_proj.weight",
            add_prefix(f"layers.{i}.mlp.gate_proj.weight"): f"layers.{i}.mlp.gate_proj.weight",
            add_prefix(f"layers.{i}.mlp.up_proj.weight"): f"layers.{i}.mlp.up_proj.weight",
            add_prefix(f"layers.{i}.mlp.down_proj.weight"): f"layers.{i}.mlp.down_proj.weight",
            add_prefix(f"layers.{i}.input_layernorm.weight"): f"layers.{i}.input_layernorm.weight",
            add_prefix(f"layers.{i}.post_attention_layernorm.weight"): f"layers.{i}.post_attention_layernorm.weight",
        }
        weight_mappings.update(layer_mappings)
    
    # Try direct mapping first
    for orig_key, our_key in weight_mappings.items():
        if orig_key in state_dict and our_key in our_state_dict:
            if state_dict[orig_key].shape == our_state_dict[our_key].shape:
                mapped_weights[our_key] = state_dict[orig_key]
                print(f"✓ Mapped {orig_key} -> {our_key}")
            else:
                print(f"✗ Shape mismatch for {orig_key}: {state_dict[orig_key].shape} vs {our_state_dict[our_key].shape}")
        elif orig_key in state_dict:
            unexpected_keys.append(orig_key)
        elif our_key in our_state_dict:
            missing_keys.append(our_key)
    
    # Check for any unmapped weights in original model
    for key in state_dict:
        if key not in weight_mappings:
            # Try alternative mappings
            if key == "fc.weight" and "fc.weight" in our_state_dict:
                if state_dict[key].shape == our_state_dict["fc.weight"].shape:
                    mapped_weights["fc.weight"] = state_dict[key]
                    print(f"✓ Mapped {key} -> fc.weight (direct)")
            else:
                print(f"? Unmapped weight in original model: {key}")
    
    
    return mapped_weights, missing_keys, unexpected_keys


def verify_model_loading(model_id: str = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"):
    """Main function to verify model loading."""
    print(f"Verifying EAGLE model loading for: {model_id}")
    print("=" * 80)
    
    try:
        # Step 1: Download model
        model_path = download_model(model_id)
        
        # Step 2: Load original config
        original_config = load_original_config(model_path)
        
        # Step 3: Create our config
        our_config = create_llama_eagle_config(original_config)
        
        # Step 4: Create our model
        print("\nInstantiating LlamaEagleSpeculator...")
        our_model = LlamaEagleSpeculator(our_config)
        print(f"Model created with {sum(p.numel() for p in our_model.parameters())} parameters")
        
        # Step 5: Load original weights
        original_weights = load_model_weights(model_path)
        
        # Step 6: Map weights
        mapped_weights, missing_keys, unexpected_keys = map_weights_to_our_model(
            original_weights, our_model
        )
        
        # Step 7: Load mapped weights into our model
        print(f"\nLoading {len(mapped_weights)} mapped weights into our model...")
        our_model.load_state_dict(mapped_weights, strict=False)
        
        # Step 8: Test forward pass
        print("\nTesting forward pass...")
        test_input_ids = torch.randint(0, our_config.llama_decoder_params.vocab_size, (1, 10))
        test_hidden_states = torch.randn(1, 10, our_config.llama_decoder_params.hidden_size)
        
        with torch.no_grad():
            output = our_model(input_ids=test_input_ids, hidden_states=test_hidden_states)
        
        print(f"✓ Forward pass successful! Output shape: {output.shape}")
        
        # Summary
        print("\n" + "=" * 80)
        print("SUMMARY:")
        print(f"  - Successfully mapped: {len(mapped_weights)} weights")
        print(f"  - Missing keys: {len(missing_keys)}")
        print(f"  - Unexpected keys: {len(unexpected_keys)}")
        
        if missing_keys:
            print("\nMissing keys that need attention:")
            for key in missing_keys[:10]:  # Show first 10
                print(f"  - {key}")
            if len(missing_keys) > 10:
                print(f"  ... and {len(missing_keys) - 10} more")
        
        print("\n✓ Model loading verification complete!")
        
        # Save our config for reference
        save_path = Path(__file__).parent.parent / "eagle_refactoring" / "configs" / "generated_eagle_config"
        save_path.mkdir(exist_ok=True, parents=True)
        our_config.save_pretrained(save_path)
        print(f"\nSaved example config to: {save_path}")
        
    except Exception as e:
        print(f"\n✗ Error during verification: {type(e).__name__}: {e}")
        import traceback
        traceback.print_exc()
        return False
    
    return True


if __name__ == "__main__":
    success = verify_model_loading()
Verification Output
python local/model_definition/verifiy_eagle.py
Verifying EAGLE model loading for: yuhuili/EAGLE-LLaMA3.1-Instruct-8B
================================================================================
Downloading model yuhuili/EAGLE-LLaMA3.1-Instruct-8B from HuggingFace...
Fetching 2 files: 100%|████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 33288.13it/s]
Model downloaded to: ./model_cache/models--yuhuili--EAGLE-LLaMA3.1-Instruct-8B/snapshots/89073acba22a03994aee0c76774a10ca941e4706

Original config keys:
  - architectures: ['LlamaForCausalLM']
  - bias: False
  - bos_token_id: 128000
  - eos_token_id: 128001
  - hidden_act: silu
  - hidden_size: 4096
  - initializer_range: 0.02
  - intermediate_size: 14336
  - max_position_embeddings: 2048
  - model_type: llama
  - num_attention_heads: 32
  - num_hidden_layers: 1
  - num_key_value_heads: 8
  - pad_token_id: 0
  - rms_norm_eps: 1e-05
  - rope_theta: 500000.0
  - tie_word_embeddings: False
  - torch_dtype: float16
  - transformers_version: 4.28.1
  - use_cache: True
  - vocab_size: 128256

Creating LlamaEagleSpeculatorConfig...
Created config with:
  - eagle_variant: eagle
  - num_hidden_layers: 1
  - llama_config.hidden_size: 4096
  - llama_config.vocab_size: 128256

Instantiating LlamaEagleSpeculator...
Model created with 1302335488 parameters

Loading weights from bin files: [PosixPath('model_cache/models--yuhuili--EAGLE-LLaMA3.1-Instruct-8B/snapshots/89073acba22a03994aee0c76774a10ca941e4706/pytorch_model.bin')]

Loaded 10 weight tensors

Weight keys:
  - embed_tokens.weight: torch.Size([128256, 4096])
  - fc.weight: torch.Size([4096, 8192])
  - layers.0.mlp.down_proj.weight: torch.Size([4096, 14336])
  - layers.0.mlp.gate_proj.weight: torch.Size([14336, 4096])
  - layers.0.mlp.up_proj.weight: torch.Size([14336, 4096])
  - layers.0.post_attention_layernorm.weight: torch.Size([4096])
  - layers.0.self_attn.k_proj.weight: torch.Size([1024, 4096])
  - layers.0.self_attn.o_proj.weight: torch.Size([4096, 4096])
  - layers.0.self_attn.q_proj.weight: torch.Size([4096, 4096])
  - layers.0.self_attn.v_proj.weight: torch.Size([1024, 4096])


Mapping weights to our model...
✓ Mapped embed_tokens.weight -> embed_tokens.weight
✓ Mapped fc.weight -> fc.weight
✓ Mapped layers.0.self_attn.q_proj.weight -> layers.0.self_attn.q_proj.weight
✓ Mapped layers.0.self_attn.k_proj.weight -> layers.0.self_attn.k_proj.weight
✓ Mapped layers.0.self_attn.v_proj.weight -> layers.0.self_attn.v_proj.weight
✓ Mapped layers.0.self_attn.o_proj.weight -> layers.0.self_attn.o_proj.weight
✓ Mapped layers.0.mlp.gate_proj.weight -> layers.0.mlp.gate_proj.weight
✓ Mapped layers.0.mlp.up_proj.weight -> layers.0.mlp.up_proj.weight
✓ Mapped layers.0.mlp.down_proj.weight -> layers.0.mlp.down_proj.weight
✓ Mapped layers.0.post_attention_layernorm.weight -> layers.0.post_attention_layernorm.weight

Loading 10 mapped weights into our model...

Testing forward pass...
✓ Forward pass successful! Output shape: torch.Size([1, 10, 128256])

================================================================================
SUMMARY:
  - Successfully mapped: 10 weights
  - Missing keys: 1
  - Unexpected keys: 0

Missing keys that need attention:
  - lm_head.weight

✓ Model loading verification complete!
Eagle Config ```

{
"eagle_variant": "eagle",
"extra_layernorm_positions": null,
"fc_bias": false,
"has_no_defaults_at_init": false,
"inputs": [
"input_ids",
"hidden_states[-1]"
],
"inputs_hidden_states_normalized": false,
"llama_decoder_params": {
"attention_bias": false,
"attention_dropout": 0.0,
"bos_token_id": null,
"eos_token_id": null,
"hidden_act": "silu",
"hidden_size": 4096,
"intermediate_size": 14336,
"max_position_embeddings": 2048,
"mlp_bias": false,
"num_attention_heads": 32,
"num_key_value_heads": 8,
"pad_token_id": null,
"rms_norm_eps": 1e-05,
"rope_theta": 500000.0,
"vocab_size": 128256
},
"model_type": "speculator_model",
"num_hidden_layers": 1,
"replace_first_layer_norm": true,
"speculators_config": null,
"speculators_model_type": "llama_eagle",
"speculators_version": "0.1.0.dev7",
"transformer_input_type": "linear_no_bias",
"transformer_layer_type": "LlamaDecoderLayer",
"transformer_remove_last_layer_norm": false,
"transformers_version": "4.52.4",
"use_extra_layernorms": false,
"use_verifier_lm_head": false
}

</details>

📦 Build Artifacts Available
The build artifacts (.whl and .tar.gz) have been successfully generated and are available for download: https://github.com/neuralmagic/speculators/actions/runs/15613226666/artifacts/3315063219.
They will be retained for up to 30 days.

@rahul-tuli rahul-tuli changed the base branch from main to models-config-enablement June 12, 2025 14:25
@rahul-tuli rahul-tuli self-assigned this Jun 12, 2025

📦 Build Artifacts Available
The build artifacts (.whl and .tar.gz) have been successfully generated and are available for download: https://github.com/neuralmagic/speculators/actions/runs/15614940890/artifacts/3315742862.
They will be retained for up to 30 days.

Copy link

@brian-dellabetta brian-dellabetta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not terribly familiar with speculators, but looking through this specifically for coding style and class definition, things look good to me! One question on structure

📦 Build Artifacts Available
The build artifacts (.whl and .tar.gz) have been successfully generated and are available for download: https://github.com/neuralmagic/speculators/actions/runs/15828956076/artifacts/3384587839.
They will be retained for up to 30 days.

@markurtz markurtz force-pushed the models-config-enablement branch from 116bb5b to db86eb6 Compare June 23, 2025 17:19
Base automatically changed from models-config-enablement to main June 23, 2025 17:34
serialization support

  - Add LlamaEagleSpeculator model class
  - Implement model serialization/deserialization
  - Add torch as a project dependency
  - Include unit and integration tests
  - Fixes for serialization handling

  This introduces the core LlamaEagleSpeculator
  functionality with full serialization support

📦 Build Artifacts Available
The build artifacts (.whl and .tar.gz) have been successfully generated and are available for download: https://github.com/neuralmagic/speculators/actions/runs/15833046979/artifacts/3386194776.
They will be retained for up to 30 days.

📦 Build Artifacts Available
The build artifacts (.whl and .tar.gz) have been successfully generated and are available for download: https://github.com/neuralmagic/speculators/actions/runs/15858831729/artifacts/3395051533.
They will be retained for up to 30 days.

Add configuration classes for EAGLE1 and HASS speculator models without the model implementation. This separates the config from the actual model code.

- Add LlamaDecoderParameters for Llama decoder configuration
- Add LlamaEagleSpeculatorConfig with support for both EAGLE1 and HASS variants
- Include config-only unit tests

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <[email protected]>

📦 Build Artifacts Available
The build artifacts (.whl and .tar.gz) have been successfully generated and are available for download: https://github.com/neuralmagic/speculators/actions/runs/15860217653/artifacts/3395595353.
They will be retained for up to 30 days.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants