Skip to content

Commit

Permalink
Enhance Llama model Integration (#22)
Browse files Browse the repository at this point in the history
  • Loading branch information
tanganke authored Oct 31, 2024
1 parent 82bdd3a commit 16dda8e
Show file tree
Hide file tree
Showing 81 changed files with 3,964 additions and 149 deletions.
2 changes: 2 additions & 0 deletions config/fabric_model_fusion.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ defaults:
- modelpool: CLIPVisionModelPool/clip-vit-base-patch32_TA8
- method: dummy
- taskpool: dummy
- _self_

_target_: fusion_bench.programs.FabricModelFusionProgram
_recursive_: false
fast_dev_run: false # Run a single batch of data to test the model or method
Expand Down
2 changes: 1 addition & 1 deletion config/llama_model_fusion.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ modelpool:
model_kwargs:
torch_dtype: float16
low_cpu_mem_usage: true
device_map: "auto"
# device_map: "auto"
23 changes: 23 additions & 0 deletions config/method/adamerging/clip.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# this option can be "clip_task_wise_adamerging"
name: ???
# this weights can be a list of float, or a string that points to a *.np, *.pt file containing the weights
# if weights is specified, skip the test-time adaptation training
weights: null
# learning rate
optimizer: adam
lr: 1e-3
init_values: 0.3
# if `clamp_weights` is true, the weights will be clamped to [0, 1]
clamp_weights: false
# arguments of `functional_call`
tie_weights: true
strict: false
# this is overrided by `fabric.devices` if launched from the `fusion_bench` CLI.
devices: 1
batch_size: 16
num_workers: 8
max_steps: 1000
fast_dev_run: ${fast_dev_run}
# the path for saving the merging weights
save_merging_weights: 'merging_weights.pt'
cache_dir: outputs
33 changes: 33 additions & 0 deletions config/method/adamerging/llama_sft.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
_target_: fusion_bench.method.adamerging.llama_adamerging.LayerWiseAdaMergingForLlamaSFT

seed: 0
output_dir: null
# path to initialize the merging weights
# this weights can be a list of float, or a string that points to a *.np, *.pt file containing the weights
# if weights is specified, skip the test-time adaptation training
init_weights_path: null
sparsity_ratio: null
# average attention modules instead of learning merging weights
average_attntion: true
# start_layer_idx is a float (in [0,1]) or int or null. If it is null, start at the first layer
start_layer_idx: 0.3
# learning rate
optimizer: adam
lr: 1e-3
init_values: 0.5
# if `clamp_weights` is true, the weights will be clamped to [0, 1]
clamp_weights: false
normalized_merging_weights: true
# arguments of `functional_call`
tie_weights: true
strict: false
max_steps: 1000
fast_dev_run: ${fast_dev_run}
# the path for saving the merging weights
save_interval: 100
save_merged_model: true

dataloader_kwargs:
batch_size: 24
num_workers: 0
shuffle: true
6 changes: 6 additions & 0 deletions config/method/analysis/task_vector_violin_plot.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
_target_: fusion_bench.method.TaskVectorViolinPlot

trainable_only: true
max_points_per_model: 1000
fig_kwargs: null
output_path: null
270 changes: 270 additions & 0 deletions docs/guides/nlp/question_answering.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
# Question Answering

## Key Concepts

### Overlapping Tokens

**Overlapping tokens** are segments of text that are repeated between consecutive chunks when a long text needs to be split into smaller pieces due to model's maximum token limit.

Here's a detailed explanation:

1. Why we need overlapping:
- When a text is too long for the model's context window (max_length)
- To maintain continuity and context between chunks
- To avoid losing information that might be split between chunks

2. Key parameters in the code:
- max_length: Maximum number of tokens allowed
- stride: Number of overlapping tokens between chunks
- return_overflowing_tokens: Tells tokenizer to return multiple chunks
- truncation="only_second": Only truncates the context, not the question

Let's illustrate with an example:

Suppose we have a text: *"The quick brown fox jumps over the lazy sleeping dog"*.
The tokenization might look like this:

```
Chunk 1: [The quick brown fox jumps over]
↓ overlap ↓
Chunk 2: [brown fox jumps over the lazy]
↓ overlap ↓
Chunk 3: [jumps over the lazy sleeping dog]
```

Real-world example with actual tokens:

```python
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

question = "What did the fox do?"
context = "The quick brown fox jumps over the lazy sleeping dog. It was a beautiful sunny day."

tokenized = tokenizer(
question,
context,
max_length=16,
truncation="only_second",
return_overflowing_tokens=True,
stride=4
)

# Print the decoded tokens for each chunk
for encoding in tokenized["input_ids"]:
print(tokenizer.decode(encoding))
```

### Offset Mapping

**Offset mapping** is a feature that provides the character-level mapping between the original text and the tokenized output. It returns a list of tuples (start, end) where:

- start: starting character position in the original text
- end: ending character position in the original text

Here's a detailed breakdown:

1. Structure of offset_mapping:

```python
[(0, 0), # [CLS] token - special token, maps to nothing
(0, 3), # "how" - maps to characters 0-3 in original text
(4, 8), # "many" - maps to characters 4-8
...]
```

2. Special tokens mapping:

- [CLS], [SEP], [PAD]: represented as (0, 0)
- These tokens don't correspond to any actual text in the input

3. Usage example:

```python
# Example showing how to use offset_mapping
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

text = "How many cats?"
tokenized = tokenizer(text, return_offsets_mapping=True)

for token_id, offset in zip(tokenized["input_ids"], tokenized["offset_mapping"]):
token = tokenizer.decode([token_id])
start, end = offset
original_text = text[start:end] if start != end else "[SPECIAL]"
print(f"Token: {token}, Offset: {offset}, Original text: {original_text}")
```

Main purposes of offset_mapping:

1. Answer span location:
- Helps locate exact position of answers in QA tasks
- Maps token positions back to original text positions

2. Token-text alignment:
- Enables precise tracking of which parts of original text correspond to which tokens
- Useful for tasks requiring character-level precision

3. Handling overlapping chunks:
- Helps maintain correct position information when text is split into chunks
- Essential for combining predictions from multiple chunks

Common operations with offset_mapping:
```python
# Finding original text for a token
def get_original_text(text, offset):
start, end = offset
return text[start:end] if start != end else "[SPECIAL]"

# Finding token position for a text span
def find_token_position(offset_mapping, char_start, char_end):
for idx, (start, end) in enumerate(offset_mapping):
if start == char_start and end == char_end:
return idx
return None
```

This feature is particularly important in Question Answering tasks where you need to:

- Map predicted token positions back to original text
- Handle answer spans across multiple chunks
- Maintain precise position information for answer extraction

### overflow_to_sample_mapping

`overflow_to_sample_mapping` is an index list that maps each feature in the overflowing tokens back to its original sample. It's particularly useful when processing multiple examples with overflow.

Here's a detailed explanation:

- When a text is split into multiple chunks due to length
- Each chunk needs to be traced back to its original example
- `overflow_to_sample_mapping` provides this tracking mechanism

Here's a comprehensive example:

```python
from transformers import AutoTokenizer
import pandas as pd

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# Multiple examples
examples = {
"question": [
"What is the capital?",
"Who won the game?"
],
"context": [
"Paris is the capital of France. It is known for the Eiffel Tower. The city has many historic monuments." * 5, # Made longer by repeating
"The Lakers won the game against the Bulls. It was a close match." * 2
]
}

# Tokenize with overflow
tokenized_examples = []
for q, c in zip(examples["question"], examples["context"]):
tokenized = tokenizer(
q,
c,
max_length=50, # Small max_length for demonstration
stride=10,
return_overflowing_tokens=True,
return_offsets_mapping=True,
padding="max_length",
truncation="only_second"
)
tokenized_examples.append(tokenized)

# Let's see how many chunks each example was split into
for i, tokenized in enumerate(tokenized_examples):
print(f"\nExample {i}:")
print(f"Number of chunks: {len(tokenized['input_ids'])}")
print(f"Overflow to sample mapping: {tokenized.overflow_to_sample_mapping}")
```

This might output something like:

```
Example 0:
Number of chunks: 4
Overflow to sample mapping: [0, 0, 0, 0] # All chunks belong to first example

Example 1:
Number of chunks: 2
Overflow to sample mapping: [0, 0] # All chunks belong to first example
```

Practical Use Case:

```python
def prepare_train_features(examples):
# Tokenize our examples with truncation and padding, but keep the overflows using a stride
tokenized_examples = tokenizer(
examples["question"],
examples["context"],
truncation="only_second",
max_length=384,
stride=128,
return_overflowing_tokens=True,
return_offsets_mapping=True,
padding="max_length",
)

# Since one example might give us several features if it has a long context
sample_mapping = tokenized_examples.overflow_to_sample_mapping

# For each feature, we need to know from which example it came from
for i, sample_idx in enumerate(sample_mapping):
# Get the example's original question
sequence_ids = tokenized_examples.sequence_ids(i)
context_start = sequence_ids.index(1) # Find where context starts

# Set example ID for this feature
tokenized_examples[i]["example_id"] = examples["id"][sample_idx]

# Set offset mappings for answer spans
tokenized_examples[i]["offset_mapping"] = [
(o if sequence_ids[k] == 1 else None)
for k, o in enumerate(tokenized_examples[i]["offset_mapping"])
]

return tokenized_examples
```

Key Benefits:

1. Tracking Features:
- Maps each feature back to its source example
- Maintains relationship between chunks and original data

2. Data Processing:
- Helps in maintaining example-level information
- Essential for combining predictions from multiple chunks

3. Batch Processing:
- Enables proper batching of features
- Maintains data integrity during training

Common Use Pattern:

```python
# Example of using overflow_to_sample_mapping in a training loop
for i, sample_idx in enumerate(tokenized_examples.overflow_to_sample_mapping):
# Get original example ID
original_example_id = examples["id"][sample_idx]

# Get original answer
original_answer = examples["answers"][sample_idx]

# Process feature while maintaining connection to original example
process_feature(tokenized_examples[i], original_example_id, original_answer)
```

This feature is particularly important in Question Answering tasks where:

- Long contexts need to be split into multiple chunks
- Each chunk needs to be processed separately
- Results need to be combined while maintaining reference to original examples

2 changes: 1 addition & 1 deletion fusion_bench/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
tasks,
utils,
)
from .method import BaseModelFusionAlgorithm
from .method import BaseAlgorithm, BaseModelFusionAlgorithm
from .modelpool import BaseModelPool
from .models import separate_io
from .taskpool import BaseTaskPool
Expand Down
11 changes: 11 additions & 0 deletions fusion_bench/dataset/imdb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from typing import Any, Dict, List, Optional

from datasets import load_dataset, load_from_disk
from transformers import PreTrainedTokenizer

import fusion_bench
import os
import logging
from trl import SFTConfig, SFTTrainer

log = logging.getLogger(__name__)
Empty file.
Loading

0 comments on commit 16dda8e

Please sign in to comment.