Skip to content

Commit

Permalink
Encode and decode execution provider in the UI
Browse files Browse the repository at this point in the history
  • Loading branch information
henryruhs committed Aug 15, 2023
1 parent 9efbc1a commit 642ffe2
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 13 deletions.
11 changes: 1 addition & 10 deletions roop/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import roop.metadata
from roop.predictor import predict_image, predict_video
from roop.processors.frame.core import get_frame_processors_modules
from roop.utilities import has_image_extension, is_image, is_video, detect_fps, create_video, extract_frames, get_temp_frame_paths, restore_audio, create_temp, move_temp, clean_temp, normalize_output_path, list_module_names
from roop.utilities import has_image_extension, is_image, is_video, detect_fps, create_video, extract_frames, get_temp_frame_paths, restore_audio, create_temp, move_temp, clean_temp, normalize_output_path, list_module_names, decode_execution_providers, encode_execution_providers

warnings.filterwarnings('ignore', category=FutureWarning, module='insightface')
warnings.filterwarnings('ignore', category=UserWarning, module='torchvision')
Expand Down Expand Up @@ -86,15 +86,6 @@ def parse_args() -> None:
roop.globals.execution_queue_count = args.execution_queue_count


def encode_execution_providers(execution_providers: List[str]) -> List[str]:
return [execution_provider.replace('ExecutionProvider', '').lower() for execution_provider in execution_providers]


def decode_execution_providers(execution_providers: List[str]) -> List[str]:
return [provider for provider, encoded_execution_provider in zip(onnxruntime.get_available_providers(), encode_execution_providers(onnxruntime.get_available_providers()))
if any(execution_provider in encoded_execution_provider for execution_provider in execution_providers)]


def suggest_execution_providers() -> List[str]:
return encode_execution_providers(onnxruntime.get_available_providers())

Expand Down
7 changes: 4 additions & 3 deletions roop/uis/__components__/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from roop.face_analyser import clear_face_analyser
from roop.processors.frame.core import clear_frame_processors_modules
from roop.uis.typing import Update
from roop.utilities import encode_execution_providers, decode_execution_providers

EXECUTION_PROVIDERS_CHECKBOX_GROUP: Optional[gradio.CheckboxGroup] = None
EXECUTION_THREAD_COUNT_SLIDER: Optional[gradio.Slider] = None
Expand All @@ -20,8 +21,8 @@ def render() -> None:
with gradio.Box():
EXECUTION_PROVIDERS_CHECKBOX_GROUP = gradio.CheckboxGroup(
label='EXECUTION PROVIDERS',
choices=onnxruntime.get_available_providers(),
value=roop.globals.execution_providers
choices=encode_execution_providers(onnxruntime.get_available_providers()),
value=encode_execution_providers(roop.globals.execution_providers)
)
EXECUTION_THREAD_COUNT_SLIDER = gradio.Slider(
label='EXECUTION THREAD COUNT',
Expand All @@ -48,7 +49,7 @@ def listen() -> None:
def update_execution_providers(execution_providers: List[str]) -> Update:
clear_face_analyser()
clear_frame_processors_modules()
roop.globals.execution_providers = execution_providers
roop.globals.execution_providers = decode_execution_providers(execution_providers)
return gradio.update(value=execution_providers)


Expand Down
11 changes: 11 additions & 0 deletions roop/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import urllib
from pathlib import Path
from typing import List, Optional

import onnxruntime
from tqdm import tqdm

import roop.globals
Expand Down Expand Up @@ -168,3 +170,12 @@ def list_module_names(path: str) -> Optional[List[str]]:
files = os.listdir(path)
return [Path(file).stem for file in files if not Path(file).stem.startswith('__')]
return None


def encode_execution_providers(execution_providers: List[str]) -> List[str]:
return [execution_provider.replace('ExecutionProvider', '').lower() for execution_provider in execution_providers]


def decode_execution_providers(execution_providers: List[str]) -> List[str]:
return [provider for provider, encoded_execution_provider in zip(onnxruntime.get_available_providers(), encode_execution_providers(onnxruntime.get_available_providers()))
if any(execution_provider in encoded_execution_provider for execution_provider in execution_providers)]

0 comments on commit 642ffe2

Please sign in to comment.