@@ -36,11 +36,13 @@ def __init__(
36
36
prompt_path = "" ,
37
37
prompt_template = "" ,
38
38
max_txt_len = 32 ,
39
+ low_resource = False , # use 8 bit and put vit in cpu
39
40
end_sym = '\n ' ,
40
41
):
41
42
super ().__init__ ()
42
43
43
44
self .tokenizer = self .init_tokenizer ()
45
+ self .low_resource = low_resource
44
46
45
47
print ('Loading VIT' )
46
48
self .visual_encoder , self .ln_vision = self .init_vision_encoder (
@@ -83,10 +85,19 @@ def __init__(
83
85
self .llama_tokenizer = LlamaTokenizer .from_pretrained (llama_model , use_fast = False )
84
86
self .llama_tokenizer .pad_token = self .llama_tokenizer .eos_token
85
87
86
- self .llama_model = LlamaForCausalLM .from_pretrained (
87
- llama_model , torch_dtype = torch .float16 ,
88
- load_in_8bit = True , device_map = "auto"
89
- )
88
+ if self .low_resource :
89
+ self .llama_model = LlamaForCausalLM .from_pretrained (
90
+ llama_model ,
91
+ torch_dtype = torch .float16 ,
92
+ load_in_8bit = True ,
93
+ device_map = "auto"
94
+ )
95
+ else :
96
+ self .llama_model = LlamaForCausalLM .from_pretrained (
97
+ llama_model ,
98
+ torch_dtype = torch .float16 ,
99
+ )
100
+
90
101
for name , param in self .llama_model .named_parameters ():
91
102
param .requires_grad = False
92
103
print ('Loading LLAMA Done' )
@@ -107,18 +118,22 @@ def __init__(
107
118
else :
108
119
self .prompt_list = []
109
120
110
- def encode_img (self , image ):
111
- device = image .device
121
+ def vit_to_cpu (self ):
112
122
self .ln_vision .to ("cpu" )
113
123
self .ln_vision .float ()
114
124
self .visual_encoder .to ("cpu" )
115
125
self .visual_encoder .float ()
116
- image = image .to ("cpu" )
117
126
118
- image_embeds = self .ln_vision (self .visual_encoder (image )).to (device )
119
- image_atts = torch .ones (image_embeds .size ()[:- 1 ], dtype = torch .long ).to (device )
127
+ def encode_img (self , image ):
128
+ device = image .device
129
+ if self .low_resource :
130
+ self .vit_to_cpu ()
131
+ image = image .to ("cpu" )
120
132
121
133
with self .maybe_autocast ():
134
+ image_embeds = self .ln_vision (self .visual_encoder (image )).to (device )
135
+ image_atts = torch .ones (image_embeds .size ()[:- 1 ], dtype = torch .long ).to (device )
136
+
122
137
query_tokens = self .query_tokens .expand (image_embeds .shape [0 ], - 1 , - 1 )
123
138
query_output = self .Qformer .bert (
124
139
query_embeds = query_tokens ,
@@ -216,6 +231,7 @@ def from_config(cls, cfg):
216
231
vit_precision = cfg .get ("vit_precision" , "fp16" )
217
232
freeze_vit = cfg .get ("freeze_vit" , True )
218
233
freeze_qformer = cfg .get ("freeze_qformer" , True )
234
+ low_resource = cfg .get ("low_resource" , False )
219
235
220
236
prompt_path = cfg .get ("prompt_path" , "" )
221
237
prompt_template = cfg .get ("prompt_template" , "" )
@@ -236,6 +252,7 @@ def from_config(cls, cfg):
236
252
prompt_path = prompt_path ,
237
253
prompt_template = prompt_template ,
238
254
max_txt_len = max_txt_len ,
255
+ low_resource = low_resource ,
239
256
end_sym = end_sym
240
257
)
241
258
0 commit comments