Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure GPU URL before writing override configuration #156

Merged
merged 1 commit into from
Aug 16, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 103 additions & 71 deletions comfy_cli/uv.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,37 @@
from importlib import metadata
import os
from pathlib import Path
import re
import shutil
import subprocess
import sys
from importlib import metadata
from pathlib import Path
from textwrap import dedent
from typing import Any, Optional, Union, cast

from comfy_cli.constants import GPU_OPTION
from comfy_cli import ui
from comfy_cli.constants import GPU_OPTION

PathLike = Union[os.PathLike[str], str]


def _run(cmd: list[str], cwd: PathLike) -> subprocess.CompletedProcess[Any]:
return subprocess.run(
cmd,
cwd=cwd,
capture_output=True,
text=True,
check=True
)
return subprocess.run(cmd, cwd=cwd, capture_output=True, text=True, check=True)


def _check_call(cmd: list[str], cwd: Optional[PathLike] = None):
"""uses check_call to run pip, as reccomended by the pip maintainers.
see https://pip.pypa.io/en/stable/user_guide/#using-pip-from-your-program"""

subprocess.check_call(cmd, cwd=cwd)


_req_name_re: re.Pattern[str] = re.compile(r"require\s([\w-]+)")


def _req_re_closure(name: str) -> re.Pattern[str]:
return re.compile(rf"({name}\S+)")


