Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 133 additions & 0 deletions paddlemix/examples/ppdocbee2/ppdocbee2_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import paddle

from paddlemix.models.qwen2_vl import MIXQwen2Tokenizer
from paddlemix.models.ppdocbee2 import PPDocBee2ForConditionalGeneration
from paddlemix.processors.qwen2_vl_processing import (
Qwen2VLImageProcessor,
Qwen2VLProcessor,
process_vision_info,
)
from paddlemix.utils.log import logger


def main(args):
paddle.seed(seed=0)
compute_dtype = "float16" if args.fp16 else "bfloat16"
if "npu" in paddle.get_device():
is_bfloat16_supported = True
else:
is_bfloat16_supported = paddle.amp.is_bfloat16_supported()
if compute_dtype == "bfloat16" and not is_bfloat16_supported:
logger.warning("bfloat16 is not supported on your device,change to float32")
compute_dtype = "float32"

model = PPDocBee2ForConditionalGeneration.from_pretrained(args.model_path, dtype=compute_dtype)

image_processor = Qwen2VLImageProcessor()
tokenizer = MIXQwen2Tokenizer.from_pretrained(args.model_path)
processor = Qwen2VLProcessor(image_processor, tokenizer)

# min_pixels = 256*28*28 # 200704
# max_pixels = 1280*28*28 # 1003520
# processor = Qwen2VLProcessor(image_processor, tokenizer, min_pixels=min_pixels, max_pixels=max_pixels)

messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": f"{args.image_file}",
},
{"type": "text", "text": f"{args.question}"},
],
}
]

# Preparation for inference
image_inputs, video_inputs = process_vision_info(messages)

question = messages[0]["content"][1]["text"]
image_pad_token = "<|vision_start|><|image_pad|><|vision_end|>"
text = f"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n{image_pad_token}{question}<|im_end|>\n<|im_start|>assistant\n"
text = [text]

inputs = processor(
text=text,
images=image_inputs,
videos=video_inputs,
padding=True,
return_tensors="pd",
)

if args.benchmark:
import time

start = 0.0
total = 0.0
for i in range(20):
if i > 10:
start = time.time()
with paddle.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_p=0.001,
top_k=1,
) # already trimmed in paddle
output_text = processor.batch_decode(
generated_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False
)
if i > 10:
total += time.time() - start
print("s/it: ", total / 10)
print(f"\nGPU memory_allocated: {paddle.device.cuda.memory_allocated() / 1024 ** 3:.2f} GB")
print(f"\nGPU max_memory_allocated: {paddle.device.cuda.max_memory_allocated() / 1024 ** 3:.2f} GB")
print(f"\nGPU memory_reserved: {paddle.device.cuda.memory_reserved() / 1024 ** 3:.2f} GB")
print(f"\nGPU max_memory_reserved: {paddle.device.cuda.max_memory_reserved() / 1024 ** 3:.2f} GB")
print("output_text:\n", output_text)

else:
with paddle.no_grad():
# Inference: Generation of the output
generated_ids = model.generate(
**inputs,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_p=0.001,
top_k=1,
) # already trimmed in paddle
output_text = processor.batch_decode(
generated_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print("output_text:\n", output_text[0])


if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default="PaddleMIX/PPDocBee-2B-1129")
parser.add_argument("--question", type=str, default="识别这份表格的内容")
parser.add_argument("--image_file", type=str, default="paddlemix/demo_images/medal_table.png")
parser.add_argument("--temperature", type=float, default=0.1)
parser.add_argument("--max_new_tokens", type=int, default=2048)
parser.add_argument("--fp16", action="store_true")
parser.add_argument("--benchmark", action="store_true")
args = parser.parse_args()
main(args)
Empty file.
17 changes: 17 additions & 0 deletions paddlemix/models/ppdocbee2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.



