Skip to content

Commit b76d5c5

Browse files
committed
add argument to switch 8bit
1 parent 3e03c83 commit b76d5c5

File tree

2 files changed

+27
-9
lines changed

2 files changed

+27
-9
lines changed

eval_configs/minigpt4_eval.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ model:
55
freeze_qformer: True
66
max_txt_len: 160
77
end_sym: "###"
8+
low_resource: True
89
prompt_path: "prompts/alignment.txt"
910
prompt_template: '###Human: {} ###Assistant: '
1011
ckpt: '/path/to/pretrained/ckpt/'

minigpt4/models/mini_gpt4.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,13 @@ def __init__(
3636
prompt_path="",
3737
prompt_template="",
3838
max_txt_len=32,
39+
low_resource=False, # use 8 bit and put vit in cpu
3940
end_sym='\n',
4041
):
4142
super().__init__()
4243

4344
self.tokenizer = self.init_tokenizer()
45+
self.low_resource = low_resource
4446

4547
print('Loading VIT')
4648
self.visual_encoder, self.ln_vision = self.init_vision_encoder(
@@ -83,10 +85,19 @@ def __init__(
8385
self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_model, use_fast=False)
8486
self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token
8587

86-
self.llama_model = LlamaForCausalLM.from_pretrained(
87-
llama_model, torch_dtype=torch.float16,
88-
load_in_8bit=True, device_map="auto"
89-
)
88+
if self.low_resource:
89+
self.llama_model = LlamaForCausalLM.from_pretrained(
90+
llama_model,
91+
torch_dtype=torch.float16,
92+
load_in_8bit=True,
93+
device_map="auto"
94+
)
95+
else:
96+
self.llama_model = LlamaForCausalLM.from_pretrained(
97+
llama_model,
98+
torch_dtype=torch.float16,
99+
)
100+
90101
for name, param in self.llama_model.named_parameters():
91102
param.requires_grad = False
92103
print('Loading LLAMA Done')
@@ -107,18 +118,22 @@ def __init__(
107118
else:
108119
self.prompt_list = []
109120

110-
def encode_img(self, image):
111-
device = image.device
121+
def vit_to_cpu(self):
112122
self.ln_vision.to("cpu")
113123
self.ln_vision.float()
114124
self.visual_encoder.to("cpu")
115125
self.visual_encoder.float()
116-
image = image.to("cpu")
117126

118-
image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
119-
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
127+
def encode_img(self, image):
128+
device = image.device
129+
if self.low_resource:
130+
self.vit_to_cpu()
131+
image = image.to("cpu")
120132

121133
with self.maybe_autocast():
134+
image_embeds = self.ln_vision(self.visual_encoder(image)).to(device)
135+
image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device)
136+
122137
query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)
123138
query_output = self.Qformer.bert(
124139
query_embeds=query_tokens,
@@ -216,6 +231,7 @@ def from_config(cls, cfg):
216231
vit_precision = cfg.get("vit_precision", "fp16")
217232
freeze_vit = cfg.get("freeze_vit", True)
218233
freeze_qformer = cfg.get("freeze_qformer", True)
234+
low_resource = cfg.get("low_resource", False)
219235

220236
prompt_path = cfg.get("prompt_path", "")
221237
prompt_template = cfg.get("prompt_template", "")
@@ -236,6 +252,7 @@ def from_config(cls, cfg):
236252
prompt_path=prompt_path,
237253
prompt_template=prompt_template,
238254
max_txt_len=max_txt_len,
255+
low_resource=low_resource,
239256
end_sym=end_sym
240257
)
241258

0 commit comments

Comments
 (0)