def parse_uv_compile_error(err: str) -> tuple[str, list[str]]:
"""takes in stderr from a run of `uv pip compile` that failed due to requirement conflict and spits out
a tuple of (reqiurement_name, [requirement_spec_in_conflict_a, requirement_spec_in_conflict_b]). Will probably
Expand All @@ -47,17 +46,20 @@ def parse_uv_compile_error(err: str) -> tuple[str, list[str]]:

return reqName, cast(list[str], reqRe.findall(err))


class DependencyCompiler:
rocmPytorchUrl = "https://download.pytorch.org/whl/rocm6.0"
nvidiaPytorchUrl = "https://download.pytorch.org/whl/cu121"

overrideGpu = dedent("""
overrideGpu = dedent(
"""
# ensure usage of {gpu} version of pytorch
--extra-index-url {gpuUrl}
torch
torchsde
torchvision
""").strip()
"""
).strip()

reqNames = {
"requirements.txt",
Expand All @@ -68,26 +70,18 @@ class DependencyCompiler:

@staticmethod
def Find_Req_Files(*ders: PathLike) -> list[Path]:
return [file
return [
file
for der in ders
for file in Path(der).absolute().iterdir()
if file.name in DependencyCompiler.reqNames
]

@staticmethod
def Install_Build_Deps():
"""Use pip to install bare minimum requirements for uv to do its thing
"""
"""Use pip to install bare minimum requirements for uv to do its thing"""
if shutil.which("uv") is None:
cmd = [
sys.executable,
"-m",
"pip",
"install",
"--upgrade",
"pip",
"uv"
]
cmd = [sys.executable, "-m", "pip", "install", "--upgrade", "pip", "uv"]

_check_call(cmd=cmd)

Expand All @@ -114,22 +108,28 @@ def Compile(
# ensures that eg tqdm is latest version, even though an old tqdm is on the amd url
# see https://github.com/astral-sh/uv/blob/main/PIP_COMPATIBILITY.md#packages-that-exist-on-multiple-indexes and https://github.com/astral-sh/uv/issues/171
if index_strategy is not None:
cmd.extend([
"--index-strategy",
"unsafe-best-match",
])
cmd.extend(
[
"--index-strategy",
"unsafe-best-match",
]
)

if override is not None:
cmd.extend([
"--override",
str(override),
])
cmd.extend(
[
"--override",
str(override),
]
)

if out is not None:
cmd.extend([
"-o",
str(out),
])
cmd.extend(
[
"-o",
str(out),
]
)

try:
return _run(cmd, cwd)
Expand Down Expand Up @@ -168,7 +168,7 @@ def Install(
override: Optional[PathLike] = None,
extraUrl: Optional[str] = None,
index_strategy: Optional[str] = "unsafe-best-match",
dry: bool = False
dry: bool = False,
) -> subprocess.CompletedProcess[Any]:
cmd = [
sys.executable,
Expand All @@ -181,22 +181,28 @@ def Install(
]

if index_strategy is not None:
cmd.extend([
"--index-strategy",
"unsafe-best-match",
])
cmd.extend(
[
"--index-strategy",
"unsafe-best-match",
]
)

if extraUrl is not None:
cmd.extend([
"--extra-index-url",
extraUrl,
])
cmd.extend(
[
"--extra-index-url",
extraUrl,
]
)

if override is not None:
cmd.extend([
"--override",
str(override),
])
cmd.extend(
[
"--override",
str(override),
]
)

if dry:
cmd.append("--dry-run")
Expand All @@ -209,7 +215,7 @@ def Sync(
reqFile: list[PathLike],
extraUrl: Optional[str] = None,
index_strategy: Optional[str] = "unsafe-best-match",
dry: bool = False
dry: bool = False,
) -> subprocess.CompletedProcess[Any]:
cmd = [
sys.executable,
Expand All @@ -221,16 +227,20 @@ def Sync(
]

if index_strategy is not None:
cmd.extend([
"--index-strategy",
"unsafe-best-match",
])
cmd.extend(
[
"--index-strategy",
"unsafe-best-match",
]
)

if extraUrl is not None:
cmd.extend([
"--extra-index-url",
extraUrl,
])
cmd.extend(
[
"--extra-index-url",
extraUrl,
]
)

if dry:
cmd.append("--dry-run")
Expand Down Expand Up @@ -262,46 +272,68 @@ def __init__(
outName: str = "requirements.compiled",
):
self.cwd = Path(cwd)
self.reqFiles = [Path(reqFile) for reqFile in reqFilesExt] if reqFilesExt is not None else None
self.reqFiles = (
[Path(reqFile) for reqFile in reqFilesExt]
if reqFilesExt is not None
else None
)
self.gpu = DependencyCompiler.Resolve_Gpu(gpu)

self.gpuUrl = DependencyCompiler.nvidiaPytorchUrl if self.gpu == GPU_OPTION.NVIDIA else DependencyCompiler.rocmPytorchUrl if self.gpu == GPU_OPTION.AMD else None
self.gpuUrl = (
DependencyCompiler.nvidiaPytorchUrl
if self.gpu == GPU_OPTION.NVIDIA
else DependencyCompiler.rocmPytorchUrl
if self.gpu == GPU_OPTION.AMD
else None
)
self.out = self.cwd / outName
self.override = self.cwd / "override.txt"

self.reqFilesCore = reqFilesCore if reqFilesCore is not None else self.find_core_reqs()
self.reqFilesExt = reqFilesExt if reqFilesExt is not None else self.find_ext_reqs()
self.reqFilesCore = (
reqFilesCore if reqFilesCore is not None else self.find_core_reqs()
)
self.reqFilesExt = (
reqFilesExt if reqFilesExt is not None else self.find_ext_reqs()
)

def find_core_reqs(self):
return DependencyCompiler.Find_Req_Files(self.cwd)

def find_ext_reqs(self):
extDirs = [d for d in (self.cwd / "custom_nodes").iterdir() if d.is_dir() and d.name != "__pycache__"]
extDirs = [
d
for d in (self.cwd / "custom_nodes").iterdir()
if d.is_dir() and d.name != "__pycache__"
]
return DependencyCompiler.Find_Req_Files(*extDirs)

def make_override(self):
#clean up
# clean up
self.override.unlink(missing_ok=True)

with open(self.override, "w") as f:
if self.gpu is not None:
f.write(DependencyCompiler.overrideGpu.format(gpu=self.gpu, gpuUrl=self.gpuUrl))
if self.gpu is not None and self.gpuUrl is not None:
f.write(
DependencyCompiler.overrideGpu.format(
gpu=self.gpu, gpuUrl=self.gpuUrl
)
)
f.write("\n\n")

completed = DependencyCompiler.Compile(
cwd=self.cwd,
reqFiles=self.reqFilesCore,
override=self.override
cwd=self.cwd, reqFiles=self.reqFilesCore, override=self.override
)

with open(self.override, "a") as f:
f.write("# ensure that core comfyui deps take precedence over any 3rd party extension deps\n")
f.write(
"# ensure that core comfyui deps take precedence over any 3rd party extension deps\n"
)
for line in completed.stdout:
f.write(line)
f.write("\n")

def compile_core_plus_ext(self):
#clean up
# clean up
self.out.unlink(missing_ok=True)

while True:
Expand Down
Loading