from .modeling_ppdocbee2 import PPDocBee2ForConditionalGeneration, PPDocBee2TransformerPretrainedModel
104 changes: 104 additions & 0 deletions paddlemix/models/ppdocbee2/modeling_ppdocbee2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import paddle
import paddle.nn.functional as F

from ..qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLPreTrainedModel, Qwen2_5_VLModel, Qwen2LMHead
from ..qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLForConditionalGeneration
from ..qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VisionTransformerPretrainedModel

class PPDocBee2TransformerPretrainedModel(Qwen2_5_VisionTransformerPretrainedModel):
layer_idx = 15

def forward(self, hidden_states: paddle.Tensor, grid_thw: paddle.Tensor) -> paddle.Tensor:
"""
Args:
hidden_states (`paddle.Tensor` of shape `(batch_size, seq_len, hidden_size)`):
The final hidden states of the model.
grid_thw (`paddle.Tensor` of shape `(num_images_or_videos, 3)`):
The temporal, height and width of feature shape of each image in LLM.

Returns:
`paddle.Tensor`: hidden_states.
"""
hidden_states = self.patch_embed(hidden_states)
rotary_pos_emb = self.rot_pos_emb(grid_thw)
window_index, cu_window_seqlens = self.get_window_index(grid_thw)
cu_window_seqlens = paddle.to_tensor(data=cu_window_seqlens, dtype="int32", place=hidden_states.place)
cu_window_seqlens = paddle.unique_consecutive(x=cu_window_seqlens)
seq_len, _ = tuple(hidden_states.shape)
hidden_states = hidden_states.reshape([seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1])
hidden_states = hidden_states[window_index, :, :]
hidden_states = hidden_states.reshape([seq_len, -1])
rotary_pos_emb = rotary_pos_emb.reshape([seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1])
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
rotary_pos_emb = rotary_pos_emb.reshape([seq_len, -1])

cu_seqlens = paddle.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
axis=0, dtype="int32"
)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)

multi_vit = []
for layer_num, blk in enumerate(self.blocks):
if layer_num in self.fullatt_block_indexes:
cu_seqlens_now = cu_seqlens
else:
cu_seqlens_now = cu_window_seqlens
if self.enable_recompute and self.training:
hidden_states = self.recompute_training_full(blk, hidden_states, cu_seqlens_now, rotary_pos_emb)
else:
hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens_now, rotary_pos_emb=rotary_pos_emb)

multi_vit.append(hidden_states.clone()) # TODO

if isinstance(self.layer_idx, int):
hidden_states = self.merger(hidden_states + multi_vit[self.layer_idx])
elif isinstance(self.layer_idx, list):
hidden_states = self.merger(hidden_states + sum([multi_vit[id] for id in self.layer_idx])/len(self.layer_idx))
else:
raise AttributeError(f'{type(self.layer_idx), self.layer_idx}')

reverse_indices = paddle.argsort(x=window_index)
hidden_states = hidden_states[reverse_indices, :]

return hidden_states



class PPDocBee2ForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
def __init__(self, config, attn_implementation="flash_attention_2"):
# Qwen2_5_VLPreTrainedModel.__init__(config)
super(Qwen2_5_VLForConditionalGeneration, self).__init__(config)

# super().__init__(config, attn_implementation)
# self.visual = PPDocBee2TransformerPretrainedModel._from_config(config.vision_config)

config._attn_implementation = attn_implementation
config.vision_config._attn_implementation = attn_implementation

self.visual = PPDocBee2TransformerPretrainedModel._from_config(config.vision_config)
self.model = Qwen2_5_VLModel(config)
self.vocab_size = config.vocab_size
if config.tie_word_embeddings:
self.lm_head = Qwen2LMHead(config, embedding_weights=self.model.embed_tokens.weight, transpose_y=True)
self.tie_weights()
else:
self.lm_head = Qwen2LMHead(config)
self.padding_side = "left" # set it to left by default, user can use setter to change padding_sides

self.enable_recompute = False