diff --git a/scripts/prepare.sh b/scripts/prepare.sh index db426e3b11..9cbc8295ee 100644 --- a/scripts/prepare.sh +++ b/scripts/prepare.sh @@ -2,7 +2,11 @@ python scripts/download.py --repo_id meta-llama/Llama-2-7b-chat-hf python scripts/download.py --repo_id meta-llama/Meta-Llama-3-8B python scripts/download.py --repo_id meta-llama/Meta-Llama-3.1-8B python scripts/download.py --repo_id meta-llama/Llama-3.2-3B +python scripts/download.py --repo_id nm-testing/SparseLlama-3-8B-pruned_50.2of4 python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Llama-2-7b-chat-hf python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Meta-Llama-3-8B python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Meta-Llama-3.1-8B python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/meta-llama/Llama-3.2-3B +# neuralmagic doesn't come with tokenizer, so we need to copy it over +mkdir -p checkpoints/nm-testing/SparseLlama-3-8B-pruned_50.2of4/original && cp checkpoints/meta-llama/Meta-Llama-3-8B/original/tokenizer.model checkpoints/nm-testing/SparseLlama-3-8B-pruned_50.2of4/original/tokenizer.model +python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/nm-testing/SparseLlama-3-8B-pruned_50.2of4 diff --git a/test/prototype/test_sparse_api.py b/test/prototype/test_sparse_api.py index 757eb9f913..f3cdbe8386 100644 --- a/test/prototype/test_sparse_api.py +++ b/test/prototype/test_sparse_api.py @@ -50,6 +50,9 @@ def test_sparse(self): sparsify_(model, semi_sparse_weight()) sparse_result = model(input) + if compile: + model = torch.compile(model) + torch.testing.assert_close(dense_result, sparse_result, rtol=1e-3, atol=1e-3) diff --git a/torchao/_models/llama/benchmarks.sh b/torchao/_models/llama/benchmarks.sh index 63733c736d..c8cd4bf39c 100644 --- a/torchao/_models/llama/benchmarks.sh +++ b/torchao/_models/llama/benchmarks.sh @@ -52,7 +52,7 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --wr python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization fp6 --write_result benchmark_results.txt -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --precision float16 --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --sparsity semi-structured --precision float16 --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-4-64 --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-2-8 --write_result benchmark_results.txt @@ -62,7 +62,7 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --wr python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization autoquant --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization fp6 --write_result benchmark_results.txt --precision float16 -python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --precision float16 --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --sparsity semi-structured --precision float16 --write_result benchmark_results.txt python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-4-64 --write_result benchmark_results.txt # python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization uintx-2-8 --write_result benchmark_results.txt @@ -79,3 +79,20 @@ python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --co python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoquant --write_result benchmark_results.txt --batch_size 1 python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoquant --write_result benchmark_results.txt --batch_size 32 python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization autoquant --write_result benchmark_results.txt --batch_size 128 + +# TTFT benchmarks +export MODEL_REPO=meta-llama/Meta-Llama-3.1-8B +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --write_result benchmark_results.txt --prefill_size 8000 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int8dq --write_result benchmark_results.txt --prefill_size 8000 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int8wo --write_result benchmark_results.txt --prefill_size 8000 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int8dq --sparsity semi-structured --write_result benchmark_results.txt --prefill_size 8000 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization float8dq --write_result benchmark_results.txt --prefill_size 8000 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization float8wo --write_result benchmark_results.txt --prefill_size 8000 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization int4wo-64 --write_result benchmark_results.txt --prefill_size 8000 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --compile_prefill --quantization sparse-marlin --write_result benchmark_results.txt --prefill_size 8000 --precision float16 --sparsity semi-structured + +# 2:4 sparse model +export MODEL_REPO=nm-testing/SparseLlama-3-8B-pruned_50.2of4 +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --precision float16 --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --sparsity semi-structured --precision float16 --write_result benchmark_results.txt +python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --compile --quantization sparse-marlin --sparsity semi-structured --precision float16 --write_result benchmark_results.txt diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index d617ceb304..4d87a3869f 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -17,6 +17,29 @@ from torchao.quantization.quant_primitives import MappingType from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 +torch.sparse.SparseSemiStructuredTensor._FORCE_CUTLASS = False + +class HostEvent: + def __init__(self): + self.event_time = None + + def record(self): + self.event_time = time.perf_counter() + + def elapsed_time(self, other_event): + if self.event_time is None: + raise ValueError("Event not recorded!") + # return ms to match cuda event + return abs(other_event.event_time - self.event_time) * 1000 + +def device_timer(device): + if "cuda" in device: + return torch.cuda.Event(enable_timing=True) + elif ("cpu" in device) or ("mps" in device): + return HostEvent() + else: + print(f"device={device} is not yet suppported") + def device_sync(device): if "cuda" in device: torch.cuda.synchronize(device) @@ -98,6 +121,10 @@ def generate( kv_cache_quantization: bool = False, cache_size: Optional[int] = None, linear_causal_mask: bool=False, + prefill_start_event: Optional[torch.cuda.Event]=None, + prefill_end_event: Optional[torch.cuda.Event]=None, + decode_start_event: Optional[torch.cuda.Event]=None, + decode_end_event: Optional[torch.cuda.Event]=None, **sampling_kwargs ) -> torch.Tensor: """ @@ -128,12 +155,21 @@ def generate( model.setup_caches(max_batch_size=batch_size, max_seq_length=cache_size, kv_cache_quantization=kv_cache_quantization, linear_causal_mask=linear_causal_mask, prompt_length=T) # execute prefill + if prefill_start_event is not None: + prefill_start_event.record() next_token = prefill(model, prompt.view(batch_size, -1), input_pos, **sampling_kwargs).clone() seq[:, T] = next_token.squeeze() + if prefill_end_event is not None: + prefill_end_event.record() + # execute token generation + if decode_start_event is not None: + decode_start_event.record() input_pos = torch.tensor([T], device=device, dtype=torch.int) generated_tokens, _ = decode_n_tokens(model, next_token.view(batch_size, -1), input_pos, new_tokens-1, callback=callback, **sampling_kwargs) seq = torch.cat((seq[:, :T+1], *generated_tokens), dim=-1) + if decode_end_event is not None: + decode_end_event.record() return seq @@ -157,6 +193,7 @@ def _load_model(checkpoint_path, device, precision): B_INST, E_INST = "[INST]", "[/INST]" def main( + prefill_size: Optional[int] = None, prompt: str = "Hello, my name is", interactive: bool = False, num_samples: int = 5, @@ -166,6 +203,7 @@ def main( temperature: float = 0.8, checkpoint_path: Path = Path("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf/model.pth"), quantization: Optional[str] = None, + sparsity: Optional[str] = None, kv_cache_quantization: bool = False, cache_size: Optional[int] = None, linear_causal_mask: bool=False, @@ -181,6 +219,10 @@ def main( """Generates text samples based on a pre-trained Transformer model and tokenizer. """ + if prefill_size is not None and prefill_size > 0: + # create prompt of prefill size + prompt = "prompt " * (int(prefill_size)-3) + torchao.quantization.utils.recommended_inductor_config_setter() assert checkpoint_path.is_file(), checkpoint_path @@ -205,6 +247,14 @@ def main( torch.manual_seed(1234) + def ffn_only(mod, fqn): + return isinstance(mod, torch.nn.Linear) and "feed_forward" in fqn + + def not_ffn_only(mod, fqn): + return isinstance(mod, torch.nn.Linear) and not ffn_only(mod, fqn) + + def ffn_or_attn_only(mod, fqn): + return isinstance(mod, torch.nn.Linear) and ("feed_forward" in fqn or "attention" in fqn) if quantization: from torchao.quantization import ( @@ -228,9 +278,14 @@ def main( apply_spinquant(model) if "int8wo" in quantization: quantize_(model, int8_weight_only()) - elif "int8dq" in quantization: - quantize_(model, int8_dynamic_activation_int8_weight()) - elif "int4wo" in quantization: + if "int8dq" in quantization: + if sparsity and "semi" in sparsity: + from torchao.dtypes import SemiSparseLayout + quantize_(model, int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()), filter_fn=ffn_only) + quantize_(model, int8_dynamic_activation_int8_weight(), filter_fn=not_ffn_only) + else: + quantize_(model, int8_dynamic_activation_int8_weight()) + if "int4wo" in quantization: if "hqq" in quantization: use_hqq=True else: @@ -250,9 +305,9 @@ def main( layout=MarlinQQQLayout(), ), ) - else: + elif "semi" in sparsity: from torchao.dtypes import MarlinSparseLayout - quantize_(model, int4_weight_only(layout=MarlinSparseLayout())) + quantize_(model, int4_weight_only(layout=MarlinSparseLayout()), filter_fn=ffn_or_attn_only) if "fp6" in quantization: quantize_(model, fpx_weight_only(3, 2)) elif "embed-int8wo" in quantization: @@ -426,6 +481,13 @@ def main( if not TORCH_VERSION_AT_LEAST_2_5: unwrap_tensor_subclass(model) + # standalone sparsity + elif sparsity: + from torchao.sparsity import semi_sparse_weight, sparsify_ + if "semi" in sparsity: + #TODO there is a bug here, need to fix + sparsify_(model.to(device), semi_sparse_weight(), filter_fn=ffn_only) + model_size = get_model_size_in_bytes(model, ignore_embeddings=True) / 1e9 if save: @@ -451,6 +513,9 @@ def main( aggregate_metrics = { 'tokens_per_sec': [], + 'time': [], + 'decode_tokens_per_sec': [], + 'prefill_time': [], } start = -1 if compile else 0 @@ -485,6 +550,8 @@ def callback(x): else: callback = lambda x : x t0 = time.perf_counter() + prefill_start_event, prefill_end_event = device_timer(device), device_timer(device) + decode_start_event, decode_end_event = device_timer(device), device_timer(device) import contextlib if (i != num_samples - 1 or not profile): prof = contextlib.nullcontext() @@ -504,6 +571,10 @@ def callback(x): kv_cache_quantization=kv_cache_quantization, cache_size=cache_size, linear_causal_mask=linear_causal_mask, + prefill_start_event=prefill_start_event, + prefill_end_event=prefill_end_event, + decode_start_event=decode_start_event, + decode_end_event=decode_end_event, ) if i == -1: print(f"Compilation time: {time.perf_counter() - t0:.2f} seconds") @@ -513,7 +584,7 @@ def callback(x): device_sync(device=device) # MKG t = time.perf_counter() - t0 - if not interactive: + if not interactive and prefill_size is None: tok_list = y[0].tolist() # truncate text after end of string token tokens = tok_list if not tokenizer.eos_id() in tok_list else tok_list[:tok_list.index(tokenizer.eos_id())] @@ -523,7 +594,14 @@ def callback(x): tokens_generated = (y.size(-1) - prompt_length) tokens_sec = tokens_generated / t aggregate_metrics['tokens_per_sec'].append(tokens_sec) - print(f"Time for inference {i + 1}: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec") + aggregate_metrics['time'].append(t) + decode_time = decode_start_event.elapsed_time(decode_end_event) / 1000 + decode_tokens_sec = tokens_generated / decode_time + aggregate_metrics['decode_tokens_per_sec'].append(decode_tokens_sec) + prefill_time = prefill_start_event.elapsed_time(prefill_end_event) / 1000 + aggregate_metrics['prefill_time'].append(prefill_time) + print(f"Sample {i+1} | overall time {t:.04f} s {tokens_sec:.02f} tokens/sec", + f"| prefill time {prefill_time:.04f} s decode {decode_tokens_sec:.02f} tokens/sec") print(f"Bandwidth achieved: {model_size * tokens_sec:.02f} GB/s") if memory_profile and i==0: @@ -544,8 +622,15 @@ def callback(x): break print("==========") + #ignore first sample for warmup tokpersec = torch.mean(torch.tensor(aggregate_metrics['tokens_per_sec'])).item() + ttft = torch.mean(torch.tensor(aggregate_metrics['prefill_time'])).item() + decode_tokpersec = torch.mean(torch.tensor(aggregate_metrics['decode_tokens_per_sec'])).item() bandwidth = model_size * tokpersec + mem = torch.cuda.max_memory_reserved() /1e9 + print(f"Average overall tokens/sec: {tokpersec:.2f}") + print(f"Average decode tokens/sec: {decode_tokens_sec:.04f} s") + print(f"Average TTFT: {ttft:.04f} s") if device == "cuda": mem = torch.cuda.max_memory_reserved() /1e9 elif device == "xpu": @@ -557,15 +642,17 @@ def callback(x): print(f"Peak Memory Usage: {mem:.02f} GB") print(f"Model Size: {model_size:.02f} GB") if write_result: - result_txt = f"\n{datetime.today().strftime('%Y%m%d%H%M%S')}, tok/s={tokpersec:6.2f}, mem/s={bandwidth:7.2f} GB/s, peak_mem={mem:5.2f} GB, model_size={model_size:5.2f} GB " - result_txt += f"quant: {quantization}, mod: {checkpoint_path.parent.name}, kv_quant: {kv_cache_quantization}, compile: {compile}, compile_prefill: {compile_prefill}, dtype: {precision}, device: {device} " + result_txt = f"\n{datetime.today().strftime('%Y%m%d%H%M%S')}, tok/s={tokpersec:6.2f}, tok/s_decode={decode_tokpersec:6.2f}, ttft={ttft:5.4f}, mem/s={bandwidth:7.2f} GB/s, peak_mem={mem:5.2f} GB, model_size={model_size:5.2f} GB " + result_txt += f"quant: {quantization}, sparse: {sparsity}, mod: {checkpoint_path.parent.name}, kv_quant: {kv_cache_quantization}, compile: {compile}, compile_prefill: {compile_prefill}, dtype: {precision}, device: {device} " result_txt += f"repro: python generate.py " result_txt += f"--quantization {quantization} " if quantization else "" + result_txt += f"--sparsity {sparsity} " if sparsity else "" result_txt += f"--checkpoint_path {checkpoint_path} " result_txt += f"--device {device} " result_txt += f"--precision {precision} " result_txt += f"--compile " if compile else "" result_txt += f"--compile_prefill " if compile_prefill else "" + result_txt += f"--prefill_size {prefill_size}" if prefill_size else "" result_txt += f"--profile {profile} " if profile else "" result_txt += f"--profile {memory_profile} " if memory_profile else "" result_txt += f"--interactive " if interactive else "" @@ -587,7 +674,7 @@ def callback(x): if __name__ == '__main__': import argparse parser = argparse.ArgumentParser(description='Your CLI description.') - + parser.add_argument('--prefill_size', type=int, default=0, help='Whether to run in ttft mode') parser.add_argument('--prompt', type=str, default="Hello, my name is", help='Input prompt.') parser.add_argument('--interactive', action='store_true', help='Whether to launch in interactive mode') parser.add_argument('--num_samples', type=int, default=5, help='Number of samples.') @@ -603,6 +690,11 @@ def callback(x): +'embed-int8wo, marlin_qqq' ) ) + parser.add_argument('-s', '--sparsity', type=str, + help=( + 'Which sparsity techniques to apply: semi-structured' + ) + ) parser.add_argument('--kv_cache_quantization', action='store_true', help='Whether to quantize the KV cache') parser.add_argument('--cache_size', type=int, default=None, help='Force size of cache to be a certain number of tokens, if not set, will use max_new_tokens+prompt_size') parser.add_argument('--linear_causal_mask', action='store_true', help='Whether to use the memory efficient, but slightly less fast, linear causal mask (important for long context lengths)') @@ -617,6 +709,6 @@ def callback(x): args = parser.parse_args() main( - args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.batch_size, args.top_k, - args.temperature, args.checkpoint_path, args.quantization, args.kv_cache_quantization, args.cache_size, args.linear_causal_mask, args.save, args.compile, args.compile_prefill, args.profile, args.memory_profile, args.device, args.precision, args.write_result + args.prefill_size, args.prompt, args.interactive, args.num_samples, args.max_new_tokens, args.batch_size, args.top_k, + args.temperature, args.checkpoint_path, args.quantization, args.sparsity, args.kv_cache_quantization, args.cache_size, args.linear_causal_mask, args.save, args.compile, args.compile_prefill, args.profile, args.memory_profile, args.device, args.precision, args.write_result ) diff --git a/torchao/dtypes/uintx/semi_sparse_layout.py b/torchao/dtypes/uintx/semi_sparse_layout.py index e2c94a7a38..d832731657 100644 --- a/torchao/dtypes/uintx/semi_sparse_layout.py +++ b/torchao/dtypes/uintx/semi_sparse_layout.py @@ -41,13 +41,17 @@ def _linear_int8_act_int8_weight_semi_structured_sparse_impl( w_vals_int8 = weight_tensor.tensor_impl.int_data w_scales = weight_tensor.tensor_impl.scale tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) + # must pad + row, col = tmp.shape + from torch.sparse import SparseSemiStructuredTensorCUSPARSELT + tmp_padded = SparseSemiStructuredTensorCUSPARSELT._pad_dense_input(tmp) # we fuse one of the scalar matrix multiplications (w_scales) into the sparse mm y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm( w_vals_int8, - tmp.t(), + tmp_padded.t(), alpha=w_scales.to(torch.float32), out_dtype=torch.bfloat16, - ).t() + ).t()[:row, :] y = (y_dot_bf16_w_scales_fused * x_scales.reshape(-1, 1)).reshape( *x_vals_int8.shape[:-1], y_dot_bf16_w_scales_fused.shape[-1] )