Skip to content

Commit

Permalink
update balanced offload
Browse files Browse the repository at this point in the history
Signed-off-by: Vladimir Mandic <mandic00@live.com>
  • Loading branch information
vladmandic committed Dec 11, 2024
1 parent f4847f1 commit 9a588d9
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 48 deletions.
7 changes: 3 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,9 @@
- **Memory** improvements:
- faster and more compatible *balanced offload* mode
- balanced offload: units are now in percentage instead of bytes
- balanced offload: add both high and low watermark and pinned threshold, defaults as below
25% for low-watermark: skip offload if memory usage is below 25%
70% high-watermark: must offload if memory usage is above 70%
15% pin-watermark: any model component smaller than 15% of total memory is pinned and not offloaded
- balanced offload: add both high and low watermark, defaults as below
`0.25` for low-watermark: skip offload if memory usage is below 25%
`0.70` high-watermark: must offload if memory usage is above 70%
- change-in-behavior:
low-end systems, triggered by either `lowvrwam` or by detection of <=4GB will use *sequential offload*
all other systems use *balanced offload* by default (can be changed in settings)
Expand Down
54 changes: 30 additions & 24 deletions modules/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,53 +187,59 @@ def get_device_for(task): # pylint: disable=unused-argument


def torch_gc(force=False, fast=False):
def get_stats():
mem_dict = memstats.memory_stats()
gpu_dict = mem_dict.get('gpu', {})
ram_dict = mem_dict.get('ram', {})
oom = gpu_dict.get('oom', 0)
ram = ram_dict.get('used', 0)
if backend == "directml":
gpu = torch.cuda.memory_allocated() / (1 << 30)
else:
gpu = gpu_dict.get('used', 0)
used_gpu = round(100 * gpu / gpu_dict.get('total', 1)) if gpu_dict.get('total', 1) > 1 else 0
used_ram = round(100 * ram / ram_dict.get('total', 1)) if ram_dict.get('total', 1) > 1 else 0
return gpu, used_gpu, ram, used_ram, oom

global previous_oom # pylint: disable=global-statement
import gc
from modules import timer, memstats
from modules.shared import cmd_opts

t0 = time.time()
mem = memstats.memory_stats()
gpu = mem.get('gpu', {})
ram = mem.get('ram', {})
oom = gpu.get('oom', 0)
if backend == "directml":
used_gpu = round(100 * torch.cuda.memory_allocated() / (1 << 30) / gpu.get('total', 1)) if gpu.get('total', 1) > 1 else 0
else:
used_gpu = round(100 * gpu.get('used', 0) / gpu.get('total', 1)) if gpu.get('total', 1) > 1 else 0
used_ram = round(100 * ram.get('used', 0) / ram.get('total', 1)) if ram.get('total', 1) > 1 else 0
global previous_oom # pylint: disable=global-statement
gpu, used_gpu, ram, used_ram, oom = get_stats()
threshold = 0 if (cmd_opts.lowvram and not cmd_opts.use_zluda) else opts.torch_gc_threshold
collected = 0
if force or threshold == 0 or used_gpu >= threshold or used_ram >= threshold:
force = True
if oom > previous_oom:
previous_oom = oom
log.warning(f'Torch GPU out-of-memory error: {mem}')
log.warning(f'Torch GPU out-of-memory error: {memstats.memory_stats()}')
force = True
if force:
# actual gc
collected = gc.collect() if not fast else 0 # python gc
if cuda_ok:
try:
with torch.cuda.device(get_cuda_device_string()):
torch.cuda.synchronize()
torch.cuda.empty_cache() # cuda gc
torch.cuda.ipc_collect()
except Exception:
pass
else:
return gpu, ram
t1 = time.time()
if 'gc' not in timer.process.records:
timer.process.records['gc'] = 0
timer.process.records['gc'] += t1 - t0
if not force or collected == 0:
return used_gpu, used_ram
mem = memstats.memory_stats()
saved = round(gpu.get('used', 0) - mem.get('gpu', {}).get('used', 0), 2)
before = { 'gpu': gpu.get('used', 0), 'ram': ram.get('used', 0) }
after = { 'gpu': mem.get('gpu', {}).get('used', 0), 'ram': mem.get('ram', {}).get('used', 0), 'retries': mem.get('retries', 0), 'oom': mem.get('oom', 0) }
utilization = { 'gpu': used_gpu, 'ram': used_ram, 'threshold': threshold }
results = { 'collected': collected, 'saved': saved }
timer.process.add('gc', t1 - t0)

