Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add prestartup script #1136

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open

Add prestartup script #1136

wants to merge 3 commits into from

Conversation

hjchen2
Copy link
Contributor

@hjchen2 hjchen2 commented Nov 7, 2024

Summary by CodeRabbit

  • New Features
    • Introduced dynamic module importing from a specified directory for enhanced flexibility.
    • Added functionality to check for NPU availability and conditionally modify device management behavior.
  • Bug Fixes
    • Implemented error handling for module imports to prevent execution interruptions.
  • Refactor
    • Updated existing functions in the model management module to integrate with new NPU functionalities.

Copy link
Contributor

coderabbitai bot commented Nov 7, 2024

Walkthrough

The changes introduce new functionalities in the onediff_comfy_nodes package, primarily through the addition of the prestartup_script.py, gcu.py, and npu.py files. The prestartup_script.py dynamically imports Python and shared object files from a specified directory, while gcu.py implements error handling for the import of the torch_gcu module. The npu.py file checks for NPU availability, conditionally imports components, and defines several patch functions to override existing functionalities in the comfy.model_management module.

Changes

File Path Change Summary
onediff_comfy_nodes/prestartup_script.py Introduced a script for dynamic importing of Python and shared object files from prestartup_scripts.
onediff_comfy_nodes/prestartup_scripts/gcu.py Added a try-except block to handle import errors for torch_gcu gracefully.
onediff_comfy_nodes/prestartup_scripts/npu.py Added functionality to check NPU availability, import related components, and define patch functions. Functions added and signatures updated for comfy.model_management.

Poem

In the burrow where code does play,
New scripts hop in, brightening the day.
With patches and imports, oh what a sight,
NPU and GCU, ready for flight!
So let’s code and create, with joy we’ll embrace,
A world of new features, in our cozy space! 🐇✨


Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@hjchen2 hjchen2 requested review from strint, clackhan and ccssu November 7, 2024 09:07
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Outside diff range and nitpick comments (6)
onediff_comfy_nodes/prestartup_script.py (2)

6-9: Add directory validation and improve code documentation.

While the path handling is robust using absolute paths, consider these improvements:

  1. Add directory existence validation
  2. Add docstring explaining the purpose
  3. Add type hints for better IDE support
+from typing import Final
+
+# Directory containing prestartup scripts to be loaded dynamically
+ONEDIFF_COMFY_NODES_DIR: Final[str] = os.path.dirname(os.path.abspath(__file__))
+ONEDIFF_COMFY_PRESTARTUP_SCRIPTS_DIR: Final[str] = os.path.join(
+    ONEDIFF_COMFY_NODES_DIR, "prestartup_scripts"
+)
+
+if not os.path.isdir(ONEDIFF_COMFY_PRESTARTUP_SCRIPTS_DIR):
+    raise FileNotFoundError(
+        f"Prestartup scripts directory not found: {ONEDIFF_COMFY_PRESTARTUP_SCRIPTS_DIR}"
+    )

11-11: Consider using a more isolated approach for module imports.

Modifying sys.path directly can lead to naming conflicts and potentially load unintended modules. Consider alternatives:

  1. Use relative imports
  2. Use a context manager to temporarily modify the path
  3. Use importlib.util.spec_from_file_location for more controlled imports

Example of a safer approach:

from contextlib import contextmanager
import importlib.util

@contextmanager
def temporary_sys_path(path):
    sys.path.insert(0, path)
    try:
        yield
    finally:
        sys.path.remove(path)

with temporary_sys_path(ONEDIFF_COMFY_NODES_DIR):
    # perform imports
onediff_comfy_nodes/prestartup_scripts/npu.py (4)

13-17: Remove unused import

The is_intel_xpu import is not used in this file.

 from comfy.model_management import (
     is_device_cpu,
-    is_intel_xpu,
     ENABLE_PYTORCH_ATTENTION,
 )
🧰 Tools
🪛 Ruff

