From 198a070d3d0f6ad2aa6137da10ac3d6a952ff3f8 Mon Sep 17 00:00:00 2001 From: Jintao Date: Thu, 14 Mar 2024 17:07:40 +0800 Subject: [PATCH] fix deepseek-vl eval_loss not found bug (#552) --- swift/llm/utils/model.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/swift/llm/utils/model.py b/swift/llm/utils/model.py index 02692f3343..3278cd640b 100644 --- a/swift/llm/utils/model.py +++ b/swift/llm/utils/model.py @@ -3,7 +3,7 @@ import os import sys from contextlib import nullcontext -from functools import partial, update_wrapper +from functools import partial, update_wrapper, wraps from types import MethodType from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Type @@ -1725,18 +1725,20 @@ def __prepare_inputs_embeds( def _patch_deepseek_vl(model) -> None: model.prepare_inputs_embeds = MethodType(__prepare_inputs_embeds, model) - def get_new_func(func_name: str): + def _get_new_func(func_name: str): + _old_func = getattr(model.language_model, func_name) - def new_func(*args, **kwargs): - return getattr(model.language_model, func_name)(*args, **kwargs) + @wraps(_old_func) + def _new_func(*args, **kwargs): + return _old_func(*args, **kwargs) - return new_func + return _new_func for key in [ 'generate', 'get_input_embeddings', 'gradient_checkpointing_enable', 'forward' ]: - setattr(model, key, get_new_func(key)) + setattr(model, key, _get_new_func(key)) @register_model(