Skip to content

Commit d72096e

Browse files
authored
Merge pull request #15 from probcomp/alexlew-hf-updates
Use BitsAndBytesConfig instead of load_in_8bit and token instead of u…
2 parents 4921bfe + a8d3c6d commit d72096e

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

hfppl/llms.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Utilities for working with HuggingFace language models, including caching and auto-batching."""
22

33
import torch
4-
from transformers import AutoTokenizer, AutoModelForCausalLM
4+
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
55
import asyncio
66
import string
77

@@ -266,18 +266,22 @@ def from_pretrained(cls, model_id, auth_token=False, load_in_8bit=True):
266266
Returns:
267267
model (hfppl.llms.CachedCausalLM): the LLaMPPL-compatible interface to the HuggingFace model.
268268
"""
269+
bnb_config = BitsAndBytesConfig(load_in_8bit=load_in_8bit)
270+
269271
if not auth_token:
270272
tok = AutoTokenizer.from_pretrained(model_id)
271273
mod = AutoModelForCausalLM.from_pretrained(
272-
model_id, device_map="auto", load_in_8bit=load_in_8bit
274+
model_id,
275+
device_map="auto",
276+
quantization_config=bnb_config,
273277
)
274278
else:
275-
tok = AutoTokenizer.from_pretrained(model_id, use_auth_token=auth_token)
279+
tok = AutoTokenizer.from_pretrained(model_id, token=auth_token)
276280
mod = AutoModelForCausalLM.from_pretrained(
277281
model_id,
278-
use_auth_token=auth_token,
282+
token=auth_token,
279283
device_map="auto",
280-
load_in_8bit=load_in_8bit,
284+
quantization_config=bnb_config,
281285
)
282286

283287
return CachedCausalLM(mod, tok)

0 commit comments

Comments
 (0)