15-15: comfy.model_management.is_intel_xpu imported but unused

Remove unused import: comfy.model_management.is_intel_xpu

(F401)


19-19: Document JIT compilation setting

Please add a comment explaining why JIT compilation is disabled for NPU.

+    # Disable JIT compilation for NPU to [explain reason here]
     torch_npu.npu.set_compile_mode(jit_compile=False)

21-24: Simplify boolean return

The function can be simplified to directly return the boolean value.

 def patch_pytorch_attention_flash_attention():
-    if ENABLE_PYTORCH_ATTENTION:
-        return True
-    return False
+    return bool(ENABLE_PYTORCH_ATTENTION)
🧰 Tools
🪛 Ruff

22-24: Return the condition bool(ENABLE_PYTORCH_ATTENTION) directly

Replace with return bool(ENABLE_PYTORCH_ATTENTION)

(SIM103)


1-59: Consider a more robust patching mechanism

The current approach of directly replacing functions in comfy.model_management could be fragile if the module is imported before this patch is applied. Consider:

  1. Implementing a proper plugin/hook system in comfy for device-specific overrides
  2. Adding version checks to ensure compatibility with the patched module
  3. Adding runtime validation to ensure patches are applied successfully

Would you like help designing a more robust patching mechanism?

🧰 Tools
🪛 Ruff

4-4: torch_npu.contrib.transfer_to_npu imported but unused

Remove unused import: torch_npu.contrib.transfer_to_npu

(F401)


7-7: Do not use bare except

(E722)


15-15: comfy.model_management.is_intel_xpu imported but unused

Remove unused import: comfy.model_management.is_intel_xpu

(F401)


22-24: Return the condition bool(ENABLE_PYTORCH_ATTENTION) directly

Replace with return bool(ENABLE_PYTORCH_ATTENTION)

(SIM103)


44-45: Use a single if statement instead of nested if statements

Combine if statements using and

(SIM102)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL

📥 Commits

Reviewing files that changed from the base of the PR and between a6eafe5 and b92151e.

📒 Files selected for processing (3)
  • onediff_comfy_nodes/prestartup_script.py (1 hunks)
  • onediff_comfy_nodes/prestartup_scripts/gcu.py (1 hunks)
  • onediff_comfy_nodes/prestartup_scripts/npu.py (1 hunks)
🧰 Additional context used
🪛 Ruff
onediff_comfy_nodes/prestartup_scripts/gcu.py

3-3: torch_gcu.transfer_to_gcu imported but unused

Remove unused import: torch_gcu.transfer_to_gcu

(F401)


4-4: Do not use bare except

(E722)

onediff_comfy_nodes/prestartup_scripts/npu.py

4-4: torch_npu.contrib.transfer_to_npu imported but unused

Remove unused import: torch_npu.contrib.transfer_to_npu

(F401)


7-7: Do not use bare except

(E722)


15-15: comfy.model_management.is_intel_xpu imported but unused

Remove unused import: comfy.model_management.is_intel_xpu

(F401)


22-24: Return the condition bool(ENABLE_PYTORCH_ATTENTION) directly

Replace with return bool(ENABLE_PYTORCH_ATTENTION)

(SIM103)


44-45: Use a single if statement instead of nested if statements

Combine if statements using and

(SIM102)

🔇 Additional comments (3)
onediff_comfy_nodes/prestartup_script.py (2)

1-4: LGTM! Clean and minimal imports.

The imports are appropriate and well-organized, using only the necessary standard library modules.


13-17: ⚠️ Potential issue

Add security measures and error handling for dynamic imports.

The current implementation has several security and reliability concerns:

  1. Arbitrary code execution risk from untrusted modules
  2. No error handling for failed imports
  3. No validation of module contents
  4. No logging of loaded modules

Consider implementing these safety measures:

+import logging
+import hashlib
+from typing import Set
+
+# Set up logging
+logger = logging.getLogger(__name__)
+
+# Maintain a set of allowed modules (consider moving to a config file)
+ALLOWED_MODULES: Set[str] = {"gcu", "npu"}
+
+def validate_module(filepath: str) -> bool:
+    """Validate module before loading."""
+    # Add your validation logic here (e.g., checksum verification)
+    return True
+
-for filename in sorted(os.listdir(ONEDIFF_COMFY_PRESTARTUP_SCRIPTS_DIR)):
-    if filename.endswith(".py") and filename[0] != "_":
-        importlib.import_module(f"prestartup_scripts.{filename[:-3]}")
-    elif filename.endswith(".so"):
-        importlib.import_module(f"prestartup_scripts.{filename.split('.')[0]}")
+def load_module(filename: str) -> None:
+    """Safely load a module with error handling."""
+    module_name = filename.split('.')[0]
+    if module_name not in ALLOWED_MODULES:
+        logger.warning(f"Skipping unauthorized module: {filename}")
+        return
+
+    filepath = os.path.join(ONEDIFF_COMFY_PRESTARTUP_SCRIPTS_DIR, filename)
+    if not validate_module(filepath):
+        logger.error(f"Module validation failed: {filename}")
+        return
+
+    try:
+        if filename.endswith(".py") and filename[0] != "_":
+            module = importlib.import_module(f"prestartup_scripts.{filename[:-3]}")
+        elif filename.endswith(".so"):
+            module = importlib.import_module(f"prestartup_scripts.{module_name}")
+        logger.info(f"Successfully loaded module: {filename}")
+    except Exception as e:
+        logger.error(f"Failed to load module {filename}: {str(e)}")
+        raise
+
+for filename in sorted(os.listdir(ONEDIFF_COMFY_PRESTARTUP_SCRIPTS_DIR)):
+    load_module(filename)

Let's verify what modules are currently in the prestartup_scripts directory:

onediff_comfy_nodes/prestartup_scripts/npu.py (1)

54-59: Verify function patches impact

The model management function replacements look correct, but we should verify that all callers of these functions handle the new behavior appropriately.

#!/bin/bash
# Search for usages of the patched functions
echo "Searching for usages of patched functions..."
rg "pytorch_attention_flash_attention|get_free_memory|should_use_fp16|should_use_bf16" \
  --type py \
  -B 2 -A 2

Comment on lines +1 to +9
_IS_NPU_AVAILABLE = False
try:
import torch_npu
from torch_npu.contrib import transfer_to_npu

_IS_NPU_AVAILABLE = True
except:
pass

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Improve error handling and imports

Several improvements needed in the NPU availability check:

  1. The bare except clause is too broad and could mask important errors
  2. The transfer_to_npu import is unused
  3. Missing error logging for debugging NPU availability issues

Apply this diff:

 _IS_NPU_AVAILABLE = False
 try:
     import torch_npu
-    from torch_npu.contrib import transfer_to_npu
 
     _IS_NPU_AVAILABLE = True
-except:
+except ImportError as e:
+    import logging
+    logging.info(f"NPU support not available: {e}")
     pass
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
_IS_NPU_AVAILABLE = False
try:
import torch_npu
from torch_npu.contrib import transfer_to_npu
_IS_NPU_AVAILABLE = True
except:
pass
_IS_NPU_AVAILABLE = False
try:
import torch_npu
_IS_NPU_AVAILABLE = True
except ImportError as e:
import logging
logging.info(f"NPU support not available: {e}")
pass
🧰 Tools
🪛 Ruff

4-4: torch_npu.contrib.transfer_to_npu imported but unused

Remove unused import: torch_npu.contrib.transfer_to_npu

(F401)


7-7: Do not use bare except

(E722)