new_gpu, new_used_gpu, new_ram, new_used_ram, oom = get_stats()
before = { 'gpu': gpu, 'ram': ram }
after = { 'gpu': new_gpu, 'ram': new_ram, 'oom': oom }
utilization = { 'gpu': new_used_gpu, 'ram': new_used_ram, 'threshold': threshold }
results = { 'saved': round(gpu - new_gpu, 2), 'collected': collected }
fn = f'{sys._getframe(2).f_code.co_name}:{sys._getframe(1).f_code.co_name}' # pylint: disable=protected-access
log.debug(f'GC: utilization={utilization} gc={results} before={before} after={after} device={torch.device(get_optimal_device_name())} fn={fn} time={round(t1 - t0, 2)}') # pylint: disable=protected-access
return used_gpu, used_ram
log.debug(f'GC: utilization={utilization} gc={results} before={before} after={after} device={torch.device(get_optimal_device_name())} fn={fn} time={round(t1 - t0, 2)}')
return new_gpu, new_ram


def set_cuda_sync_mode(mode):
Expand Down
52 changes: 35 additions & 17 deletions modules/sd_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from ldm.util import instantiate_from_config
from modules import paths, shared, shared_state, modelloader, devices, script_callbacks, sd_vae, sd_unet, errors, sd_models_config, sd_models_compile, sd_hijack_accelerate, sd_detect
from modules.timer import Timer, process as process_timer
from modules.memstats import memory_stats, memory_cache
from modules.memstats import memory_stats
from modules.modeldata import model_data
from modules.sd_checkpoint import CheckpointInfo, select_checkpoint, list_models, checkpoints_list, checkpoint_titles, get_closet_checkpoint_match, model_hash, update_model_hashes, setup_model, write_metadata, read_metadata_from_safetensors # pylint: disable=unused-import

Expand All @@ -35,6 +35,8 @@
diffusers_version = int(diffusers.__version__.split('.')[1])
checkpoint_tiles = checkpoint_titles # legacy compatibility
should_offload = ['sc', 'sd3', 'f1', 'hunyuandit', 'auraflow', 'omnigen']
offload_hook_instance = None
offload_component_map = {}


class NoWatermark:
Expand Down Expand Up @@ -415,10 +417,6 @@ def detach_hook(self, module):
return module


offload_hook_instance = None
offload_component_map = {}


def apply_balanced_offload(sd_model, exclude=[]):
global offload_hook_instance # pylint: disable=global-statement
if shared.opts.diffusers_offload_mode != "balanced":
Expand All @@ -433,6 +431,29 @@ def apply_balanced_offload(sd_model, exclude=[]):
if checkpoint_name is None:
checkpoint_name = sd_model.__class__.__name__

def get_pipe_modules(pipe):
if hasattr(pipe, "_internal_dict"):
modules_names = pipe._internal_dict.keys() # pylint: disable=protected-access
else:
modules_names = get_signature(pipe).keys()
modules_names = [m for m in modules_names if m not in exclude and not m.startswith('_')]
modules = {}
for module_name in modules_names:
module_size = offload_component_map.get(module_name, None)
if module_size is None:
module = getattr(pipe, module_name, None)
if not isinstance(module, torch.nn.Module):
continue
try:
module_size = sum(p.numel()*p.element_size() for p in module.parameters(recurse=True)) / 1024 / 1024 / 1024
except Exception as e:
shared.log.error(f'Balanced offload: module={module_name} {e}')
module_size = 0
offload_component_map[module_name] = module_size
modules[module_name] = module_size
modules = sorted(modules.items(), key=lambda x: x[1], reverse=True)
return modules

