Skip to content

Commit

Permalink
Ensure GPU URL before writing override configuration
Browse files Browse the repository at this point in the history
  • Loading branch information
yoland68 committed Aug 16, 2024
1 parent 71cd6ef commit 72a680d
Showing 1 changed file with 103 additions and 71 deletions.
174 changes: 103 additions & 71 deletions comfy_cli/uv.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,37 @@
from importlib import metadata
import os
from pathlib import Path
import re
import shutil
import subprocess
import sys
from importlib import metadata
from pathlib import Path
from textwrap import dedent
from typing import Any, Optional, Union, cast

from comfy_cli.constants import GPU_OPTION
from comfy_cli import ui
from comfy_cli.constants import GPU_OPTION

PathLike = Union[os.PathLike[str], str]


def _run(cmd: list[str], cwd: PathLike) -> subprocess.CompletedProcess[Any]:
return subprocess.run(
cmd,
cwd=cwd,
capture_output=True,
text=True,
check=True
)
return subprocess.run(cmd, cwd=cwd, capture_output=True, text=True, check=True)


def _check_call(cmd: list[str], cwd: Optional[PathLike] = None):
"""uses check_call to run pip, as reccomended by the pip maintainers.
see https://pip.pypa.io/en/stable/user_guide/#using-pip-from-your-program"""

subprocess.check_call(cmd, cwd=cwd)


_req_name_re: re.Pattern[str] = re.compile(r"require\s([\w-]+)")


def _req_re_closure(name: str) -> re.Pattern[str]:
return re.compile(rf"({name}\S+)")


