Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add prestartup script #1136

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions onediff_comfy_nodes/prestartup_script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import os
import sys
import importlib


ONEDIFF_COMFY_NODES_DIR = os.path.dirname(os.path.abspath(__file__))
ONEDIFF_COMFY_PRESTARTUP_SCRIPTS_DIR = os.path.join(
ONEDIFF_COMFY_NODES_DIR, "prestartup_scripts"
)

sys.path.append(ONEDIFF_COMFY_NODES_DIR)

for filename in sorted(os.listdir(ONEDIFF_COMFY_PRESTARTUP_SCRIPTS_DIR)):
if filename.endswith(".py") and filename[0] != "_":
importlib.import_module(f"prestartup_scripts.{filename[:-3]}")
elif filename.endswith(".so"):
importlib.import_module(f"prestartup_scripts.{filename.split('.')[0]}")
5 changes: 5 additions & 0 deletions onediff_comfy_nodes/prestartup_scripts/gcu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
try:
import torch_gcu
import torch_gcu.transfer_to_gcu
except:
pass
ccssu marked this conversation as resolved.
Show resolved Hide resolved
59 changes: 59 additions & 0 deletions onediff_comfy_nodes/prestartup_scripts/npu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
_IS_NPU_AVAILABLE = False
try:
import torch_npu
from torch_npu.contrib import transfer_to_npu

_IS_NPU_AVAILABLE = True
except:
pass

Comment on lines +1 to +9
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Improve error handling and imports

Several improvements needed in the NPU availability check:

  1. The bare except clause is too broad and could mask important errors
  2. The transfer_to_npu import is unused
  3. Missing error logging for debugging NPU availability issues

Apply this diff:

 _IS_NPU_AVAILABLE = False
 try:
     import torch_npu
-    from torch_npu.contrib import transfer_to_npu
 
     _IS_NPU_AVAILABLE = True
-except:
+except ImportError as e:
+    import logging
+    logging.info(f"NPU support not available: {e}")
     pass
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
_IS_NPU_AVAILABLE = False
try:
import torch_npu
from torch_npu.contrib import transfer_to_npu
_IS_NPU_AVAILABLE = True
except:
pass
_IS_NPU_AVAILABLE = False
try:
import torch_npu
_IS_NPU_AVAILABLE = True
except ImportError as e:
import logging
logging.info(f"NPU support not available: {e}")
pass
🧰 Tools
🪛 Ruff

4-4: torch_npu.contrib.transfer_to_npu imported but unused

Remove unused import: torch_npu.contrib.transfer_to_npu

(F401)


7-7: Do not use bare except

(E722)


if _IS_NPU_AVAILABLE:
import comfy
from comfy.model_management import (
is_device_cpu,
is_intel_xpu,
ENABLE_PYTORCH_ATTENTION,
)

torch_npu.npu.set_compile_mode(jit_compile=False)

def patch_pytorch_attention_flash_attention():
if ENABLE_PYTORCH_ATTENTION:
return True
return False

def patch_get_free_memory(dev=None, torch_free_too=False):
# stats = torch.npu.memory_stats(dev)
# mem_active = stats['active_bytes.all.current']
# mem_reserved = stats['reserved_bytes.all.current']
# mem_free_npu, _ = torch.npu.mem_get_info(dev)
# mem_free_torch = mem_reserved - mem_active
# mem_free_total = mem_free_npu + mem_free_torch
mem_free_total = 48 * 1024 * 1024 * 1024 # TODO
mem_free_torch = mem_free_total

if torch_free_too:
return (mem_free_total, mem_free_torch)
else:
return mem_free_total

Comment on lines +26 to +40
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codebase verification

⚠️ Potential issue

Critical: Implement proper NPU memory management

The review comment is correct. The code currently uses a hardcoded 48GB memory value instead of properly calculating available NPU memory. The suggested fix in the review comment is appropriate because:

  1. The commented code shows the correct approach using torch.npu.memory_stats() and torch.npu.mem_get_info() APIs
  2. The file imports torch_npu and has NPU-specific implementations
  3. The hardcoded value (48GB) is not derived from any configuration and is explicitly marked as TODO

The proposed diff correctly:

  • Restores the actual memory calculation logic
  • Adds error handling with logging
  • Includes validation with max(0, ...) to prevent negative memory values
  • Provides a clear error message when memory info cannot be retrieved
🔗 Analysis chain

Critical: Replace hardcoded memory value

The function has several issues:

  1. Uses a hardcoded memory value (48GB) marked as TODO
  2. Actual memory calculation code is commented out
  3. No validation for memory values

Please implement proper memory calculation:

  1. Uncomment and fix the NPU memory calculation code
  2. Add validation for memory values
  3. Consider adding a fallback mechanism if memory info cannot be retrieved
 def patch_get_free_memory(dev=None, torch_free_too=False):
-    # stats = torch.npu.memory_stats(dev)
-    # mem_active = stats['active_bytes.all.current']
-    # mem_reserved = stats['reserved_bytes.all.current']
-    # mem_free_npu, _ = torch.npu.mem_get_info(dev)
-    # mem_free_torch = mem_reserved - mem_active
-    # mem_free_total = mem_free_npu + mem_free_torch
-    mem_free_total = 48 * 1024 * 1024 * 1024  # TODO
-    mem_free_torch = mem_free_total
+    try:
+        stats = torch.npu.memory_stats(dev)
+        mem_active = stats['active_bytes.all.current']
+        mem_reserved = stats['reserved_bytes.all.current']
+        mem_free_npu, _ = torch.npu.mem_get_info(dev)
+        mem_free_torch = max(0, mem_reserved - mem_active)
+        mem_free_total = max(0, mem_free_npu + mem_free_torch)
+    except Exception as e:
+        import logging
+        logging.warning(f"Failed to get NPU memory info: {e}")
+        # Fallback to a conservative estimate or raise error
+        raise RuntimeError("Cannot determine available NPU memory")

Let me generate additional scripts to gather more context about the memory management implementation.

The first command will help us understand the original implementation that this patch is trying to override, and the second command will help identify any memory-related configurations or constants that might be relevant to this implementation.


Let me try one more search to find any NPU-specific memory management code or configurations.

This will help us understand if there are any NPU-specific memory management patterns or if the hardcoded value is used elsewhere in the codebase.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for other memory-related implementations that might help
rg -l "mem_get_info|memory_stats" --type py

Length of output: 4097


Script:

#!/bin/bash
# Search for the original implementation or similar memory management code
rg -B 5 -A 5 "def get_free_memory" --type py

# Look for any configuration or constants related to memory management
rg -l "MEMORY|memory" --type py

Length of output: 1017


Script:

#!/bin/bash
# Look for NPU-specific memory management
rg -B 3 -A 3 "torch.npu" --type py

# Check for any memory-related constants or configurations in NPU files
rg "1024.*1024.*1024" --type py

Length of output: 2176

def patch_should_use_fp16(
device=None, model_params=0, prioritize_performance=True, manual_cast=False
):
if device is not None:
if is_device_cpu(device):
return False
return True

def patch_should_use_bf16(
device=None, model_params=0, prioritize_performance=True, manual_cast=False
):
return False

comfy.model_management.pytorch_attention_flash_attention = (
patch_pytorch_attention_flash_attention
)
comfy.model_management.get_free_memory = patch_get_free_memory
comfy.model_management.should_use_fp16 = patch_should_use_fp16
comfy.model_management.should_use_bf16 = patch_should_use_bf16
Loading