def apply_balanced_offload_to_module(pipe):
used_gpu, used_ram = devices.torch_gc(fast=True)
if hasattr(pipe, "pipe"):
Expand All @@ -442,24 +463,20 @@ def apply_balanced_offload_to_module(pipe):
else:
keys = get_signature(pipe).keys()
keys = [k for k in keys if k not in exclude and not k.startswith('_')]
for module_name in keys: # pylint: disable=protected-access
for module_name, module_size in get_pipe_modules(pipe): # pylint: disable=protected-access
module = getattr(pipe, module_name, None)
if not isinstance(module, torch.nn.Module):
continue
network_layer_name = getattr(module, "network_layer_name", None)
device_map = getattr(module, "balanced_offload_device_map", None)
max_memory = getattr(module, "balanced_offload_max_memory", None)
module = accelerate.hooks.remove_hook_from_module(module, recurse=True)
module_size = offload_component_map.get(module_name, None)
if module_size is None:
module_size = sum(p.numel()*p.element_size() for p in module.parameters(recurse=True)) / 1024 / 1024 / 1024
offload_component_map[module_name] = module_size
do_offload = (used_gpu > 100 * shared.opts.diffusers_offload_min_gpu_memory) and (module_size > shared.gpu_memory * shared.opts.diffusers_offload_pin_gpu_memory)
perc_gpu = used_gpu / shared.gpu_memory
try:
debug_move(f'Balanced offload: gpu={used_gpu} ram={used_ram} current={module.device} dtype={module.dtype} op={"move" if do_offload else "skip"} component={module.__class__.__name__} size={module_size:.3f}')
if do_offload and module.device != devices.cpu:
module = module.to(devices.cpu)
used_gpu, used_ram = devices.torch_gc(fast=True, force=True)
prev_gpu = used_gpu
do_offload = (perc_gpu > shared.opts.diffusers_offload_min_gpu_memory) and (module.device != devices.cpu)
if do_offload:
module = module.to(devices.cpu, non_blocking=True)
used_gpu -= module_size
debug_move(f'Balanced offload: op={"move" if do_offload else "skip"} gpu={prev_gpu:.3f}:{used_gpu:.3f} perc={perc_gpu:.2f} ram={used_ram:.3f} current={module.device} dtype={module.dtype} component={module.__class__.__name__} size={module_size:.3f}')
except Exception as e:
if 'bitsandbytes' not in str(e):
shared.log.error(f'Balanced offload: module={module_name} {e}')
Expand All @@ -473,6 +490,7 @@ def apply_balanced_offload_to_module(pipe):
if device_map and max_memory:
module.balanced_offload_device_map = device_map
module.balanced_offload_max_memory = max_memory
devices.torch_gc(fast=True, force=True)

apply_balanced_offload_to_module(sd_model)
if hasattr(sd_model, "pipe"):
Expand Down
3 changes: 1 addition & 2 deletions modules/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,8 +483,7 @@ def get_default_modes():
"diffusers_offload_mode": OptionInfo(startup_offload_mode, "Model offload mode", gr.Radio, {"choices": ['none', 'balanced', 'model', 'sequential']}),
"diffusers_offload_min_gpu_memory": OptionInfo(0.25, "Balanced offload GPU low watermark", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01 }),
"diffusers_offload_max_gpu_memory": OptionInfo(0.70, "Balanced offload GPU high watermark", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01 }),
"diffusers_offload_pin_gpu_memory": OptionInfo(0.15, "Balanced offload GPU pin watermark", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01 }),
"diffusers_offload_max_cpu_memory": OptionInfo(0.90, "Balanced offload CPU high watermark", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01 }),
"diffusers_offload_max_cpu_memory": OptionInfo(0.90, "Balanced offload CPU high watermark", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01, "visible": False }),

"advanced_sep": OptionInfo("<h2>Advanced Options</h2>", "", gr.HTML),
"sd_checkpoint_autoload": OptionInfo(True, "Model autoload on start"),
Expand Down
2 changes: 1 addition & 1 deletion wiki
Submodule wiki updated from 95f174 to db8288

0 comments on commit 9a588d9

Please sign in to comment.