diff --git a/examples/_base_.yaml b/examples/_base_.yaml index b4d33cd6c..57110fb16 100644 --- a/examples/_base_.yaml +++ b/examples/_base_.yaml @@ -1,3 +1,6 @@ +log_report: true +print_report: true + # hydra/cli specific settings hydra: run: diff --git a/examples/pytorch_bert.py b/examples/pytorch_bert.py index db40fad99..09f62b8d5 100644 --- a/examples/pytorch_bert.py +++ b/examples/pytorch_bert.py @@ -1,23 +1,45 @@ +import os + +from huggingface_hub import whoami + from optimum_benchmark import Benchmark, BenchmarkConfig, InferenceConfig, ProcessConfig, PyTorchConfig from optimum_benchmark.logging_utils import setup_logging -setup_logging(level="INFO", prefix="MAIN-PROCESS") +try: + USERNAME = whoami()["name"] +except Exception as e: + print(f"Failed to get username from Hugging Face Hub: {e}") + USERNAME = None -if __name__ == "__main__": - BENCHMARK_NAME = "pytorch_bert" - REPO_ID = f"IlyasMoutawwakil/{BENCHMARK_NAME}" +BENCHMARK_NAME = "pytorch_bert" + +def run_benchmark(): launcher_config = ProcessConfig(device_isolation=True, device_isolation_action="warn") backend_config = PyTorchConfig(device="cuda", device_ids="0", no_weights=True, model="bert-base-uncased") scenario_config = InferenceConfig(memory=True, latency=True, input_shapes={"batch_size": 1, "sequence_length": 128}) - benchmark_config = BenchmarkConfig( - name=BENCHMARK_NAME, launcher=launcher_config, backend=backend_config, scenario=scenario_config + name=BENCHMARK_NAME, + launcher=launcher_config, + scenario=scenario_config, + backend=backend_config, + print_report=True, + log_report=True, ) - # benchmark_config.push_to_hub(repo_id=REPO_ID) - benchmark_report = Benchmark.launch(benchmark_config) - # benchmark_report.push_to_hub(repo_id=REPO_ID) + return benchmark_config, benchmark_report + + +if __name__ == "__main__": + level = os.environ.get("LOG_LEVEL", "INFO") + to_file = os.environ.get("LOG_TO_FILE", "0") == "1" + setup_logging(level=level, to_file=to_file, prefix="MAIN-PROCESS") + + benchmark_config, benchmark_report = run_benchmark() benchmark = Benchmark(config=benchmark_config, report=benchmark_report) - # benchmark.push_to_hub(repo_id=REPO_ID) + + if USERNAME is not None: + benchmark_config.push_to_hub(repo_id=f"{USERNAME}/benchmarks", subfolder=BENCHMARK_NAME) + benchmark_report.push_to_hub(repo_id=f"{USERNAME}/benchmarks", subfolder=BENCHMARK_NAME) + benchmark.push_to_hub(repo_id=f"{USERNAME}/benchmarks", subfolder=BENCHMARK_NAME) diff --git a/examples/pytorch_llama.py b/examples/pytorch_llama.py index 5ecf55731..90c099317 100644 --- a/examples/pytorch_llama.py +++ b/examples/pytorch_llama.py @@ -1,8 +1,16 @@ import os +from huggingface_hub import whoami + from optimum_benchmark import Benchmark, BenchmarkConfig, InferenceConfig, ProcessConfig, PyTorchConfig from optimum_benchmark.logging_utils import setup_logging +try: + USERNAME = whoami()["name"] +except Exception as e: + print(f"Failed to get username from Hugging Face Hub: {e}") + USERNAME = None + BENCHMARK_NAME = "pytorch-llama" WEIGHTS_CONFIGS = { @@ -11,16 +19,16 @@ "quantization_scheme": None, "quantization_config": {}, }, - # "4bit-awq-gemm": { - # "torch_dtype": "float16", - # "quantization_scheme": "awq", - # "quantization_config": {"bits": 4, "version": "gemm"}, - # }, - # "4bit-gptq-exllama-v2": { - # "torch_dtype": "float16", - # "quantization_scheme": "gptq", - # "quantization_config": {"bits": 4, "use_exllama ": True, "version": 2, "model_seqlen": 256}, - # }, + "4bit-awq-gemm": { + "torch_dtype": "float16", + "quantization_scheme": "awq", + "quantization_config": {"bits": 4, "version": "gemm"}, + }, + "4bit-gptq-exllama-v2": { + "torch_dtype": "float16", + "quantization_scheme": "gptq", + "quantization_config": {"bits": 4, "use_exllama ": True, "version": 2, "model_seqlen": 256}, + }, } @@ -42,16 +50,17 @@ def run_benchmark(weight_config: str): input_shapes={"batch_size": 1, "sequence_length": 128}, generate_kwargs={"max_new_tokens": 32, "min_new_tokens": 32}, ) - benchmark_config = BenchmarkConfig( - name=BENCHMARK_NAME, launcher=launcher_config, scenario=scenario_config, backend=backend_config + name=BENCHMARK_NAME, + launcher=launcher_config, + scenario=scenario_config, + backend=backend_config, + print_report=True, + log_report=True, ) benchmark_report = Benchmark.launch(benchmark_config) - benchmark = Benchmark(config=benchmark_config, report=benchmark_report) - filename = f"{BENCHMARK_NAME}-{backend_config.version}-{weight_config}.json" - benchmark.push_to_hub(repo_id="optimum-benchmark/pytorch-llama", filename=filename) - benchmark.save_json(path=f"benchmarks/{filename}") + return benchmark_config, benchmark_report if __name__ == "__main__": @@ -60,4 +69,10 @@ def run_benchmark(weight_config: str): setup_logging(level=level, to_file=to_file, prefix="MAIN-PROCESS") for weight_config in WEIGHTS_CONFIGS: - run_benchmark(weight_config) + benchmark_config, benchmark_report = run_benchmark(weight_config) + benchmark = Benchmark(config=benchmark_config, report=benchmark_report) + + if USERNAME is not None: + benchmark.push_to_hub( + repo_id=f"{USERNAME}/benchmarks", filename=f"{weight_config}.json", subfolder=BENCHMARK_NAME + ) diff --git a/optimum_benchmark/backends/diffusers_utils.py b/optimum_benchmark/backends/diffusers_utils.py index fa40c36bf..43f0757bf 100644 --- a/optimum_benchmark/backends/diffusers_utils.py +++ b/optimum_benchmark/backends/diffusers_utils.py @@ -39,12 +39,18 @@ def get_diffusers_pretrained_config(model: str, **kwargs) -> Dict[str, int]: + if not is_diffusers_available(): + raise ImportError("diffusers is not available. Please, pip install diffusers.") + config = DiffusionPipeline.load_config(model, **kwargs) pipeline_config = config[0] if isinstance(config, tuple) else config return pipeline_config def extract_diffusers_shapes_from_model(model: str, **kwargs) -> Dict[str, int]: + if not is_diffusers_available(): + raise ImportError("diffusers is not available. Please, pip install diffusers.") + model_config = get_diffusers_pretrained_config(model, **kwargs) shapes = {} @@ -56,6 +62,14 @@ def extract_diffusers_shapes_from_model(model: str, **kwargs) -> Dict[str, int]: shapes["height"] = vae_config["sample_size"] shapes["width"] = vae_config["sample_size"] + elif "vae_decoder" in model_config: + vae_import_path = model_config["vae_decoder"] + vae_class = get_class(f"{vae_import_path[0]}.{vae_import_path[1]}") + vae_config = vae_class.load_config(model, subfolder="vae_decoder", **kwargs) + shapes["num_channels"] = vae_config["out_channels"] + shapes["height"] = vae_config["sample_size"] + shapes["width"] = vae_config["sample_size"] + elif "vae_encoder" in model_config: vae_import_path = model_config["vae_encoder"] vae_class = get_class(f"{vae_import_path[0]}.{vae_import_path[1]}") @@ -74,6 +88,9 @@ def extract_diffusers_shapes_from_model(model: str, **kwargs) -> Dict[str, int]: def get_diffusers_automodel_loader_for_task(task: str): + if not is_diffusers_available(): + raise ImportError("diffusers is not available. Please, pip install diffusers.") + model_loader_name = TASKS_TO_MODEL_LOADERS[task] model_loader_class = getattr(diffusers, model_loader_name) return model_loader_class diff --git a/optimum_benchmark/backends/onnxruntime/backend.py b/optimum_benchmark/backends/onnxruntime/backend.py index 8fb69254f..223da6dc7 100644 --- a/optimum_benchmark/backends/onnxruntime/backend.py +++ b/optimum_benchmark/backends/onnxruntime/backend.py @@ -297,11 +297,9 @@ def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]: with Accelerator().split_between_processes(inputs=inputs, apply_padding=False) as process_inputs: inputs = process_inputs - if self.config.library == "transformers": - for key, value in list(inputs.items()): - if key in ["position_ids", "token_type_ids"]: - if key not in self.pretrained_model.input_names: - inputs.pop(key) + for key in list(inputs.keys()): + if hasattr(self.pretrained_model, "input_names") and key not in self.pretrained_model.input_names: + inputs.pop(key) for key, value in inputs.items(): if isinstance(value, torch.Tensor): diff --git a/optimum_benchmark/backends/openvino/backend.py b/optimum_benchmark/backends/openvino/backend.py index cd2a57afe..9db49fb2a 100644 --- a/optimum_benchmark/backends/openvino/backend.py +++ b/optimum_benchmark/backends/openvino/backend.py @@ -201,6 +201,10 @@ def prepare_inputs(self, inputs: Dict[str, Any]) -> Dict[str, Any]: with Accelerator().split_between_processes(inputs=inputs, apply_padding=False) as process_inputs: inputs = process_inputs + for key in list(inputs.keys()): + if hasattr(self.pretrained_model, "input_names") and key not in self.pretrained_model.input_names: + inputs.pop(key) + return inputs def forward(self, inputs: Dict[str, Any], kwargs: Dict[str, Any]) -> OrderedDict: diff --git a/optimum_benchmark/backends/peft_utils.py b/optimum_benchmark/backends/peft_utils.py index 037d4e872..92e71039a 100644 --- a/optimum_benchmark/backends/peft_utils.py +++ b/optimum_benchmark/backends/peft_utils.py @@ -9,5 +9,8 @@ def apply_peft(model: PreTrainedModel, peft_type: str, peft_config: Dict[str, Any]) -> PreTrainedModel: + if not is_peft_available(): + raise ImportError("peft is not available. Please, pip install peft.") + peft_config = PEFT_TYPE_TO_CONFIG_MAPPING[peft_type](**peft_config) return get_peft_model(model=model, peft_config=peft_config) diff --git a/optimum_benchmark/backends/timm_utils.py b/optimum_benchmark/backends/timm_utils.py index 77ed30007..941e09917 100644 --- a/optimum_benchmark/backends/timm_utils.py +++ b/optimum_benchmark/backends/timm_utils.py @@ -10,6 +10,9 @@ def get_timm_pretrained_config(model_name: str) -> PretrainedConfig: + if not is_timm_available(): + raise ImportError("timm is not available. Please, pip install timm.") + model_source, model_name = parse_model_name(model_name) if model_source == "hf-hub": # For model names specified in the form `hf-hub:path/architecture_name@revision`, @@ -21,6 +24,9 @@ def get_timm_pretrained_config(model_name: str) -> PretrainedConfig: def extract_timm_shapes_from_config(config: PretrainedConfig) -> Dict[str, Any]: + if not is_timm_available(): + raise ImportError("timm is not available. Please, pip install timm.") + artifacts_dict = {} config_dict = {k: v for k, v in config.to_dict().items() if v is not None} @@ -74,4 +80,7 @@ def extract_timm_shapes_from_config(config: PretrainedConfig) -> Dict[str, Any]: def get_timm_automodel_loader(): + if not is_timm_available(): + raise ImportError("timm is not available. Please, pip install timm.") + return create_model diff --git a/optimum_benchmark/benchmark/base.py b/optimum_benchmark/benchmark/base.py index 08f81cecd..8ae7f34dc 100644 --- a/optimum_benchmark/benchmark/base.py +++ b/optimum_benchmark/benchmark/base.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from logging import getLogger from typing import TYPE_CHECKING, Type from hydra.utils import get_class @@ -16,6 +17,9 @@ from ..scenarios.base import Scenario +LOGGER = getLogger("benchmark") + + @dataclass class Benchmark(PushToHubMixin): config: BenchmarkConfig @@ -32,8 +36,8 @@ def __post_init__(self): elif not isinstance(self.report, BenchmarkReport): raise ValueError("report must be either a dict or a BenchmarkReport instance") - @classmethod - def launch(cls, config: BenchmarkConfig): + @staticmethod + def launch(config: BenchmarkConfig): """ Runs an benchmark using specified launcher configuration/logic """ @@ -44,12 +48,18 @@ def launch(cls, config: BenchmarkConfig): launcher: Launcher = launcher_factory(launcher_config) # Launch the benchmark using the launcher - report = launcher.launch(worker=cls.run, worker_args=[config]) + report = launcher.launch(worker=Benchmark.run, worker_args=[config]) + + if config.log_report: + report.log() + + if config.print_report: + report.print() return report - @classmethod - def run(cls, config: BenchmarkConfig): + @staticmethod + def run(config: BenchmarkConfig): """ Runs a scenario using specified backend configuration/logic """ diff --git a/optimum_benchmark/benchmark/config.py b/optimum_benchmark/benchmark/config.py index 4edde0b00..f27021b48 100644 --- a/optimum_benchmark/benchmark/config.py +++ b/optimum_benchmark/benchmark/config.py @@ -20,6 +20,9 @@ class BenchmarkConfig(PushToHubMixin): # ENVIRONMENT CONFIGURATION environment: Dict[str, Any] = field(default_factory=lambda: {**get_system_info(), **get_hf_libs_info()}) + print_report: bool = False + log_report: bool = True + @classproperty def default_filename(cls) -> str: return "benchmark_config.json" diff --git a/optimum_benchmark/benchmark/report.py b/optimum_benchmark/benchmark/report.py index 812f4abbb..c4b0602db 100644 --- a/optimum_benchmark/benchmark/report.py +++ b/optimum_benchmark/benchmark/report.py @@ -1,14 +1,21 @@ from dataclasses import dataclass, make_dataclass +from logging import getLogger from typing import Any, Dict, List, Optional +from rich.console import Console +from rich.markdown import Markdown + from ..hub_utils import PushToHubMixin, classproperty from ..trackers.energy import Efficiency, Energy from ..trackers.latency import Latency, Throughput from ..trackers.memory import Memory +CONSOLE = Console() +LOGGER = getLogger("report") + @dataclass -class BenchmarkMeasurements: +class TargetMeasurements: memory: Optional[Memory] = None latency: Optional[Latency] = None throughput: Optional[Throughput] = None @@ -28,7 +35,7 @@ def __post_init__(self): self.efficiency = Efficiency(**self.efficiency) @staticmethod - def aggregate(measurements: List["BenchmarkMeasurements"]) -> "BenchmarkMeasurements": + def aggregate(measurements: List["TargetMeasurements"]) -> "TargetMeasurements": assert len(measurements) > 0, "No measurements to aggregate" m0 = measurements[0] @@ -39,7 +46,39 @@ def aggregate(measurements: List["BenchmarkMeasurements"]) -> "BenchmarkMeasurem energy = Energy.aggregate([m.energy for m in measurements]) if m0.energy is not None else None efficiency = Efficiency.aggregate([m.efficiency for m in measurements]) if m0.efficiency is not None else None - return BenchmarkMeasurements(memory, latency, throughput, energy, efficiency) + return TargetMeasurements( + memory=memory, latency=latency, throughput=throughput, energy=energy, efficiency=efficiency + ) + + def to_plain_text(self) -> str: + plain_text = "" + + for key in ["memory", "latency", "throughput", "energy", "efficiency"]: + measurement = getattr(self, key) + if measurement is not None: + plain_text += f"\t+ {key}:\n" + plain_text += measurement.to_plain_text() + + return plain_text + + def log(self): + for line in self.to_plain_text().split("\n"): + if line: + LOGGER.info(line) + + def to_markdown_text(self) -> str: + markdown_text = "" + + for key in ["memory", "latency", "throughput", "energy", "efficiency"]: + measurement = getattr(self, key) + if measurement is not None: + markdown_text += f"## {key}:\n\n" + markdown_text += measurement.to_markdown_text() + + return markdown_text + + def print(self): + CONSOLE.print(Markdown(self.to_markdown_text())) @dataclass @@ -55,62 +94,52 @@ def from_dict(cls, data: Dict[str, Any]) -> "BenchmarkReport": def __post_init__(self): for target in self.to_dict().keys(): if getattr(self, target) is None: - setattr(self, target, BenchmarkMeasurements()) + setattr(self, target, TargetMeasurements()) elif isinstance(getattr(self, target), dict): - setattr(self, target, BenchmarkMeasurements(**getattr(self, target))) + setattr(self, target, TargetMeasurements(**getattr(self, target))) - def log_memory(self): - for target in self.to_dict().keys(): - measurements: BenchmarkMeasurements = getattr(self, target) - if measurements.memory is not None: - measurements.memory.log(prefix=target) + @classmethod + def aggregate(cls, reports: List["BenchmarkReport"]) -> "BenchmarkReport": + aggregated_measurements = {} + for target in reports[0].to_dict().keys(): + measurements = [getattr(report, target) for report in reports] + aggregated_measurements[target] = TargetMeasurements.aggregate(measurements) - def log_latency(self): - for target in self.to_dict().keys(): - measurements: BenchmarkMeasurements = getattr(self, target) - if measurements.latency is not None: - measurements.latency.log(prefix=target) + return cls.from_dict(aggregated_measurements) - def log_throughput(self): - for target in self.to_dict().keys(): - measurements: BenchmarkMeasurements = getattr(self, target) - if measurements.throughput is not None: - measurements.throughput.log(prefix=target) + def to_plain_text(self) -> str: + plain_text = "" - def log_energy(self): for target in self.to_dict().keys(): - measurements: BenchmarkMeasurements = getattr(self, target) - if measurements.energy is not None: - measurements.energy.log(prefix=target) + plain_text += f"+ {target}:\n" + plain_text += getattr(self, target).to_plain_text() - def log_efficiency(self): - for target in self.to_dict().keys(): - measurements: BenchmarkMeasurements = getattr(self, target) - if measurements.efficiency is not None: - measurements.efficiency.log(prefix=target) + return plain_text + + def to_markdown_text(self) -> str: + markdown_text = "" - def log(self): for target in self.to_dict().keys(): - measurements: BenchmarkMeasurements = getattr(self, target) - if measurements.memory is not None: - measurements.memory.log(prefix=target) - if measurements.latency is not None: - measurements.latency.log(prefix=target) - if measurements.throughput is not None: - measurements.throughput.log(prefix=target) - if measurements.energy is not None: - measurements.energy.log(prefix=target) - if measurements.efficiency is not None: - measurements.efficiency.log(prefix=target) + markdown_text += f"# {target}:\n\n" + markdown_text += getattr(self, target).to_markdown_text() - @classmethod - def aggregate(cls, reports: List["BenchmarkReport"]) -> "BenchmarkReport": - aggregated_measurements = {} - for target in reports[0].to_dict().keys(): - measurements = [getattr(report, target) for report in reports] - aggregated_measurements[target] = BenchmarkMeasurements.aggregate(measurements) + return markdown_text - return cls.from_dict(aggregated_measurements) + def save_text(self, filename: str): + with open(filename, mode="w") as f: + f.write(self.to_plain_text()) + + def save_markdown(self, filename: str): + with open(filename, mode="w") as f: + f.write(self.to_markdown_text()) + + def log(self): + for line in self.to_plain_text().split("\n"): + if line: + LOGGER.info(line) + + def print(self): + CONSOLE.print(Markdown(self.to_markdown_text())) @classproperty def default_filename(self) -> str: diff --git a/optimum_benchmark/cli.py b/optimum_benchmark/cli.py index 29abc0747..4b26266b3 100644 --- a/optimum_benchmark/cli.py +++ b/optimum_benchmark/cli.py @@ -78,7 +78,9 @@ def main(config: DictConfig) -> None: benchmark_config.save_json("benchmark_config.json") benchmark_report = Benchmark.launch(benchmark_config) + benchmark_report.save_markdown("benchmark_report.md") benchmark_report.save_json("benchmark_report.json") + benchmark_report.save_text("benchmark_report.txt") benchmark = Benchmark(config=benchmark_config, report=benchmark_report) benchmark.save_json("benchmark.json") diff --git a/optimum_benchmark/launchers/inline/launcher.py b/optimum_benchmark/launchers/inline/launcher.py index 05b0448c0..eb6c788fe 100644 --- a/optimum_benchmark/launchers/inline/launcher.py +++ b/optimum_benchmark/launchers/inline/launcher.py @@ -13,5 +13,7 @@ def __init__(self, config: InlineConfig): def launch(self, worker: Callable[..., BenchmarkReport], worker_args: List[Any]) -> BenchmarkReport: self.logger.warning("The inline launcher is only recommended for debugging purposes and not for benchmarking") + report = worker(*worker_args) + return report diff --git a/optimum_benchmark/launchers/process/launcher.py b/optimum_benchmark/launchers/process/launcher.py index 7e702a83c..8375586fb 100644 --- a/optimum_benchmark/launchers/process/launcher.py +++ b/optimum_benchmark/launchers/process/launcher.py @@ -70,7 +70,6 @@ def launch(self, worker: Callable[..., BenchmarkReport], worker_args: List[Any]) elif "report" in response: self.logger.info("\t+ Received report from isolated process") report = BenchmarkReport.from_dict(response["report"]) - report.log() else: raise RuntimeError(f"Received an unexpected response from isolated process: {response}") diff --git a/optimum_benchmark/launchers/torchrun/launcher.py b/optimum_benchmark/launchers/torchrun/launcher.py index 2c51c521e..98c076ee0 100644 --- a/optimum_benchmark/launchers/torchrun/launcher.py +++ b/optimum_benchmark/launchers/torchrun/launcher.py @@ -102,8 +102,6 @@ def launch(self, worker: Callable[..., BenchmarkReport], worker_args: List[Any]) self.logger.info("\t+ Aggregating reports from all rank processes") report = BenchmarkReport.aggregate(reports) - report.log() - return report diff --git a/optimum_benchmark/scenarios/inference/scenario.py b/optimum_benchmark/scenarios/inference/scenario.py index c0d9475e8..f2f18e0b1 100644 --- a/optimum_benchmark/scenarios/inference/scenario.py +++ b/optimum_benchmark/scenarios/inference/scenario.py @@ -13,7 +13,7 @@ from ..base import Scenario from .config import InferenceConfig -PER_TOKEN_BACKENDS = ["pytorch", "onnxruntime", "openvino", "neural-compressor"] +PER_TOKEN_BACKENDS = ["pytorch", "onnxruntime", "openvino", "neural-compressor", "ipex"] TEXT_GENERATION_DEFAULT_KWARGS = { "num_return_sequences": 1, @@ -91,24 +91,15 @@ def run(self, backend: Backend[BackendConfigT]) -> BenchmarkReport: self.logger.info("\t+ Preparing inputs for Inference") self.inputs = backend.prepare_inputs(inputs=self.inputs) - if self.config.memory: - if backend.config.task in TEXT_GENERATION_TASKS: - self.run_text_generation_memory_tracking(backend) - elif backend.config.task in IMAGE_DIFFUSION_TASKS: - self.run_image_diffusion_memory_tracking(backend) - else: - self.run_inference_memory_tracking(backend) - - self.report.log_memory() - if self.config.latency or self.config.energy: # latency and energy are metrics that require some warmup - if backend.config.task in TEXT_GENERATION_TASKS: - self.warmup_text_generation(backend) - elif backend.config.task in IMAGE_DIFFUSION_TASKS: - self.warmup_image_diffusion(backend) - else: - self.warmup_inference(backend) + if self.config.warmup_runs > 0: + if backend.config.task in TEXT_GENERATION_TASKS: + self.warmup_text_generation(backend) + elif backend.config.task in IMAGE_DIFFUSION_TASKS: + self.warmup_image_diffusion(backend) + else: + self.warmup_inference(backend) if self.config.latency: if backend.config.task in TEXT_GENERATION_TASKS: @@ -121,8 +112,13 @@ def run(self, backend: Backend[BackendConfigT]) -> BenchmarkReport: else: self.run_latency_inference_tracking(backend) - self.report.log_latency() - self.report.log_throughput() + if self.config.memory: + if backend.config.task in TEXT_GENERATION_TASKS: + self.run_text_generation_memory_tracking(backend) + elif backend.config.task in IMAGE_DIFFUSION_TASKS: + self.run_image_diffusion_memory_tracking(backend) + else: + self.run_inference_memory_tracking(backend) if self.config.energy: if backend.config.task in TEXT_GENERATION_TASKS: @@ -132,11 +128,9 @@ def run(self, backend: Backend[BackendConfigT]) -> BenchmarkReport: else: self.run_inference_energy_tracking(backend) - self.report.log_energy() - self.report.log_efficiency() - return self.report + # Warmup def warmup_text_generation(self, backend: Backend[BackendConfigT]): self.logger.info("\t+ Warming up backend for Text Generation") _ = backend.generate(self.inputs, self.config.generate_kwargs) @@ -158,35 +152,25 @@ def warmup_inference(self, backend: Backend[BackendConfigT]): def run_model_loading_tracking(self, backend: Backend[BackendConfigT]): self.logger.info("\t+ Running model loading tracking") - if self.config.latency: - latency_tracker = LatencyTracker(backend=backend.config.name, device=backend.config.device) if self.config.memory: memory_tracker = MemoryTracker( backend=backend.config.name, device=backend.config.device, device_ids=backend.config.device_ids ) - if self.config.energy: - energy_tracker = EnergyTracker( - backend=backend.config.name, device=backend.config.device, device_ids=backend.config.device_ids - ) - - context_stack = ExitStack() if self.config.latency: - context_stack.enter_context(latency_tracker.track()) - if self.config.memory: - context_stack.enter_context(memory_tracker.track()) - if self.config.energy: - context_stack.enter_context(energy_tracker.track()) + latency_tracker = LatencyTracker(backend=backend.config.name, device=backend.config.device) + + with ExitStack() as context_stack: + if self.config.memory: + context_stack.enter_context(memory_tracker.track()) + if self.config.latency: + context_stack.enter_context(latency_tracker.track()) - with context_stack: - self.logger.info("\t+ Loading model for Inference") backend.load() if self.config.latency: self.report.load.latency = latency_tracker.get_latency() if self.config.memory: self.report.load.memory = memory_tracker.get_max_memory() - if self.config.energy: - self.report.load.energy = energy_tracker.get_energy() ## Memory tracking def run_text_generation_memory_tracking(self, backend: Backend[BackendConfigT]): diff --git a/optimum_benchmark/scenarios/training/scenario.py b/optimum_benchmark/scenarios/training/scenario.py index 34fdc1b8e..0b80135cc 100644 --- a/optimum_benchmark/scenarios/training/scenario.py +++ b/optimum_benchmark/scenarios/training/scenario.py @@ -38,35 +38,21 @@ def run(self, backend: Backend[BackendConfigT]) -> BenchmarkReport: training_callbackes = [] - if self.config.latency: - self.logger.info("\t+ Creating latency tracking callback") - latency_callback = StepLatencyTrainerCallback(device=backend.config.device, backend=backend.config.name) - self.logger.info("\t+ Adding latency measuring callback") - training_callbackes.append(latency_callback) - - context_stack = ExitStack() - - if self.config.memory: - self.logger.info("\t+ Creating memory tracking context manager") - memory_tracker = MemoryTracker( - device=backend.config.device, backend=backend.config.name, device_ids=backend.config.device_ids - ) - - if self.config.energy: - self.logger.info("\t+ Creating energy tracking context manager") - energy_tracker = EnergyTracker( - device=backend.config.device, backend=backend.config.name, device_ids=backend.config.device_ids - ) + with ExitStack() as context_stack: + if self.config.latency: + latency_callback = StepLatencyTrainerCallback(device=backend.config.device, backend=backend.config.name) + training_callbackes.append(latency_callback) + if self.config.memory: + memory_tracker = MemoryTracker( + device=backend.config.device, backend=backend.config.name, device_ids=backend.config.device_ids + ) + context_stack.enter_context(memory_tracker.track()) + if self.config.energy: + energy_tracker = EnergyTracker( + device=backend.config.device, backend=backend.config.name, device_ids=backend.config.device_ids + ) + context_stack.enter_context(energy_tracker.track(file_prefix="train")) - if self.config.memory: - self.logger.info("\t+ Entering memory tracking context manager") - context_stack.enter_context(memory_tracker.track()) - - if self.config.energy: - self.logger.info("\t+ Entering energy tracking context manager") - context_stack.enter_context(energy_tracker.track()) - - with context_stack: backend.train( training_dataset=training_dataset, training_callbacks=training_callbackes, @@ -89,12 +75,13 @@ def run(self, backend: Backend[BackendConfigT]) -> BenchmarkReport: ) if self.config.memory: + # we're supposing that it's the same memory usage for all steps self.report.overall.memory = memory_tracker.get_max_memory() self.report.warmup.memory = memory_tracker.get_max_memory() self.report.train.memory = memory_tracker.get_max_memory() if self.config.energy: - # can only get overall energy consumption + # we can only get overall energy consumption self.report.overall.energy = energy_tracker.get_energy() self.report.overall.efficiency = Efficiency.from_energy( self.report.overall.energy, volume=self.overall_volume, unit=TRAIN_EFFICIENCY_UNIT @@ -120,4 +107,8 @@ def warmup_volume(self) -> int: @property def train_volume(self) -> int: - return self.overall_volume - self.warmup_volume + return ( + (self.config.max_steps - self.config.warmup_steps) + * self.config.training_arguments["per_device_train_batch_size"] + * self.config.training_arguments["gradient_accumulation_steps"] + ) diff --git a/optimum_benchmark/trackers/energy.py b/optimum_benchmark/trackers/energy.py index 6f904deb9..3586809f4 100644 --- a/optimum_benchmark/trackers/energy.py +++ b/optimum_benchmark/trackers/energy.py @@ -5,6 +5,9 @@ from logging import getLogger from typing import List, Literal, Optional, Union +from rich.console import Console +from rich.markdown import Markdown + from ..import_utils import is_codecarbon_available, is_torch_available if is_torch_available(): @@ -14,15 +17,16 @@ from codecarbon import EmissionsTracker, OfflineEmissionsTracker from codecarbon.output import EmissionsData +CONSOLE = Console() LOGGER = getLogger("energy") POWER_UNIT = "W" ENERGY_UNIT = "kWh" +POWER_CONSUMPTION_SAMPLING_RATE = 1 # in seconds + Energy_Unit_Literal = Literal["kWh"] Efficiency_Unit_Literal = Literal["samples/kWh", "tokens/kWh", "images/kWh"] -POWER_CONSUMPTION_SAMPLING_RATE = 1 # in seconds - @dataclass class Energy: @@ -33,28 +37,6 @@ class Energy: gpu: float total: float - @staticmethod - def aggregate(energies: List["Energy"]) -> "Energy": - if len(energies) == 0 or all(energy is None for energy in energies): - return None - elif any(energy is None for energy in energies): - raise ValueError("Some energy measurements are missing") - - # since measurements are machine-level, we just take the average - cpu = sum(energy.cpu for energy in energies) / len(energies) - gpu = sum(energy.gpu for energy in energies) / len(energies) - ram = sum(energy.ram for energy in energies) / len(energies) - total = sum(energy.total for energy in energies) / len(energies) - - return Energy(cpu=cpu, gpu=gpu, ram=ram, total=total, unit=ENERGY_UNIT) - - def log(self, prefix: str = "forward"): - LOGGER.info(f"\t\t+ {prefix} energy consumption:") - LOGGER.info(f"\t\t\t+ CPU: {self.cpu:f} ({self.unit})") - LOGGER.info(f"\t\t\t+ GPU: {self.gpu:f} ({self.unit})") - LOGGER.info(f"\t\t\t+ RAM: {self.ram:f} ({self.unit})") - LOGGER.info(f"\t\t\t+ total: {self.total:f} ({self.unit})") - def __sub__(self, other: "Energy") -> "Energy": """Enables subtraction of two Energy instances using the '-' operator.""" @@ -78,6 +60,47 @@ def __truediv__(self, scalar: float) -> "Energy": total=self.total / scalar, ) + @staticmethod + def aggregate(energies: List["Energy"]) -> "Energy": + if len(energies) == 0 or all(energy is None for energy in energies): + return None + elif any(energy is None for energy in energies): + raise ValueError("Some energy measurements are missing") + + # since measurements are machine-level, we just take the average + cpu = sum(energy.cpu for energy in energies) / len(energies) + gpu = sum(energy.gpu for energy in energies) / len(energies) + ram = sum(energy.ram for energy in energies) / len(energies) + total = sum(energy.total for energy in energies) / len(energies) + + return Energy(cpu=cpu, gpu=gpu, ram=ram, total=total, unit=ENERGY_UNIT) + + def to_plain_text(self) -> str: + plain_text = "" + plain_text += "\t\t+ cpu: {cpu:f} ({unit})\n" + plain_text += "\t\t+ gpu: {gpu:f} ({unit})\n" + plain_text += "\t\t+ ram: {ram:f} ({unit})\n" + plain_text += "\t\t+ total: {total:f} ({unit})\n" + return plain_text.format(**asdict(self)) + + def log(self): + for line in self.to_plain_text().split("\n"): + if line: + LOGGER.info(line) + + def to_markdown_text(self) -> str: + markdown_text = "" + markdown_text += "| metric | value | unit |\n" + markdown_text += "| :--------- | --------: | -----: |\n" + markdown_text += "| cpu | {cpu:f} | {unit} |\n" + markdown_text += "| gpu | {gpu:f} | {unit} |\n" + markdown_text += "| ram | {ram:f} | {unit} |\n" + markdown_text += "| total | {total:f} | {unit} |\n" + return markdown_text.format(**asdict(self)) + + def print(self): + CONSOLE.print(Markdown(self.to_markdown_text())) + @dataclass class Efficiency: @@ -101,8 +124,25 @@ def aggregate(efficiencies: List["Efficiency"]) -> "Efficiency": def from_energy(energy: "Energy", volume: int, unit: str) -> "Efficiency": return Efficiency(value=volume / energy.total if energy.total > 0 else 0, unit=unit) - def log(self, prefix: str = ""): - LOGGER.info(f"\t\t+ {prefix} energy efficiency: {self.value:f} ({self.unit})") + def to_plain_text(self) -> str: + plain_text = "" + plain_text += "\t\t+ efficiency: {value:f} ({unit})\n" + return plain_text.format(**asdict(self)) + + def log(self): + for line in self.to_plain_text().split("\n"): + if line: + LOGGER.info(line) + + def to_markdown_text(self) -> str: + markdown_text = "" + markdown_text += "| metric | value | unit |\n" + markdown_text += "| :--------- | --------: | -----: |\n" + markdown_text += "| efficiency | {value:f} | {unit} |\n" + return markdown_text.format(**asdict(self)) + + def print(self): + CONSOLE.print(Markdown(self.to_markdown_text())) class EnergyTracker: @@ -114,7 +154,7 @@ def __init__(self, backend: str, device: str, device_ids: Optional[Union[str, in self.is_gpu = self.device == "cuda" self.is_pytorch_cuda = (self.backend, self.device) == ("pytorch", "cuda") - LOGGER.info("\t+ Tracking CPU and RAM energy") + LOGGER.info("\t\t+ Tracking RAM and CPU energy consumption") if self.is_gpu: if isinstance(self.device_ids, str): @@ -128,7 +168,7 @@ def __init__(self, backend: str, device: str, device_ids: Optional[Union[str, in else: raise ValueError("GPU device IDs must be a string, an integer, or a list of integers") - LOGGER.info(f"\t+ Tracking GPU energy on devices {self.device_ids}") + LOGGER.info(f"\t\t+ Tracking GPU energy consumption on devices {self.device_ids}") if not is_codecarbon_available(): raise ValueError( @@ -152,11 +192,11 @@ def __init__(self, backend: str, device: str, device_ids: Optional[Union[str, in measure_power_secs=POWER_CONSUMPTION_SAMPLING_RATE, ) except Exception: - LOGGER.warning("\t+ Falling back to Offline Emissions Tracker") + LOGGER.warning("\t\t+ Falling back to Offline Emissions Tracker") if os.environ.get("COUNTRY_ISO_CODE", None) is None: LOGGER.warning( - "\t+ Offline Emissions Tracker requires COUNTRY_ISO_CODE to be set. " + "\t\t+ Offline Emissions Tracker requires COUNTRY_ISO_CODE to be set. " "We will set it to USA but the carbon footprint might be inaccurate." ) @@ -196,7 +236,7 @@ def track(self, file_prefix: str = "task"): emission_data: EmissionsData = self.emission_tracker.stop_task() with open(f"{file_prefix}_codecarbon.json", "w") as f: - LOGGER.info(f"\t+ Saving codecarbon emission data to {file_prefix}_codecarbon.json") + LOGGER.info(f"\t\t+ Saving codecarbon emission data to {file_prefix}_codecarbon.json") dump(asdict(emission_data), f, indent=4) self.total_energy = emission_data.energy_consumed diff --git a/optimum_benchmark/trackers/latency.py b/optimum_benchmark/trackers/latency.py index 343a04d7c..908108cb1 100644 --- a/optimum_benchmark/trackers/latency.py +++ b/optimum_benchmark/trackers/latency.py @@ -1,16 +1,20 @@ import time from contextlib import contextmanager -from dataclasses import dataclass +from dataclasses import asdict, dataclass from logging import getLogger from typing import List, Literal, Optional, Union import numpy as np import torch +from rich.console import Console +from rich.markdown import Markdown from transformers import LogitsProcessor, TrainerCallback +CONSOLE = Console() LOGGER = getLogger("latency") LATENCY_UNIT = "s" + Latency_Unit_Literal = Literal["s"] Throughput_Unit_Literal = Literal["samples/s", "tokens/s", "images/s", "steps/s"] @@ -19,16 +23,17 @@ class Latency: unit: Latency_Unit_Literal + values: List[float] + count: int total: float mean: float - stdev: float p50: float p90: float p95: float p99: float - - values: List[float] + stdev: float + stdev_: float def __getitem__(self, index) -> float: if isinstance(index, slice): @@ -62,28 +67,53 @@ def aggregate(latencies: List["Latency"]) -> "Latency": def from_values(values: List[float], unit: str) -> "Latency": return Latency( unit=unit, + values=values, count=len(values), total=sum(values), mean=np.mean(values), - stdev=np.std(values), p50=np.percentile(values, 50), p90=np.percentile(values, 90), p95=np.percentile(values, 95), p99=np.percentile(values, 99), - values=values, + stdev=np.std(values) if len(values) > 1 else 0, + stdev_=(np.std(values) / np.abs(np.mean(values))) * 100 if len(values) > 1 else 0, ) - def log(self, prefix: str = ""): - stdev_percentage = 100 * self.stdev / self.mean if self.mean > 0 else 0 - LOGGER.info(f"\t\t+ {prefix} latency:") - LOGGER.info(f"\t\t\t- count: {self.count}") - LOGGER.info(f"\t\t\t- total: {self.total:f} {self.unit}") - LOGGER.info(f"\t\t\t- mean: {self.mean:f} {self.unit}") - LOGGER.info(f"\t\t\t- stdev: {self.stdev:f} {self.unit} ({stdev_percentage:.2f}%)") - LOGGER.info(f"\t\t\t- p50: {self.p50:f} {self.unit}") - LOGGER.info(f"\t\t\t- p90: {self.p90:f} {self.unit}") - LOGGER.info(f"\t\t\t- p95: {self.p95:f} {self.unit}") - LOGGER.info(f"\t\t\t- p99: {self.p99:f} {self.unit}") + def to_plain_text(self) -> str: + plain_text = "" + plain_text += "\t\t+ count: {count}\n" + plain_text += "\t\t+ total: {total:.6f} ({unit})\n" + plain_text += "\t\t+ mean: {mean:.6f} ({unit})\n" + plain_text += "\t\t+ p50: {p50:.6f} ({unit})\n" + plain_text += "\t\t+ p90: {p90:.6f} ({unit})\n" + plain_text += "\t\t+ p95: {p95:.6f} ({unit})\n" + plain_text += "\t\t+ p99: {p99:.6f} ({unit})\n" + plain_text += "\t\t+ stdev: {stdev:.6f} ({unit})\n" + plain_text += "\t\t+ stdev_: {stdev_:.2f} (%)\n" + return plain_text.format(**asdict(self)) + + def log(self): + for line in self.to_plain_text().split("\n"): + if line: + LOGGER.info(line) + + def to_markdown_text(self) -> str: + markdown_text = "" + markdown_text += "| metric | value | unit |\n" + markdown_text += "| :----- | -----------: |------: |\n" + markdown_text += "| count | {count} | - |\n" + markdown_text += "| total | {total:f} | {unit} |\n" + markdown_text += "| mean | {mean:f} | {unit} |\n" + markdown_text += "| p50 | {p50:f} | {unit} |\n" + markdown_text += "| p90 | {p90:f} | {unit} |\n" + markdown_text += "| p95 | {p95:f} | {unit} |\n" + markdown_text += "| p99 | {p99:f} | {unit} |\n" + markdown_text += "| stdev | {stdev:f} | {unit} |\n" + markdown_text += "| stdev_ | {stdev_:.2f} | % |\n" + return markdown_text.format(**asdict(self)) + + def print(self): + CONSOLE.print(Markdown(self.to_markdown_text())) @dataclass @@ -109,8 +139,25 @@ def from_latency(latency: Latency, volume: int, unit: str) -> "Throughput": value = volume / latency.mean if latency.mean > 0 else 0 return Throughput(value=value, unit=unit) - def log(self, prefix: str = "method"): - LOGGER.info(f"\t\t+ {prefix} throughput: {self.value:f} {self.unit}") + def to_plain_text(self) -> str: + plain_text = "" + plain_text += "\t\t+ throughput: {value:.2f} ({unit})\n" + return plain_text.format(**asdict(self)) + + def log(self): + for line in self.to_plain_text().split("\n"): + if line: + LOGGER.info(line) + + def to_markdown_text(self) -> str: + markdown_text = "" + markdown_text += "| metric | value | unit |\n" + markdown_text += "| :--------- | --------: | -----: |\n" + markdown_text += "| throughput | {value:.2f} | {unit} |\n" + return markdown_text.format(**asdict(self)) + + def print(self): + CONSOLE.print(Markdown(self.to_markdown_text())) class LatencyTracker: @@ -121,9 +168,9 @@ def __init__(self, device: str, backend: str): self.is_pytorch_cuda = (self.backend, self.device) == ("pytorch", "cuda") if self.is_pytorch_cuda: - LOGGER.info("\t+ Tracking latency using Pytorch CUDA events") + LOGGER.info("\t\t+ Tracking latency using Pytorch CUDA events") else: - LOGGER.info("\t+ Tracking latency using CPU performance counter") + LOGGER.info("\t\t+ Tracking latency using CPU performance counter") self.start_time: Optional[float] = None self.start_events: List[Union[float, torch.cuda.Event]] = [] @@ -199,9 +246,9 @@ def __init__(self, device: str, backend: str) -> None: self.is_pytorch_cuda = (self.backend, self.device) == ("pytorch", "cuda") if self.is_pytorch_cuda: - LOGGER.info("\t+ Tracking latency using Pytorch CUDA events") + LOGGER.info("\t\t+ Tracking latency using Pytorch CUDA events") else: - LOGGER.info("\t+ Tracking latency using CPU performance counter") + LOGGER.info("\t\t+ Tracking latency using CPU performance counter") self.start_events: List[Union[float, torch.cuda.Event]] = [] self.end_events: List[Union[float, torch.cuda.Event]] = [] @@ -249,9 +296,9 @@ def __init__(self, device: str, backend: str): self.is_pytorch_cuda = (self.backend, self.device) == ("pytorch", "cuda") if self.is_pytorch_cuda: - LOGGER.info("\t+ Tracking latency using Pytorch CUDA events") + LOGGER.info("\t\t+ Tracking latency using Pytorch CUDA events") else: - LOGGER.info("\t+ Tracking latency using CPU performance counter") + LOGGER.info("\t\t+ Tracking latency using CPU performance counter") self.start_time: Optional[float] = None self.prefilled: Optional[bool] = None diff --git a/optimum_benchmark/trackers/memory.py b/optimum_benchmark/trackers/memory.py index ba515d516..5e9359b1c 100644 --- a/optimum_benchmark/trackers/memory.py +++ b/optimum_benchmark/trackers/memory.py @@ -1,11 +1,14 @@ import os from contextlib import contextmanager -from dataclasses import dataclass +from dataclasses import asdict, dataclass from logging import getLogger from multiprocessing import Pipe, Process from multiprocessing.connection import Connection from typing import List, Literal, Optional, Union +from rich.console import Console +from rich.markdown import Markdown + from ..import_utils import ( is_amdsmi_available, is_pynvml_available, @@ -29,9 +32,12 @@ import psutil +CONSOLE = Console() LOGGER = getLogger("memory") MEMORY_UNIT = "MB" +MEMORY_CONSUMPTION_SAMPLING_RATE = 0.01 # in seconds + Memory_Unit_Literal = Literal["MB"] @@ -77,18 +83,43 @@ def aggregate(memories: List["Memory"]) -> "Memory": max_allocated=max_allocated, ) - def log(self, prefix: str = ""): - LOGGER.info(f"\t\t+ {prefix} memory:") + def to_plain_text(self) -> str: + plain_text = "" + if self.max_ram is not None: + plain_text += "\t\t+ max_ram: {max_ram:.2f} ({unit})\n" + if self.max_global_vram is not None: + plain_text += "\t\t+ max_global_vram: {max_global_vram:.2f} ({unit})\n" + if self.max_process_vram is not None: + plain_text += "\t\t+ max_process_vram: {max_process_vram:.2f} ({unit})\n" + if self.max_reserved is not None: + plain_text += "\t\t+ max_reserved: {max_reserved:.2f} ({unit})\n" + if self.max_allocated is not None: + plain_text += "\t\t+ max_allocated: {max_allocated:.2f} ({unit})\n" + return plain_text.format(**asdict(self)) + + def log(self): + for line in self.to_plain_text().split("\n"): + if line: + LOGGER.info(line) + + def to_markdown_text(self) -> str: + markdown_text = "" + markdown_text += "| metric | value | unit |\n" + markdown_text += "| ------ | ----: | ---: |\n" if self.max_ram is not None: - LOGGER.info(f"\t\t\t- max RAM: {self.max_ram:f} ({self.unit})") + markdown_text += "| max_ram | {max_ram:.2f} | {unit} |\n" if self.max_global_vram is not None: - LOGGER.info(f"\t\t\t- max global VRAM: {self.max_global_vram:f} ({self.unit})") + markdown_text += "| max_global_vram | {max_global_vram:.2f} | {unit} |\n" if self.max_process_vram is not None: - LOGGER.info(f"\t\t\t- max process VRAM: {self.max_process_vram:f} ({self.unit})") + markdown_text += "| max_process_vram | {max_process_vram:.2f} | {unit} |\n" if self.max_reserved is not None: - LOGGER.info(f"\t\t\t- max reserved memory: {self.max_reserved:f} ({self.unit})") + markdown_text += "| max_reserved | {max_reserved:.2f} | {unit} |\n" if self.max_allocated is not None: - LOGGER.info(f"\t\t\t- max allocated memory: {self.max_allocated:f} ({self.unit})") + markdown_text += "| max_allocated | {max_allocated:.2f} | {unit} |\n" + return markdown_text.format(**asdict(self)) + + def print(self): + CONSOLE.print(Markdown(self.to_markdown_text())) class MemoryTracker: @@ -101,7 +132,7 @@ def __init__(self, device: str, backend: str, device_ids: Optional[Union[str, in self.is_gpu = device == "cuda" self.is_pytorch_cuda = (self.backend, self.device) == ("pytorch", "cuda") - LOGGER.info(f"\t+ Tracking RAM memory of process [{self.monitored_pid}]") + LOGGER.info(f"\t\t+ Tracking RAM memory of process {self.monitored_pid}") if self.is_gpu: if isinstance(self.device_ids, str): @@ -115,7 +146,7 @@ def __init__(self, device: str, backend: str, device_ids: Optional[Union[str, in else: raise ValueError("GPU device IDs must be a string, an integer, or a list of integers") - LOGGER.info(f"\t+ Tracking GPU memory of devices {self.device_ids}") + LOGGER.info(f"\t\t+ Tracking GPU memory of devices {self.device_ids}") if self.is_pytorch_cuda: self.num_pytorch_devices = torch.cuda.device_count() @@ -125,7 +156,7 @@ def __init__(self, device: str, backend: str, device_ids: Optional[Union[str, in f"Got {len(self.device_ids)} and {self.num_pytorch_devices} respectively." ) - LOGGER.info(f"\t+ Tracking Allocated/Reserved memory of {self.num_pytorch_devices} Pytorch CUDA devices") + LOGGER.info(f"\t\t+ Tracking Allocated/Reserved memory of {self.num_pytorch_devices} Pytorch CUDA devices") self.max_ram_memory = None self.max_global_vram_memory = None @@ -236,7 +267,7 @@ def get_max_memory(self): ) -def monitor_cpu_ram_memory(monitored_pid: int, connection: Connection, interval: float = 0.001): +def monitor_cpu_ram_memory(monitored_pid: int, connection: Connection): stop = False max_used_memory = 0 monitored_process = psutil.Process(monitored_pid) @@ -248,7 +279,7 @@ def monitor_cpu_ram_memory(monitored_pid: int, connection: Connection, interval: meminfo_attr = "memory_info" if hasattr(monitored_process, "memory_info") else "get_memory_info" used_memory = getattr(monitored_process, meminfo_attr)()[0] max_used_memory = max(max_used_memory, used_memory) - stop = connection.poll(interval) + stop = connection.poll(MEMORY_CONSUMPTION_SAMPLING_RATE) if monitored_process.is_running(): connection.send(max_used_memory / 1e6) # convert to MB @@ -256,7 +287,7 @@ def monitor_cpu_ram_memory(monitored_pid: int, connection: Connection, interval: connection.close() -def monitor_gpu_vram_memory(monitored_pid: int, device_ids: List[int], connection: Connection, interval: float = 0.01): +def monitor_gpu_vram_memory(monitored_pid: int, device_ids: List[int], connection: Connection): stop = False max_used_global_memory = 0 max_used_process_memory = 0 @@ -302,7 +333,7 @@ def monitor_gpu_vram_memory(monitored_pid: int, device_ids: List[int], connectio max_used_global_memory = max(max_used_global_memory, used_global_memory) max_used_process_memory = max(max_used_process_memory, used_process_memory) - stop = connection.poll(interval) + stop = connection.poll(MEMORY_CONSUMPTION_SAMPLING_RATE) pynvml.nvmlShutdown() @@ -365,7 +396,7 @@ def monitor_gpu_vram_memory(monitored_pid: int, device_ids: List[int], connectio max_used_global_memory = max(max_used_global_memory, used_global_memory) max_used_process_memory = max(max_used_process_memory, used_process_memory) - stop = connection.poll(interval) + stop = connection.poll(MEMORY_CONSUMPTION_SAMPLING_RATE) amdsmi.amdsmi_shut_down() rocml.smi_shutdown() diff --git a/setup.py b/setup.py index cff0d1970..03bbdf073 100644 --- a/setup.py +++ b/setup.py @@ -29,6 +29,7 @@ "flatten_dict", "colorlog", "pandas", + "rich", ] try: diff --git a/tests/configs/_base_.yaml b/tests/configs/_base_.yaml index b27b8c7a5..6cc9a5fc8 100644 --- a/tests/configs/_base_.yaml +++ b/tests/configs/_base_.yaml @@ -5,6 +5,9 @@ defaults: - scenario: inference # default scenario - _self_ +print_report: true +log_report: true + # hydra/cli specific settings hydra: run: diff --git a/tests/test_api.py b/tests/test_api.py index 28c5e7385..66ee16f95 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -102,6 +102,8 @@ def test_api_launch(device, scenario, library, task, model): scenario=scenario_config, launcher=launcher_config, backend=backend_config, + print_report=True, + log_report=True, ) benchmark_report = Benchmark.launch(benchmark_config) @@ -123,6 +125,8 @@ def test_api_push_to_hub_mixin(): scenario=scenario_config, launcher=launcher_config, backend=backend_config, + print_report=True, + log_report=True, ) benchmark_report = Benchmark.launch(benchmark_config) benchmark = Benchmark(config=benchmark_config, report=benchmark_report)