diff --git a/deepspeed/utils/logging.py b/deepspeed/utils/logging.py index d5db29485db3..77173f2839ca 100644 --- a/deepspeed/utils/logging.py +++ b/deepspeed/utils/logging.py @@ -7,7 +7,8 @@ import logging import sys import os -from deepspeed.runtime.compiler import is_compile_supported, is_compiling +import torch +from deepspeed.utils.torch import required_torch_version log_levels = { "debug": logging.DEBUG, @@ -20,31 +21,6 @@ class LoggerFactory: - def create_warning_filter(logger): - warn = False - - def warn_once(record): - nonlocal warn - if is_compile_supported() and is_compiling() and not warn: - warn = True - logger.warning("To avoid graph breaks caused by logger in compile-mode, it is recommended to" - " disable logging by setting env var DISABLE_LOGS_WHILE_COMPILING=1") - return True - - return warn_once - - @staticmethod - def logging_decorator(func): - - @functools.wraps(func) - def wrapper(*args, **kwargs): - if is_compiling(): - return - else: - return func(*args, **kwargs) - - return wrapper - @staticmethod def create_logger(name=None, level=logging.INFO): """create a logger @@ -70,12 +46,15 @@ def create_logger(name=None, level=logging.INFO): ch.setLevel(level) ch.setFormatter(formatter) logger_.addHandler(ch) - if os.getenv("DISABLE_LOGS_WHILE_COMPILING", "0") == "1": - for method in ['info', 'debug', 'error', 'warning', 'critical', 'exception']: + if required_torch_version(min_version=2.6) and os.getenv("DISABLE_LOGS_WHILE_COMPILING", "0") == "1": + excluded_set = { + item.strip() + for item in os.getenv("LOGGER_METHODS_TO_EXCLUDE_FROM_DISABLE", "").split(",") + } + ignore_set = {'info', 'debug', 'error', 'warning', 'critical', 'exception', 'isEnabledFor'} - excluded_set + for method in ignore_set: original_logger = getattr(logger_, method) - setattr(logger_, method, LoggerFactory.logging_decorator(original_logger)) - else: - logger_.addFilter(LoggerFactory.create_warning_filter(logger_)) + torch._dynamo.config.ignore_logger_methods.add(original_logger) return logger_