def parse_uv_compile_error(err: str) -> tuple[str, list[str]]:
"""takes in stderr from a run of `uv pip compile` that failed due to requirement conflict and spits out
a tuple of (reqiurement_name, [requirement_spec_in_conflict_a, requirement_spec_in_conflict_b]). Will probably
Expand All @@ -47,17 +46,20 @@ def parse_uv_compile_error(err: str) -> tuple[str, list[str]]:

return reqName, cast(list[str], reqRe.findall(err))


class DependencyCompiler:
rocmPytorchUrl = "https://download.pytorch.org/whl/rocm6.0"
nvidiaPytorchUrl = "https://download.pytorch.org/whl/cu121"

overrideGpu = dedent("""
overrideGpu = dedent(
"""
# ensure usage of {gpu} version of pytorch
--extra-index-url {gpuUrl}
torch
torchsde
torchvision
""").strip()
"""
).strip()

reqNames = {
"requirements.txt",
Expand All @@ -68,26 +70,18 @@ class DependencyCompiler:

@staticmethod
def Find_Req_Files(*ders: PathLike) -> list[Path]:
return [file
return [
file
for der in ders
for file in Path(der).absolute().iterdir()
if file.name in DependencyCompiler.reqNames
]

@staticmethod
def Install_Build_Deps():
"""Use pip to install bare minimum requirements for uv to do its thing
"""
"""Use pip to install bare minimum requirements for uv to do its thing"""
if shutil.which("uv") is None:
cmd = [
sys.executable,
"-m",
"pip",
"install",
"--upgrade",
"pip",
"uv"
]
cmd = [sys.executable, "-m", "pip", "install", "--upgrade", "pip", "uv"]

_check_call(cmd=cmd)

Expand All @@ -114,22 +108,28 @@ def Compile(
# ensures that eg tqdm is latest version, even though an old tqdm is on the amd url
# see https://github.com/astral-sh/uv/blob/main/PIP_COMPATIBILITY.md#packages-that-exist-on-multiple-indexes and https://github.com/astral-sh/uv/issues/171
if index_strategy is not None:
cmd.extend([
"--index-strategy",
"unsafe-best-match",
])
cmd.extend(
[
"--index-strategy",
"unsafe-best-match",
]
)

if override is not None:
cmd.extend([
"--override",
str(override),
])
cmd.extend(
[
"--override",
str(override),
]
)

if out is not None:
cmd.extend([
"-o",
str(out),
])
cmd.extend(
[
"-o",
str(out),
]
)

try:
return _run(cmd, cwd)
Expand Down Expand Up @@ -168,7 +168,7 @@ def Install(
override: Optional[PathLike] = None,
extraUrl: Optional[str] = None,
index_strategy: Optional[str] = "unsafe-best-match",
dry: bool = False
dry: bool = False,
) -> subprocess.CompletedProcess[Any]:
cmd = [
sys.executable,
Expand All @@ -181,22 +181,28 @@ def Install(
]

if index_strategy is not None:
cmd.extend([
"--index-strategy",
"unsafe-best-match",
])
cmd.extend(
[
"--index-strategy",
"unsafe-best-match",
]
)

if extraUrl is not None:
cmd.extend([
"--extra-index-url",
extraUrl,
])
cmd.extend(
[
"--extra-index-url",
extraUrl,
]
)

if override is not None:
cmd.extend([
"--override",
str(override),
])
cmd.extend(
[
"--override",
str(override),
]
)

if dry:
cmd.append("--dry-run")
Expand All @@ -209,7 +215,7 @@ def Sync(
reqFile: list[PathLike],
extraUrl: Optional[str] = None,
index_strategy: Optional[str] = "unsafe-best-match",
dry: bool = False
dry: bool = False,
) -> subprocess.CompletedProcess[Any]:
cmd = [
sys.executable,
Expand All @@ -221,16 +227,20 @@ def Sync(
]

if index_strategy is not None:
cmd.extend([
"--index-strategy",
"unsafe-best-match",
])
cmd.extend(
[
"--index-strategy",
"unsafe-best-match",
]
)

if extraUrl is not None:
cmd.extend([
"--extra-index-url",
extraUrl,
])
cmd.extend(
[
"--extra-index-url",
extraUrl,
]
)

if dry:
cmd.append("--dry-run")
Expand Down Expand Up @@ -262,46 +272,68 @@ def __init__(
outName: str = "requirements.compiled",
):
self.cwd = Path(cwd)
self.reqFiles = [Path(reqFile) for reqFile in reqFilesExt] if reqFilesExt is not None else None
self.reqFiles = (
[Path(reqFile) for reqFile in reqFilesExt]
if reqFilesExt is not None
else None
)
self.gpu = DependencyCompiler.Resolve_Gpu(gpu)

self.gpuUrl = DependencyCompiler.nvidiaPytorchUrl if self.gpu == GPU_OPTION.NVIDIA else DependencyCompiler.rocmPytorchUrl if self.gpu == GPU_OPTION.AMD else None
self.gpuUrl = (
DependencyCompiler.nvidiaPytorchUrl
if self.gpu == GPU_OPTION.NVIDIA
else DependencyCompiler.rocmPytorchUrl
if self.gpu == GPU_OPTION.AMD
else None
)
self.out = self.cwd / outName
self.override = self.cwd / "override.txt"

self.reqFilesCore = reqFilesCore if reqFilesCore is not None else self.find_core_reqs()
self.reqFilesExt = reqFilesExt if reqFilesExt is not None else self.find_ext_reqs()
self.reqFilesCore = (
reqFilesCore if reqFilesCore is not None else self.find_core_reqs()
)
self.reqFilesExt = (
reqFilesExt if reqFilesExt is not None else self.find_ext_reqs()
)

def find_core_reqs(self):
return DependencyCompiler.Find_Req_Files(self.cwd)

def find_ext_reqs(self):
extDirs = [d for d in (self.cwd / "custom_nodes").iterdir() if d.is_dir() and d.name != "__pycache__"]
extDirs = [
d
for d in (self.cwd / "custom_nodes").iterdir()
if d.is_dir() and d.name != "__pycache__"
]
return DependencyCompiler.Find_Req_Files(*extDirs)

def make_override(self):
#clean up
# clean up
self.override.unlink(missing_ok=True)

with open(self.override, "w") as f:
if self.gpu is not None:
f.write(DependencyCompiler.overrideGpu.format(gpu=self.gpu, gpuUrl=self.gpuUrl))
if self.gpu is not None and self.gpuUrl is not None:
f.write(
DependencyCompiler.overrideGpu.format(
gpu=self.gpu, gpuUrl=self.gpuUrl
)
)
f.write("\n\n")

completed = DependencyCompiler.Compile(
cwd=self.cwd,
reqFiles=self.reqFilesCore,
override=self.override
cwd=self.cwd, reqFiles=self.reqFilesCore, override=self.override
)

with open(self.override, "a") as f:
f.write("# ensure that core comfyui deps take precedence over any 3rd party extension deps\n")
f.write(
"# ensure that core comfyui deps take precedence over any 3rd party extension deps\n"
)
for line in completed.stdout:
f.write(line)
f.write("\n")

def compile_core_plus_ext(self):
#clean up
# clean up
self.out.unlink(missing_ok=True)

while True:
Expand Down

0 comments on commit 72a680d

Please sign in to comment.