Skip to content

Commit

Permalink
Merge pull request #120 from OPHoperHPO/development
Browse files Browse the repository at this point in the history
Downloaders logic fixes
  • Loading branch information
OPHoperHPO authored Jan 9, 2023
2 parents d274ff5 + 3f7529c commit 2935e46
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions carvekit/utils/download_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ def sha512_checksum_calc(file: Path) -> str:
class CachedDownloader:
__metaclass__ = ABCMeta

@property
@abstractmethod
def name(self) -> str:
return self.__class__.__name__

@property
@abstractmethod
def fallback_downloader(self) -> Optional["CachedDownloader"]:
Expand All @@ -99,13 +104,13 @@ def download_model(self, file_name: str) -> Path:
except BaseException as e:
if self.fallback_downloader is not None:
warnings.warn(
f"Failed to download model from {self.__class__.__name__} downloader."
f" Trying to download from {self.fallback_downloader.__class__.__name__} downloader."
f"Failed to download model from {self.name} downloader."
f" Trying to download from {self.fallback_downloader.name} downloader."
)
return self.fallback_downloader.download_model(file_name)
else:
warnings.warn(
f"Failed to download model from {self.__class__.__name__} downloader."
f"Failed to download model from {self.name} downloader."
f" No fallback downloader available."
)
raise e
Expand All @@ -121,17 +126,23 @@ def __call__(self, file_name: str):
class HuggingFaceCompatibleDownloader(CachedDownloader, ABC):
def __init__(
self,
name: str = "Huggingface.co",
base_url: str = "https://huggingface.co",
fb_downloader: Optional["CachedDownloader"] = None,
):
self.cache_dir = checkpoints_dir
self.base_url = base_url
self._name = name
self._fallback_downloader = fb_downloader

@property
def fallback_downloader(self) -> Optional["CachedDownloader"]:
return self._fallback_downloader

@property
def name(self):
return self._name

def check_for_existence(self, file_name: str) -> Optional[Path]:
if file_name not in MODELS_URLS.keys():
raise FileNotFoundError("Unknown model!")
Expand Down Expand Up @@ -167,7 +178,7 @@ def download_model_base(self, file_name: str) -> Path:
hugging_face_url = f"{self.base_url}/{url['repository']}/resolve/{url['revision']}/{url['filename']}"

try:
r = requests.get(hugging_face_url, stream=True)
r = requests.get(hugging_face_url, stream=True, timeout=10)
if r.status_code < 400:
with open(cached_path, "wb") as f:
r.raw.decode_content = True
Expand All @@ -194,8 +205,10 @@ def download_model_base(self, file_name: str) -> Path:
return cached_path


fallback_downloader: CachedDownloader = HuggingFaceCompatibleDownloader()
downloader: CachedDownloader = HuggingFaceCompatibleDownloader(
base_url="https://cdn.carve.photos"
base_url="https://cdn.carve.photos",
fb_downloader=fallback_downloader,
name="Carve CDN",
)
fallback_downloader: CachedDownloader = HuggingFaceCompatibleDownloader()
downloader._fallback_downloader = fallback_downloader

0 comments on commit 2935e46

Please sign in to comment.