Comment on lines +26 to +40
def patch_get_free_memory(dev=None, torch_free_too=False):
# stats = torch.npu.memory_stats(dev)
# mem_active = stats['active_bytes.all.current']
# mem_reserved = stats['reserved_bytes.all.current']
# mem_free_npu, _ = torch.npu.mem_get_info(dev)
# mem_free_torch = mem_reserved - mem_active
# mem_free_total = mem_free_npu + mem_free_torch
mem_free_total = 48 * 1024 * 1024 * 1024 # TODO
mem_free_torch = mem_free_total

if torch_free_too:
return (mem_free_total, mem_free_torch)
else:
return mem_free_total

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codebase verification

⚠️ Potential issue

Critical: Implement proper NPU memory management

The review comment is correct. The code currently uses a hardcoded 48GB memory value instead of properly calculating available NPU memory. The suggested fix in the review comment is appropriate because:

  1. The commented code shows the correct approach using torch.npu.memory_stats() and torch.npu.mem_get_info() APIs
  2. The file imports torch_npu and has NPU-specific implementations
  3. The hardcoded value (48GB) is not derived from any configuration and is explicitly marked as TODO

The proposed diff correctly:

  • Restores the actual memory calculation logic
  • Adds error handling with logging
  • Includes validation with max(0, ...) to prevent negative memory values
  • Provides a clear error message when memory info cannot be retrieved
🔗 Analysis chain

Critical: Replace hardcoded memory value

The function has several issues:

  1. Uses a hardcoded memory value (48GB) marked as TODO
  2. Actual memory calculation code is commented out
  3. No validation for memory values

Please implement proper memory calculation:

  1. Uncomment and fix the NPU memory calculation code
  2. Add validation for memory values
  3. Consider adding a fallback mechanism if memory info cannot be retrieved
 def patch_get_free_memory(dev=None, torch_free_too=False):
-    # stats = torch.npu.memory_stats(dev)
-    # mem_active = stats['active_bytes.all.current']
-    # mem_reserved = stats['reserved_bytes.all.current']
-    # mem_free_npu, _ = torch.npu.mem_get_info(dev)
-    # mem_free_torch = mem_reserved - mem_active
-    # mem_free_total = mem_free_npu + mem_free_torch
-    mem_free_total = 48 * 1024 * 1024 * 1024  # TODO
-    mem_free_torch = mem_free_total
+    try:
+        stats = torch.npu.memory_stats(dev)
+        mem_active = stats['active_bytes.all.current']
+        mem_reserved = stats['reserved_bytes.all.current']
+        mem_free_npu, _ = torch.npu.mem_get_info(dev)
+        mem_free_torch = max(0, mem_reserved - mem_active)
+        mem_free_total = max(0, mem_free_npu + mem_free_torch)
+    except Exception as e:
+        import logging
+        logging.warning(f"Failed to get NPU memory info: {e}")
+        # Fallback to a conservative estimate or raise error
+        raise RuntimeError("Cannot determine available NPU memory")

Let me generate additional scripts to gather more context about the memory management implementation.

The first command will help us understand the original implementation that this patch is trying to override, and the second command will help identify any memory-related configurations or constants that might be relevant to this implementation.


Let me try one more search to find any NPU-specific memory management code or configurations.

This will help us understand if there are any NPU-specific memory management patterns or if the hardcoded value is used elsewhere in the codebase.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for other memory-related implementations that might help
rg -l "mem_get_info|memory_stats" --type py

Length of output: 4097


Script:

#!/bin/bash
# Search for the original implementation or similar memory management code
rg -B 5 -A 5 "def get_free_memory" --type py

# Look for any configuration or constants related to memory management
rg -l "MEMORY|memory" --type py

Length of output: 1017


Script:

#!/bin/bash
# Look for NPU-specific memory management
rg -B 3 -A 3 "torch.npu" --type py

# Check for any memory-related constants or configurations in NPU files
rg "1024.*1024.*1024" --type py

Length of output: 2176

@hjchen2 hjchen2 enabled auto-merge (squash) November 14, 2024 07:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants