Skip to content

Commit

Permalink
Add testing for phi-3 mlp quantisation
Browse files Browse the repository at this point in the history
  • Loading branch information
SarahByrneIntel committed Jul 16, 2024
1 parent d2fe9fe commit b7825e7
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions test/python/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from transformers.models.phi.modeling_phi import PhiConfig, PhiMLP
from transformers.models.phi3.modeling_phi3 import Phi3Config, Phi3MLP
from transformers import AutoTokenizer, AutoModelForCausalLM
from intel_npu_acceleration_library.dtypes import int8, int4
from sklearn.metrics import r2_score
from torch.profiler import profile, ProfilerActivity
import intel_npu_acceleration_library
Expand Down Expand Up @@ -85,19 +86,27 @@ def test_phi2_mlp(seq_len, hidden_size, intermediate_size):
@pytest.mark.parametrize("seq_len", [16, 128, 256])
@pytest.mark.parametrize("hidden_size", [256, 512])
@pytest.mark.parametrize("intermediate_size", [512])
def test_phi3_mlp_compile(seq_len, hidden_size, intermediate_size):
@pytest.mark.parametrize("dtype", ["float16", "int8", "int4"])
def test_phi3_mlp_compile(seq_len, hidden_size, intermediate_size, dtype):
conf = Phi3Config.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
conf.num_hidden_layers = 1
conf.hidden_size = hidden_size
conf.intermediate_size = intermediate_size

if dtype == "int8":
dtype = int8
elif dtype == "int4":
dtype = int4
else:
dtype = torch.float16

mlp = Phi3MLP(conf)

hidden_states = torch.rand((seq_len, conf.hidden_size))

reference = mlp(hidden_states.to(torch.float32)).to(torch.float16).detach().numpy()

model = intel_npu_acceleration_library.compile(mlp)
model = intel_npu_acceleration_library.compile(mlp, dtype)

assert model

Expand All @@ -116,4 +125,7 @@ def test_phi3_mlp_compile(seq_len, hidden_size, intermediate_size):
assert np.isfinite(reference).all(), "Pytorch Reference contains NaN or Inf"
assert np.isfinite(out).all(), "NPU output contains NaN or Inf"

assert 1 - r2_score(reference, out) < 0.001
if dtype == int4:
assert 1 - r2_score(reference, out) < 0.05
else:
assert 1 - r2_score(reference, out) < 0.001

0 comments on commit b7825e7

Please sign in to comment.