Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check torch and lightning versions before running models #444

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion evaluate/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from lit_llama import Tokenizer
from lit_llama.adapter import LLaMA
from lit_llama.utils import EmptyInitOnDevice, lazy_load, llama_model_lookup
from lit_llama.utils import EmptyInitOnDevice, lazy_load, llama_model_lookup, check_python_packages
from scripts.prepare_alpaca import generate_prompt

from datasets import load_dataset
Expand Down Expand Up @@ -80,6 +80,8 @@ def main(
assert checkpoint_path.is_file()
assert tokenizer_path.is_file()

check_python_packages()

fabric = L.Fabric(accelerator=accelerator, devices=1)

dt = getattr(torch, dtype, None)
Expand Down
4 changes: 3 additions & 1 deletion evaluate/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from lit_llama import Tokenizer
from lit_llama.adapter import LLaMA
from lit_llama.utils import EmptyInitOnDevice, lazy_load, llama_model_lookup
from lit_llama.utils import EmptyInitOnDevice, lazy_load, llama_model_lookup, check_python_packages
from lit_llama.adapter_v2 import add_adapter_v2_parameters_to_linear_layers
from scripts.prepare_alpaca import generate_prompt

Expand Down Expand Up @@ -76,6 +76,8 @@ def main(
assert checkpoint_path.is_file()
assert tokenizer_path.is_file()

check_python_packages()

fabric = L.Fabric(accelerator=accelerator, devices=1)

dt = getattr(torch, dtype, None)
Expand Down
4 changes: 3 additions & 1 deletion evaluate/full.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
sys.path.append(str(wd))

from lit_llama import LLaMA, Tokenizer
from lit_llama.utils import EmptyInitOnDevice
from lit_llama.utils import EmptyInitOnDevice, check_python_packages

from datasets import load_dataset

Expand Down Expand Up @@ -74,6 +74,8 @@ def main(
assert checkpoint_path.is_file()
assert tokenizer_path.is_file()

check_python_packages()

fabric = L.Fabric(accelerator=accelerator, devices=1)

dt = getattr(torch, dtype, None)
Expand Down
4 changes: 3 additions & 1 deletion evaluate/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
sys.path.append(str(wd))

from lit_llama import LLaMA, Tokenizer
from lit_llama.utils import EmptyInitOnDevice, lazy_load, llama_model_lookup
from lit_llama.utils import EmptyInitOnDevice, lazy_load, llama_model_lookup, check_python_packages
from lit_llama.lora import lora
from scripts.prepare_alpaca import generate_prompt

Expand Down Expand Up @@ -83,6 +83,8 @@ def main(
assert checkpoint_path.is_file()
assert tokenizer_path.is_file()

check_python_packages()

if quantize is not None:
raise NotImplementedError("Quantization in LoRA is not supported yet")

Expand Down
3 changes: 3 additions & 0 deletions finetune/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from generate import generate
from lit_llama.adapter import LLaMA, LLaMAConfig, mark_only_adapter_as_trainable, adapter_state_from_state_dict
from lit_llama.tokenizer import Tokenizer
from lit_llama.utils import check_python_packages
from scripts.prepare_alpaca import generate_prompt
from lightning.fabric.strategies import DeepSpeedStrategy

Expand Down Expand Up @@ -65,6 +66,8 @@ def main(
out_dir: str = "out/adapter/alpaca",
):

check_python_packages()

fabric = L.Fabric(
accelerator="cuda",
devices=devices,
Expand Down
4 changes: 3 additions & 1 deletion finetune/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from lit_llama.tokenizer import Tokenizer
from scripts.prepare_alpaca import generate_prompt
from lightning.fabric.strategies import DeepSpeedStrategy

from lit_llama.utils import check_python_packages

eval_interval = 600
save_interval = 1000
Expand Down Expand Up @@ -69,6 +69,8 @@ def main(
pretrained_path: str = "checkpoints/lit-llama/7B/lit-llama.pth",
out_dir: str = "out/adapter_v2/alpaca",
):

check_python_packages()

fabric = L.Fabric(
accelerator="cuda",
Expand Down
4 changes: 3 additions & 1 deletion finetune/full.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from generate import generate
from lit_llama.model import Block, LLaMA, LLaMAConfig
from lit_llama.tokenizer import Tokenizer
from lit_llama.utils import save_model_checkpoint
from lit_llama.utils import save_model_checkpoint, check_python_packages
from scripts.prepare_alpaca import generate_prompt


Expand Down Expand Up @@ -54,6 +54,8 @@ def main(
out_dir: str = "out/full/alpaca",
):

check_python_packages()

auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block})
strategy = FSDPStrategy(auto_wrap_policy=auto_wrap_policy, activation_checkpointing=Block, limit_all_gathers=True)

Expand Down
5 changes: 4 additions & 1 deletion finetune/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from lit_llama.lora import mark_only_lora_as_trainable, lora, lora_state_dict
from lit_llama.model import LLaMA, LLaMAConfig
from lit_llama.tokenizer import Tokenizer
from lit_llama.utils import check_python_packages
from scripts.prepare_alpaca import generate_prompt


Expand Down Expand Up @@ -51,6 +52,8 @@ def main(
tokenizer_path: str = "checkpoints/lit-llama/tokenizer.model",
out_dir: str = "out/lora/alpaca",
):

check_python_packages()

fabric = L.Fabric(accelerator="cuda", devices=1, precision="bf16-true")
fabric.launch()
Expand All @@ -70,7 +73,7 @@ def main(
model = LLaMA(config)
# strict=False because missing keys due to LoRA weights not contained in checkpoint state
model.load_state_dict(checkpoint, strict=False)

mark_only_lora_as_trainable(model)

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
Expand Down
4 changes: 3 additions & 1 deletion generate/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from generate import generate
from lit_llama import Tokenizer
from lit_llama.adapter import LLaMA
from lit_llama.utils import lazy_load, llama_model_lookup, quantization
from lit_llama.utils import lazy_load, llama_model_lookup, quantization, check_python_packages
from scripts.prepare_alpaca import generate_prompt


Expand Down Expand Up @@ -52,6 +52,8 @@ def main(
assert pretrained_path.is_file()
assert tokenizer_path.is_file()

check_python_packages()

precision = "bf16-true" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else "32-true"
fabric = L.Fabric(devices=1, precision=precision)

Expand Down
4 changes: 3 additions & 1 deletion generate/adapter_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from generate import generate
from lit_llama import Tokenizer
from lit_llama.adapter import LLaMA
from lit_llama.utils import lazy_load, llama_model_lookup, quantization
from lit_llama.utils import lazy_load, llama_model_lookup, quantization, check_python_packages
from lit_llama.adapter_v2 import add_adapter_v2_parameters_to_linear_layers
from scripts.prepare_alpaca import generate_prompt

Expand Down Expand Up @@ -53,6 +53,8 @@ def main(
assert pretrained_path.is_file()
assert tokenizer_path.is_file()

check_python_packages()

precision = "bf16-true" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else "32-true"
fabric = L.Fabric(devices=1, precision=precision)

Expand Down
4 changes: 3 additions & 1 deletion generate/full.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
sys.path.append(str(wd))

from lit_llama import LLaMA, Tokenizer
from lit_llama.utils import quantization
from lit_llama.utils import quantization, check_python_packages
from scripts.prepare_alpaca import generate_prompt
from generate import generate

Expand Down Expand Up @@ -50,6 +50,8 @@ def main(
assert checkpoint_path.is_file(), checkpoint_path
assert tokenizer_path.is_file(), tokenizer_path

check_python_packages()

precision = "bf16-true" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else "32-true"
fabric = L.Fabric(devices=1, precision=precision)

Expand Down
4 changes: 3 additions & 1 deletion generate/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from generate import generate
from lit_llama import Tokenizer, LLaMA
from lit_llama.lora import lora
from lit_llama.utils import lazy_load, llama_model_lookup
from lit_llama.utils import lazy_load, llama_model_lookup, check_python_packages
from scripts.prepare_alpaca import generate_prompt

lora_r = 8
Expand Down Expand Up @@ -58,6 +58,8 @@ def main(

if quantize is not None:
raise NotImplementedError("Quantization in LoRA is not supported yet")

check_python_packages()

precision = "bf16-true" if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else "32-true"
fabric = L.Fabric(devices=1, precision=precision)
Expand Down
11 changes: 11 additions & 0 deletions lit_llama/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import warnings
from io import BytesIO
from pathlib import Path
from packaging.version import parse as version_parse
from contextlib import contextmanager

import torch
Expand Down Expand Up @@ -494,3 +495,13 @@ def _write_storage_and_return_key(self, storage):

def __exit__(self, type, value, traceback):
self.zipfile.write_end_of_file()




def check_python_packages():

torch_ = RequirementCache('torch>=2.0.0')
lit_ = RequirementCache('lightning>=2.1.0')
if not bool(torch_) or not bool(lit_):
raise ImportError("Wrong package version(s) installed.")
Loading