-
Notifications
You must be signed in to change notification settings - Fork 108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add prestartup script #1136
base: main
Are you sure you want to change the base?
Add prestartup script #1136
Conversation
WalkthroughThe changes introduce new functionalities in the Changes
Poem
Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media? 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
CodeRabbit Configuration File (
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Outside diff range and nitpick comments (6)
onediff_comfy_nodes/prestartup_script.py (2)
6-9
: Add directory validation and improve code documentation.While the path handling is robust using absolute paths, consider these improvements:
- Add directory existence validation
- Add docstring explaining the purpose
- Add type hints for better IDE support
+from typing import Final + +# Directory containing prestartup scripts to be loaded dynamically +ONEDIFF_COMFY_NODES_DIR: Final[str] = os.path.dirname(os.path.abspath(__file__)) +ONEDIFF_COMFY_PRESTARTUP_SCRIPTS_DIR: Final[str] = os.path.join( + ONEDIFF_COMFY_NODES_DIR, "prestartup_scripts" +) + +if not os.path.isdir(ONEDIFF_COMFY_PRESTARTUP_SCRIPTS_DIR): + raise FileNotFoundError( + f"Prestartup scripts directory not found: {ONEDIFF_COMFY_PRESTARTUP_SCRIPTS_DIR}" + )
11-11
: Consider using a more isolated approach for module imports.Modifying
sys.path
directly can lead to naming conflicts and potentially load unintended modules. Consider alternatives:
- Use relative imports
- Use a context manager to temporarily modify the path
- Use importlib.util.spec_from_file_location for more controlled imports
Example of a safer approach:
from contextlib import contextmanager import importlib.util @contextmanager def temporary_sys_path(path): sys.path.insert(0, path) try: yield finally: sys.path.remove(path) with temporary_sys_path(ONEDIFF_COMFY_NODES_DIR): # perform importsonediff_comfy_nodes/prestartup_scripts/npu.py (4)
13-17
: Remove unused importThe
is_intel_xpu
import is not used in this file.from comfy.model_management import ( is_device_cpu, - is_intel_xpu, ENABLE_PYTORCH_ATTENTION, )
🧰 Tools
🪛 Ruff
15-15:
comfy.model_management.is_intel_xpu
imported but unusedRemove unused import:
comfy.model_management.is_intel_xpu
(F401)
19-19
: Document JIT compilation settingPlease add a comment explaining why JIT compilation is disabled for NPU.
+ # Disable JIT compilation for NPU to [explain reason here] torch_npu.npu.set_compile_mode(jit_compile=False)
21-24
: Simplify boolean returnThe function can be simplified to directly return the boolean value.
def patch_pytorch_attention_flash_attention(): - if ENABLE_PYTORCH_ATTENTION: - return True - return False + return bool(ENABLE_PYTORCH_ATTENTION)🧰 Tools
🪛 Ruff
22-24: Return the condition
bool(ENABLE_PYTORCH_ATTENTION)
directlyReplace with
return bool(ENABLE_PYTORCH_ATTENTION)
(SIM103)
1-59
: Consider a more robust patching mechanismThe current approach of directly replacing functions in
comfy.model_management
could be fragile if the module is imported before this patch is applied. Consider:
- Implementing a proper plugin/hook system in comfy for device-specific overrides
- Adding version checks to ensure compatibility with the patched module
- Adding runtime validation to ensure patches are applied successfully
Would you like help designing a more robust patching mechanism?
🧰 Tools
🪛 Ruff
4-4:
torch_npu.contrib.transfer_to_npu
imported but unusedRemove unused import:
torch_npu.contrib.transfer_to_npu
(F401)
7-7: Do not use bare
except
(E722)
15-15:
comfy.model_management.is_intel_xpu
imported but unusedRemove unused import:
comfy.model_management.is_intel_xpu
(F401)
22-24: Return the condition
bool(ENABLE_PYTORCH_ATTENTION)
directlyReplace with
return bool(ENABLE_PYTORCH_ATTENTION)
(SIM103)
44-45: Use a single
if
statement instead of nestedif
statementsCombine
if
statements usingand
(SIM102)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
📒 Files selected for processing (3)
onediff_comfy_nodes/prestartup_script.py
(1 hunks)onediff_comfy_nodes/prestartup_scripts/gcu.py
(1 hunks)onediff_comfy_nodes/prestartup_scripts/npu.py
(1 hunks)
🧰 Additional context used
🪛 Ruff
onediff_comfy_nodes/prestartup_scripts/gcu.py
3-3: torch_gcu.transfer_to_gcu
imported but unused
Remove unused import: torch_gcu.transfer_to_gcu
(F401)
4-4: Do not use bare except
(E722)
onediff_comfy_nodes/prestartup_scripts/npu.py
4-4: torch_npu.contrib.transfer_to_npu
imported but unused
Remove unused import: torch_npu.contrib.transfer_to_npu
(F401)
7-7: Do not use bare except
(E722)
15-15: comfy.model_management.is_intel_xpu
imported but unused
Remove unused import: comfy.model_management.is_intel_xpu
(F401)
22-24: Return the condition bool(ENABLE_PYTORCH_ATTENTION)
directly
Replace with return bool(ENABLE_PYTORCH_ATTENTION)
(SIM103)
44-45: Use a single if
statement instead of nested if
statements
Combine if
statements using and
(SIM102)
🔇 Additional comments (3)
onediff_comfy_nodes/prestartup_script.py (2)
1-4
: LGTM! Clean and minimal imports.
The imports are appropriate and well-organized, using only the necessary standard library modules.
13-17
:
Add security measures and error handling for dynamic imports.
The current implementation has several security and reliability concerns:
- Arbitrary code execution risk from untrusted modules
- No error handling for failed imports
- No validation of module contents
- No logging of loaded modules
Consider implementing these safety measures:
+import logging
+import hashlib
+from typing import Set
+
+# Set up logging
+logger = logging.getLogger(__name__)
+
+# Maintain a set of allowed modules (consider moving to a config file)
+ALLOWED_MODULES: Set[str] = {"gcu", "npu"}
+
+def validate_module(filepath: str) -> bool:
+ """Validate module before loading."""
+ # Add your validation logic here (e.g., checksum verification)
+ return True
+
-for filename in sorted(os.listdir(ONEDIFF_COMFY_PRESTARTUP_SCRIPTS_DIR)):
- if filename.endswith(".py") and filename[0] != "_":
- importlib.import_module(f"prestartup_scripts.{filename[:-3]}")
- elif filename.endswith(".so"):
- importlib.import_module(f"prestartup_scripts.{filename.split('.')[0]}")
+def load_module(filename: str) -> None:
+ """Safely load a module with error handling."""
+ module_name = filename.split('.')[0]
+ if module_name not in ALLOWED_MODULES:
+ logger.warning(f"Skipping unauthorized module: {filename}")
+ return
+
+ filepath = os.path.join(ONEDIFF_COMFY_PRESTARTUP_SCRIPTS_DIR, filename)
+ if not validate_module(filepath):
+ logger.error(f"Module validation failed: {filename}")
+ return
+
+ try:
+ if filename.endswith(".py") and filename[0] != "_":
+ module = importlib.import_module(f"prestartup_scripts.{filename[:-3]}")
+ elif filename.endswith(".so"):
+ module = importlib.import_module(f"prestartup_scripts.{module_name}")
+ logger.info(f"Successfully loaded module: {filename}")
+ except Exception as e:
+ logger.error(f"Failed to load module {filename}: {str(e)}")
+ raise
+
+for filename in sorted(os.listdir(ONEDIFF_COMFY_PRESTARTUP_SCRIPTS_DIR)):
+ load_module(filename)
Let's verify what modules are currently in the prestartup_scripts directory:
onediff_comfy_nodes/prestartup_scripts/npu.py (1)
54-59
: Verify function patches impact
The model management function replacements look correct, but we should verify that all callers of these functions handle the new behavior appropriately.
#!/bin/bash
# Search for usages of the patched functions
echo "Searching for usages of patched functions..."
rg "pytorch_attention_flash_attention|get_free_memory|should_use_fp16|should_use_bf16" \
--type py \
-B 2 -A 2
_IS_NPU_AVAILABLE = False | ||
try: | ||
import torch_npu | ||
from torch_npu.contrib import transfer_to_npu | ||
|
||
_IS_NPU_AVAILABLE = True | ||
except: | ||
pass | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Improve error handling and imports
Several improvements needed in the NPU availability check:
- The bare
except
clause is too broad and could mask important errors - The
transfer_to_npu
import is unused - Missing error logging for debugging NPU availability issues
Apply this diff:
_IS_NPU_AVAILABLE = False
try:
import torch_npu
- from torch_npu.contrib import transfer_to_npu
_IS_NPU_AVAILABLE = True
-except:
+except ImportError as e:
+ import logging
+ logging.info(f"NPU support not available: {e}")
pass
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
_IS_NPU_AVAILABLE = False | |
try: | |
import torch_npu | |
from torch_npu.contrib import transfer_to_npu | |
_IS_NPU_AVAILABLE = True | |
except: | |
pass | |
_IS_NPU_AVAILABLE = False | |
try: | |
import torch_npu | |
_IS_NPU_AVAILABLE = True | |
except ImportError as e: | |
import logging | |
logging.info(f"NPU support not available: {e}") | |
pass |
🧰 Tools
🪛 Ruff
4-4: torch_npu.contrib.transfer_to_npu
imported but unused
Remove unused import: torch_npu.contrib.transfer_to_npu
(F401)
7-7: Do not use bare except
(E722)
def patch_get_free_memory(dev=None, torch_free_too=False): | ||
# stats = torch.npu.memory_stats(dev) | ||
# mem_active = stats['active_bytes.all.current'] | ||
# mem_reserved = stats['reserved_bytes.all.current'] | ||
# mem_free_npu, _ = torch.npu.mem_get_info(dev) | ||
# mem_free_torch = mem_reserved - mem_active | ||
# mem_free_total = mem_free_npu + mem_free_torch | ||
mem_free_total = 48 * 1024 * 1024 * 1024 # TODO | ||
mem_free_torch = mem_free_total | ||
|
||
if torch_free_too: | ||
return (mem_free_total, mem_free_torch) | ||
else: | ||
return mem_free_total | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codebase verification
Critical: Implement proper NPU memory management
The review comment is correct. The code currently uses a hardcoded 48GB memory value instead of properly calculating available NPU memory. The suggested fix in the review comment is appropriate because:
- The commented code shows the correct approach using
torch.npu.memory_stats()
andtorch.npu.mem_get_info()
APIs - The file imports
torch_npu
and has NPU-specific implementations - The hardcoded value (48GB) is not derived from any configuration and is explicitly marked as TODO
The proposed diff correctly:
- Restores the actual memory calculation logic
- Adds error handling with logging
- Includes validation with
max(0, ...)
to prevent negative memory values - Provides a clear error message when memory info cannot be retrieved
🔗 Analysis chain
Critical: Replace hardcoded memory value
The function has several issues:
- Uses a hardcoded memory value (48GB) marked as TODO
- Actual memory calculation code is commented out
- No validation for memory values
Please implement proper memory calculation:
- Uncomment and fix the NPU memory calculation code
- Add validation for memory values
- Consider adding a fallback mechanism if memory info cannot be retrieved
def patch_get_free_memory(dev=None, torch_free_too=False):
- # stats = torch.npu.memory_stats(dev)
- # mem_active = stats['active_bytes.all.current']
- # mem_reserved = stats['reserved_bytes.all.current']
- # mem_free_npu, _ = torch.npu.mem_get_info(dev)
- # mem_free_torch = mem_reserved - mem_active
- # mem_free_total = mem_free_npu + mem_free_torch
- mem_free_total = 48 * 1024 * 1024 * 1024 # TODO
- mem_free_torch = mem_free_total
+ try:
+ stats = torch.npu.memory_stats(dev)
+ mem_active = stats['active_bytes.all.current']
+ mem_reserved = stats['reserved_bytes.all.current']
+ mem_free_npu, _ = torch.npu.mem_get_info(dev)
+ mem_free_torch = max(0, mem_reserved - mem_active)
+ mem_free_total = max(0, mem_free_npu + mem_free_torch)
+ except Exception as e:
+ import logging
+ logging.warning(f"Failed to get NPU memory info: {e}")
+ # Fallback to a conservative estimate or raise error
+ raise RuntimeError("Cannot determine available NPU memory")
Let me generate additional scripts to gather more context about the memory management implementation.
The first command will help us understand the original implementation that this patch is trying to override, and the second command will help identify any memory-related configurations or constants that might be relevant to this implementation.
Let me try one more search to find any NPU-specific memory management code or configurations.
This will help us understand if there are any NPU-specific memory management patterns or if the hardcoded value is used elsewhere in the codebase.
🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
#!/bin/bash
# Search for other memory-related implementations that might help
rg -l "mem_get_info|memory_stats" --type py
Length of output: 4097
Script:
#!/bin/bash
# Search for the original implementation or similar memory management code
rg -B 5 -A 5 "def get_free_memory" --type py
# Look for any configuration or constants related to memory management
rg -l "MEMORY|memory" --type py
Length of output: 1017
Script:
#!/bin/bash
# Look for NPU-specific memory management
rg -B 3 -A 3 "torch.npu" --type py
# Check for any memory-related constants or configurations in NPU files
rg "1024.*1024.*1024" --type py
Length of output: 2176
Summary by CodeRabbit