Skip to content

Commit

Permalink
disable_logger_for_pt2.6
Browse files Browse the repository at this point in the history
  • Loading branch information
SNahir committed Feb 4, 2025
1 parent 241bffd commit a9dc0f9
Showing 1 changed file with 10 additions and 31 deletions.
41 changes: 10 additions & 31 deletions deepspeed/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import logging
import sys
import os
from deepspeed.runtime.compiler import is_compile_supported, is_compiling
import torch
from deepspeed.utils.torch import required_torch_version

log_levels = {
"debug": logging.DEBUG,
Expand All @@ -20,31 +21,6 @@

class LoggerFactory:

def create_warning_filter(logger):
warn = False

def warn_once(record):
nonlocal warn
if is_compile_supported() and is_compiling() and not warn:
warn = True
logger.warning("To avoid graph breaks caused by logger in compile-mode, it is recommended to"
" disable logging by setting env var DISABLE_LOGS_WHILE_COMPILING=1")
return True

return warn_once

@staticmethod
def logging_decorator(func):

@functools.wraps(func)
def wrapper(*args, **kwargs):
if is_compiling():
return
else:
return func(*args, **kwargs)

return wrapper

@staticmethod
def create_logger(name=None, level=logging.INFO):
"""create a logger
Expand All @@ -70,12 +46,15 @@ def create_logger(name=None, level=logging.INFO):
ch.setLevel(level)
ch.setFormatter(formatter)
logger_.addHandler(ch)
if os.getenv("DISABLE_LOGS_WHILE_COMPILING", "0") == "1":
for method in ['info', 'debug', 'error', 'warning', 'critical', 'exception']:
if required_torch_version(min_version=2.6) and os.getenv("DISABLE_LOGS_WHILE_COMPILING", "0") == "1":
excluded_set = {
item.strip()
for item in os.getenv("LOGGER_METHODS_TO_EXCLUDE_FROM_DISABLE", "").split(",")
}
ignore_set = {'info', 'debug', 'error', 'warning', 'critical', 'exception', 'isEnabledFor'} - excluded_set
for method in ignore_set:
original_logger = getattr(logger_, method)
setattr(logger_, method, LoggerFactory.logging_decorator(original_logger))
else:
logger_.addFilter(LoggerFactory.create_warning_filter(logger_))
torch._dynamo.config.ignore_logger_methods.add(original_logger)
return logger_


Expand Down

0 comments on commit a9dc0f9

Please sign in to comment.