From cd806a7e88be5559db175929bcf08e54ca87244e Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sun, 29 Sep 2024 16:28:16 +0200 Subject: [PATCH 1/3] add llava to conversion --- convert_hf_to_gguf.py | 41 +++++++++++++++++-- gguf-py/gguf/constants.py | 73 ++++++++++++++++++++++++++++++++++ gguf-py/gguf/gguf_writer.py | 30 ++++++++++++++ gguf-py/gguf/tensor_mapping.py | 60 ++++++++++++++++++++++++++++ 4 files changed, 200 insertions(+), 4 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 2cd5a8c11bc18..0490178864051 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -66,6 +66,11 @@ class Model: dir_model_card: Path is_lora: bool + # for vision model + vparams: dict[str, Any] | None = None + v_tensor_map: gguf.TensorNameMap + v_tensor_names: set[str] | None + # subclasses should define this! model_arch: gguf.MODEL_ARCH @@ -210,9 +215,13 @@ def match_model_tensor_name(self, name: str, key: gguf.MODEL_TENSOR, bid: int | def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", ".bias")) -> str: new_name = self.tensor_map.get_name(key=name, try_suffixes=try_suffixes) - if new_name is None: + new_name_vision = self.v_tensor_map.get_name(key=name, try_suffixes=try_suffixes) + if new_name is not None: + return new_name + elif new_name_vision is not None: + return new_name_vision + else: raise ValueError(f"Can not map tensor {name!r}") - return new_name def set_gguf_parameters(self): self.gguf_writer.add_block_count(self.block_count) @@ -452,7 +461,10 @@ def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str] @staticmethod def load_hparams(dir_model: Path): with open(dir_model / "config.json", "r", encoding="utf-8") as f: - return json.load(f) + hparams = json.load(f) + if "text_config" in hparams: + hparams = {**hparams, **hparams["text_config"]} + return hparams @classmethod def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]: @@ -1501,10 +1513,17 @@ def prepare_tensors(self): raise ValueError(f"Unprocessed norms: {norms}") -@Model.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM") +@Model.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM", "LlavaForConditionalGeneration") class LlamaModel(Model): model_arch = gguf.MODEL_ARCH.LLAMA + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if "vision_config" in self.hparams: + self.vparams = self.hparams["vision_config"] + if self.vparams is not None: + self.v_tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.LLAVA_VISION, self.vparams["num_hidden_layers"]) + def set_vocab(self): try: self._set_vocab_sentencepiece() @@ -1554,6 +1573,17 @@ def set_gguf_parameters(self): if self.hparams.get("vocab_size", 32000) == 49152: self.gguf_writer.add_add_bos_token(False) + # For vision model + if self.vparams is not None: + self.gguf_writer.add_vision_type("clip") + self.gguf_writer.add_vision_image_size(self.vparams["image_size"]) + self.gguf_writer.add_vision_patch_size(self.vparams["patch_size"]) + self.gguf_writer.add_vision_clip_architecture("llava") + self.gguf_writer.add_vision_clip_block_count(self.vparams["num_hidden_layers"]) + self.gguf_writer.add_vision_clip_embedding_length(self.vparams["hidden_size"]) + self.gguf_writer.add_vision_clip_feed_forward_length(self.vparams["intermediate_size"]) + self.gguf_writer.add_vision_clip_head_count(self.vparams["num_attention_heads"]) + @staticmethod def permute(weights: Tensor, n_head: int, n_head_kv: int | None): if n_head_kv is not None and n_head != n_head_kv: @@ -1568,6 +1598,9 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter n_head = self.hparams["num_attention_heads"] n_kv_head = self.hparams.get("num_key_value_heads") + if name.startswith("language_model"): + name = name.replace("language_model.", "") + if name.endswith(("q_proj.weight", "q_proj.bias")): data_torch = LlamaModel.permute(data_torch, n_head, n_head) if name.endswith(("k_proj.weight", "k_proj.bias")): diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 2fd2e9d2be828..de4dc0cddb1cb 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -178,6 +178,26 @@ class Adapter: TYPE = "adapter.type" LORA_ALPHA = "adapter.lora.alpha" + class Vision: + # only support vision.type = "clip" for now + TYPE = "vision.type" + IMAGE_SIZE = "vision.image_size" + PATCH_SIZE = "vision.patch_size" + IMAGE_MEAN = "vision.image_mean" + IMAGE_STD = "vision.image_std" + + class Clip: + ARCHITECTURE = "vision.clip.architecture" + CONTEXT_LENGTH = "vision.clip.context_length" + EMBEDDING_LENGTH = "vision.clip.embedding_length" + BLOCK_COUNT = "vision.clip.block_count" + FEED_FORWARD_LENGTH = "vision.clip.feed_forward_length" + PROJECTION_TYPE = "vision.clip.projection_type" + PROJECTION_DIM = "vision.clip.projection_dim" + USE_GELU = "vision.clip.use_gelu" + HEAD_COUNT = "vision.clip.attention.head_count" + LAYERNORM_EPS = "vision.clip.attention.layer_norm_epsilon" + # # recommended mapping of model tensor names for storage in gguf # @@ -238,6 +258,8 @@ class MODEL_ARCH(IntEnum): GRANITE = auto() GRANITE_MOE = auto() CHAMELEON = auto() + # vision models + LLAVA_VISION = auto() class MODEL_TENSOR(IntEnum): @@ -345,6 +367,22 @@ class MODEL_TENSOR(IntEnum): ENC_FFN_DOWN = auto() ENC_FFN_UP = auto() ENC_OUTPUT_NORM = auto() + # vision + V_MMPROJ_A = auto() + V_MMPROJ_B = auto() + V_ENC_EMBD_CLS = auto() + V_ENC_EMBD_PATCH = auto() + V_ENC_EMBD_POS = auto() + V_ENC_ATTN_Q = auto() + V_ENC_ATTN_K = auto() + V_ENC_ATTN_V = auto() + V_ENC_INPUT_NORM = auto() + V_ENC_OUTPUT = auto() + V_ENC_OUTPUT_NORM = auto() + V_ENC_FFN_UP = auto() + V_ENC_FFN_DOWN = auto() + V_PRE_NORM = auto() + V_POST_NORM = auto() MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { @@ -397,6 +435,8 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.GRANITE: "granite", MODEL_ARCH.GRANITE_MOE: "granitemoe", MODEL_ARCH.CHAMELEON: "chameleon", + # vision + MODEL_ARCH.LLAVA_VISION: "llava", } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { @@ -504,6 +544,22 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.ENC_FFN_DOWN: "enc.blk.{bid}.ffn_down", MODEL_TENSOR.ENC_FFN_UP: "enc.blk.{bid}.ffn_up", MODEL_TENSOR.ENC_OUTPUT_NORM: "enc.output_norm", + # vision + MODEL_TENSOR.V_MMPROJ_A: "v.mmproj_a", + MODEL_TENSOR.V_MMPROJ_B: "v.mmproj_b", + MODEL_TENSOR.V_ENC_EMBD_CLS: "v.enc.embd.cls", + MODEL_TENSOR.V_ENC_EMBD_PATCH: "v.enc.embd.patch", + MODEL_TENSOR.V_ENC_EMBD_POS: "v.enc.embd.pos", + MODEL_TENSOR.V_ENC_ATTN_Q: "v.enc.blk.{bid}.attn_q", + MODEL_TENSOR.V_ENC_ATTN_K: "v.enc.blk.{bid}.attn_k", + MODEL_TENSOR.V_ENC_ATTN_V: "v.enc.blk.{bid}.attn_v", + MODEL_TENSOR.V_ENC_INPUT_NORM: "v.enc.blk.{bid}.input_norm", + MODEL_TENSOR.V_ENC_OUTPUT: "v.enc.blk.{bid}.output", + MODEL_TENSOR.V_ENC_OUTPUT_NORM: "v.enc.blk.{bid}.output_norm", + MODEL_TENSOR.V_ENC_FFN_UP: "v.enc.blk.{bid}.ffn_up", + MODEL_TENSOR.V_ENC_FFN_DOWN: "v.enc.blk.{bid}.ffn_down", + MODEL_TENSOR.V_PRE_NORM: "v.pre_norm", + MODEL_TENSOR.V_POST_NORM: "v.post_norm", } MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { @@ -1279,6 +1335,23 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.LLAVA_VISION: [ + MODEL_TENSOR.V_MMPROJ_A, + MODEL_TENSOR.V_MMPROJ_B, + MODEL_TENSOR.V_ENC_EMBD_CLS, + MODEL_TENSOR.V_ENC_EMBD_PATCH, + MODEL_TENSOR.V_ENC_EMBD_POS, + MODEL_TENSOR.V_ENC_ATTN_Q, + MODEL_TENSOR.V_ENC_ATTN_K, + MODEL_TENSOR.V_ENC_ATTN_V, + MODEL_TENSOR.V_ENC_INPUT_NORM, + MODEL_TENSOR.V_ENC_OUTPUT, + MODEL_TENSOR.V_ENC_OUTPUT_NORM, + MODEL_TENSOR.V_ENC_FFN_UP, + MODEL_TENSOR.V_ENC_FFN_DOWN, + MODEL_TENSOR.V_PRE_NORM, + MODEL_TENSOR.V_POST_NORM, + ], # TODO } diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 5c460ef1bc260..b35eeee7ef620 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -814,6 +814,36 @@ def add_remove_extra_whitespaces(self, value: bool) -> None: def add_precompiled_charsmap(self, charsmap: Sequence[bytes]) -> None: self.add_array(Keys.Tokenizer.PRECOMPILED_CHARSMAP, charsmap) + def add_vision_type(self, value: str) -> None: + self.add_string(Keys.Vision.TYPE, value) + + def add_vision_image_size(self, value: int) -> None: + self.add_uint32(Keys.Vision.IMAGE_SIZE, value) + + def add_vision_patch_size(self, value: int) -> None: + self.add_uint32(Keys.Vision.PATCH_SIZE, value) + + def add_vision_clip_architecture(self, value: str) -> None: + self.add_string(Keys.Vision.Clip.ARCHITECTURE, value) + + def add_vision_clip_context_length(self, value: int) -> None: + self.add_uint32(Keys.Vision.Clip.CONTEXT_LENGTH, value) + + def add_vision_clip_embedding_length(self, value: int) -> None: + self.add_uint32(Keys.Vision.Clip.EMBEDDING_LENGTH, value) + + def add_vision_clip_block_count(self, value: int) -> None: + self.add_uint32(Keys.Vision.Clip.BLOCK_COUNT, value) + + def add_vision_clip_feed_forward_length(self, value: int) -> None: + self.add_uint32(Keys.Vision.Clip.FEED_FORWARD_LENGTH, value) + + def add_vision_clip_head_count(self, value: int) -> None: + self.add_uint32(Keys.Vision.Clip.HEAD_COUNT, value) + + def add_vision_clip_layer_norm_epsilon(self, value: float) -> None: + self.add_float32(Keys.Vision.Clip.LAYERNORM_EPS, value) + def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None: if not isinstance(value, str): template_default = None diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 5ef91f11d312f..4e73706a018a2 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -679,6 +679,66 @@ class TensorNameMap: MODEL_TENSOR.ENC_OUTPUT_NORM: ( "encoder.final_layer_norm", # t5 ), + + MODEL_TENSOR.V_MMPROJ_A: ( + "multi_modal_projector.linear_1", + ), + + MODEL_TENSOR.V_MMPROJ_B: ( + "multi_modal_projector.linear_2", + ), + + MODEL_TENSOR.V_ENC_EMBD_CLS: ( + "vision_tower.vision_model.embeddings.class_embedding", + ), + + MODEL_TENSOR.V_ENC_EMBD_PATCH: ( + "vision_tower.vision_model.embeddings.patch_embedding", + ), + + MODEL_TENSOR.V_ENC_EMBD_POS: ( + "vision_tower.vision_model.embeddings.position_embedding", + ), + + MODEL_TENSOR.V_ENC_ATTN_Q: ( + "vision_tower.vision_model.encoder.layers.{bid}.self_attn.q_proj", + ), + + MODEL_TENSOR.V_ENC_ATTN_K: ( + "vision_tower.vision_model.encoder.layers.{bid}.self_attn.k_proj", + ), + + MODEL_TENSOR.V_ENC_ATTN_V: ( + "vision_tower.vision_model.encoder.layers.{bid}.self_attn.v_proj", + ), + + MODEL_TENSOR.V_ENC_INPUT_NORM: ( + "vision_tower.vision_model.encoder.layers.{bid}.layer_norm1", + ), + + MODEL_TENSOR.V_ENC_OUTPUT: ( + "vision_tower.vision_model.encoder.layers.{bid}.self_attn.out_proj", + ), + + MODEL_TENSOR.V_ENC_OUTPUT_NORM: ( + "vision_tower.vision_model.encoder.layers.{bid}.layer_norm2", + ), + + MODEL_TENSOR.V_ENC_FFN_UP: ( + "vision_tower.vision_model.encoder.layers.{bid}.mlp.fc1", + ), + + MODEL_TENSOR.V_ENC_FFN_DOWN: ( + "vision_tower.vision_model.encoder.layers.{bid}.mlp.fc2", + ), + + MODEL_TENSOR.V_PRE_NORM: ( + "vision_tower.vision_model.pre_layrnorm", + ), + + MODEL_TENSOR.V_POST_NORM: ( + "vision_tower.vision_model.post_layernorm", + ), } # architecture-specific block mappings From a75c5c42956baa2bb307fc14c8d456ac046052ff Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 30 Sep 2024 13:53:36 +0200 Subject: [PATCH 2/3] model is loadable --- Makefile | 12 ++ convert_hf_to_gguf.py | 3 + gguf-py/gguf/constants.py | 1 + gguf-py/gguf/gguf_writer.py | 3 + include/llama.h | 24 ++++ src/CMakeLists.txt | 1 + src/llama-vision.cpp | 5 + src/llama-vision.h | 91 +++++++++++++ src/llama.cpp | 248 ++++++++++++++++++++++++++++++++---- 9 files changed, 360 insertions(+), 28 deletions(-) create mode 100644 src/llama-vision.cpp create mode 100644 src/llama-vision.h diff --git a/Makefile b/Makefile index 8a903d7ed5914..a49ec0154bd86 100644 --- a/Makefile +++ b/Makefile @@ -1120,6 +1120,7 @@ src/llama.o: \ src/llama-vocab.h \ src/llama-grammar.h \ src/llama-sampling.h \ + src/llama-vision.h \ src/unicode.h \ include/llama.h \ ggml/include/ggml-cuda.h \ @@ -1152,6 +1153,17 @@ src/llama-sampling.o: \ include/llama.h $(CXX) $(CXXFLAGS) -c $< -o $@ +src/llama-vision.o: \ + src/llama-vision.cpp \ + src/llama-vision.h \ + include/llama.h \ + ggml/include/ggml-cuda.h \ + ggml/include/ggml-metal.h \ + ggml/include/ggml.h \ + ggml/include/ggml-alloc.h \ + ggml/include/ggml-backend.h + $(CXX) $(CXXFLAGS) -c $< -o $@ + $(LIB_LLAMA): \ $(OBJ_LLAMA) \ $(LIB_GGML) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 0490178864051..e0880511a4606 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1583,6 +1583,9 @@ def set_gguf_parameters(self): self.gguf_writer.add_vision_clip_embedding_length(self.vparams["hidden_size"]) self.gguf_writer.add_vision_clip_feed_forward_length(self.vparams["intermediate_size"]) self.gguf_writer.add_vision_clip_head_count(self.vparams["num_attention_heads"]) + # TODO: should not hardcode these, but they are currently missing from config.json + self.gguf_writer.add_vision_clip_max_position_embeddings(577) + self.gguf_writer.add_vision_clip_layer_norm_epsilon(1e-05) @staticmethod def permute(weights: Tensor, n_head: int, n_head_kv: int | None): diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index de4dc0cddb1cb..f4ebd8f90beda 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -195,6 +195,7 @@ class Clip: PROJECTION_TYPE = "vision.clip.projection_type" PROJECTION_DIM = "vision.clip.projection_dim" USE_GELU = "vision.clip.use_gelu" + MAX_POS_EMBEDDING = "vision.clip.max_position_embeddings" HEAD_COUNT = "vision.clip.attention.head_count" LAYERNORM_EPS = "vision.clip.attention.layer_norm_epsilon" diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index b35eeee7ef620..2828f0a802d70 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -841,6 +841,9 @@ def add_vision_clip_feed_forward_length(self, value: int) -> None: def add_vision_clip_head_count(self, value: int) -> None: self.add_uint32(Keys.Vision.Clip.HEAD_COUNT, value) + def add_vision_clip_max_position_embeddings(self, value: int) -> None: + self.add_uint32(Keys.Vision.Clip.MAX_POS_EMBEDDING, value) + def add_vision_clip_layer_norm_epsilon(self, value: float) -> None: self.add_float32(Keys.Vision.Clip.LAYERNORM_EPS, value) diff --git a/include/llama.h b/include/llama.h index 4ea8a2c2b664b..9aa17ffd1313e 100644 --- a/include/llama.h +++ b/include/llama.h @@ -224,6 +224,20 @@ extern "C" { typedef bool (*llama_progress_callback)(float progress, void * user_data); + // represent an RGB image + // size of data must be equal to 3*nx*ny + typedef struct llama_img { + uint32_t nx; + uint32_t ny; + unsigned char * data; + } llama_img; + + // Input data for llama_vision_decode + typedef struct llama_img_batch { + int32_t n_imgs; + llama_img * imgs; + } llama_img_batch; + // Input data for llama_decode // A llama_batch object can contain input about one or many sequences // The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens @@ -875,6 +889,16 @@ extern "C" { // shape: [n_embd] (1-dimensional) LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id); + // + // Vision + // + + // encode image into embeddings + LLAMA_API int32_t llama_vision_encode(struct llama_context * ctx, llama_img_batch * batch); + + // get output embeddings, to be put into language batch + LLAMA_API float * llama_vision_get_embeddings(struct llama_context * ctx, int32_t idx); + // // Vocab // diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 46a6ad56202f7..2916e1366ef67 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -17,6 +17,7 @@ add_library(llama llama-vocab.cpp llama-grammar.cpp llama-sampling.cpp + llama-vision.cpp unicode.h unicode.cpp unicode-data.cpp diff --git a/src/llama-vision.cpp b/src/llama-vision.cpp new file mode 100644 index 0000000000000..ac507babbe8c0 --- /dev/null +++ b/src/llama-vision.cpp @@ -0,0 +1,5 @@ +#include "llama.h" + +#include "llama-vision.h" + + diff --git a/src/llama-vision.h b/src/llama-vision.h new file mode 100644 index 0000000000000..5bf1673e530a4 --- /dev/null +++ b/src/llama-vision.h @@ -0,0 +1,91 @@ +#pragma once + +#include "ggml.h" + +#include + +enum vision_arch { + VISION_ARCH_LLAVA, + VISION_ARCH_UNKNOWN, +}; + +enum mm_patch_merge { + MM_PATCH_MERGE_FLAT, + MM_PATCH_MERGE_SPATIAL_UNPAD, +}; + +struct clip_hparams { + vision_arch arch = VISION_ARCH_UNKNOWN; + + uint32_t image_size; + uint32_t patch_size; + uint32_t hidden_size; + uint32_t n_intermediate; + uint32_t projection_dim; + uint32_t n_head; + uint32_t n_layer; + uint32_t max_pos_embd; + + float eps; + + mm_patch_merge mm_patch_merge_type = MM_PATCH_MERGE_FLAT; + + int32_t image_grid_pinpoints[32]; + int32_t image_crop_resolution; +}; + +struct clip_layer { + // attention + struct ggml_tensor * k_w; + struct ggml_tensor * k_b; + struct ggml_tensor * q_w; + struct ggml_tensor * q_b; + struct ggml_tensor * v_w; + struct ggml_tensor * v_b; + + struct ggml_tensor * output_w; + struct ggml_tensor * output_b; + + // layernorm 1 + struct ggml_tensor * norm_in_w; + struct ggml_tensor * norm_in_b; + + // ff + struct ggml_tensor * ffn_up_w; + struct ggml_tensor * ffn_up_b; + + struct ggml_tensor * ffn_down_w; + struct ggml_tensor * ffn_down_b; + + // layernorm 2 + struct ggml_tensor * norm_out_w; + struct ggml_tensor * norm_out_b; +}; + +struct clip_vision_model { + struct clip_hparams hparams; + + // embeddings + struct ggml_tensor * class_embedding; + struct ggml_tensor * patch_embeddings; + struct ggml_tensor * patch_bias; + struct ggml_tensor * position_embeddings; + + struct ggml_tensor * pre_norm_w; + struct ggml_tensor * pre_norm_b; + + std::vector layers; + + struct ggml_tensor * post_norm_w; + struct ggml_tensor * post_norm_b; + + struct ggml_tensor * projection; + + // LLaVA projection + struct ggml_tensor * mm_a_w = NULL; + struct ggml_tensor * mm_a_b = NULL; + struct ggml_tensor * mm_b_w = NULL; + struct ggml_tensor * mm_b_b = NULL; + + struct ggml_tensor * image_newline = NULL; +}; diff --git a/src/llama.cpp b/src/llama.cpp index 44afb31d74e53..0eac03a513f5f 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -1,6 +1,7 @@ #include "llama-impl.h" #include "llama-vocab.h" #include "llama-sampling.h" +#include "llama-vision.h" #include "unicode.h" @@ -273,6 +274,11 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_UNKNOWN, "(unknown)" }, }; +static const std::map VISION_ARCH_NAMES = { + { VISION_ARCH_LLAVA, "llava" }, + { VISION_ARCH_UNKNOWN, "(unknown)" }, +}; + enum llm_kv { LLM_KV_GENERAL_TYPE, LLM_KV_GENERAL_ARCHITECTURE, @@ -379,6 +385,24 @@ enum llm_kv { LLM_KV_ADAPTER_TYPE, LLM_KV_ADAPTER_LORA_ALPHA, + + // TODO: these are vision-related KV, probably should be moved to a new enum + LLM_KV_VISION_TYPE, + LLM_KV_VISION_IMAGE_SIZE, + LLM_KV_VISION_PATCH_SIZE, + LLM_KV_VISION_IMAGE_MEAN, + LLM_KV_VISION_IMAGE_STD, + LLM_KV_VISION_CLIP_ARCHITECTURE, + LLM_KV_VISION_CLIP_CONTEXT_LENGTH, + LLM_KV_VISION_CLIP_EMBEDDING_LENGTH, + LLM_KV_VISION_CLIP_BLOCK_COUNT, + LLM_KV_VISION_CLIP_FEED_FORWARD_LENGTH, + LLM_KV_VISION_CLIP_PROJECTION_TYPE, + LLM_KV_VISION_CLIP_PROJECTION_DIM, + LLM_KV_VISION_CLIP_USE_GELU, + LLM_KV_VISION_CLIP_HEAD_COUNT, + LLM_KV_VISION_CLIP_MAX_POS_EMBD, + LLM_KV_VISION_CLIP_LAYERNORM_EPS, }; static const std::map LLM_KV_NAMES = { @@ -487,6 +511,23 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ADAPTER_TYPE, "adapter.type" }, { LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" }, + + { LLM_KV_VISION_TYPE, "vision.type" }, + { LLM_KV_VISION_IMAGE_SIZE, "vision.image_size" }, + { LLM_KV_VISION_PATCH_SIZE, "vision.patch_size" }, + { LLM_KV_VISION_IMAGE_MEAN, "vision.image_mean" }, + { LLM_KV_VISION_IMAGE_STD, "vision.image_std" }, + { LLM_KV_VISION_CLIP_ARCHITECTURE, "vision.clip.architecture" }, + { LLM_KV_VISION_CLIP_CONTEXT_LENGTH, "vision.clip.context_length" }, + { LLM_KV_VISION_CLIP_EMBEDDING_LENGTH, "vision.clip.embedding_length" }, + { LLM_KV_VISION_CLIP_BLOCK_COUNT, "vision.clip.block_count" }, + { LLM_KV_VISION_CLIP_FEED_FORWARD_LENGTH, "vision.clip.feed_forward_length" }, + { LLM_KV_VISION_CLIP_PROJECTION_TYPE, "vision.clip.projection_type" }, + { LLM_KV_VISION_CLIP_PROJECTION_DIM, "vision.clip.projection_dim" }, + { LLM_KV_VISION_CLIP_USE_GELU, "vision.clip.use_gelu" }, + { LLM_KV_VISION_CLIP_MAX_POS_EMBD, "vision.clip.max_position_embeddings" }, + { LLM_KV_VISION_CLIP_HEAD_COUNT, "vision.clip.attention.head_count" }, + { LLM_KV_VISION_CLIP_LAYERNORM_EPS, "vision.clip.attention.layer_norm_epsilon" }, }; struct LLM_KV { @@ -608,6 +649,24 @@ enum llm_tensor { LLM_TENSOR_ENC_OUTPUT_NORM, }; +enum vision_tensor { + VISION_TENSOR_MMPROJ_A, + VISION_TENSOR_MMPROJ_B, + VISION_TENSOR_ENC_EMBD_CLS, + VISION_TENSOR_ENC_EMBD_PATCH, + VISION_TENSOR_ENC_EMBD_POS, + VISION_TENSOR_ENC_ATTN_Q, + VISION_TENSOR_ENC_ATTN_K, + VISION_TENSOR_ENC_ATTN_V, + VISION_TENSOR_ENC_INPUT_NORM, + VISION_TENSOR_ENC_OUTPUT, + VISION_TENSOR_ENC_OUTPUT_NORM, + VISION_TENSOR_ENC_FFN_UP, + VISION_TENSOR_ENC_FFN_DOWN, + VISION_TENSOR_PRE_NORM, + VISION_TENSOR_POST_NORM, +}; + static const std::map> LLM_TENSOR_NAMES = { { LLM_ARCH_LLAMA, @@ -1530,6 +1589,29 @@ static const std::map> LLM_TENSOR_NA }, }; +static const std::map> VISION_TENSOR_NAMES = { + { + VISION_ARCH_LLAVA, + { + { VISION_TENSOR_MMPROJ_A, "v.mmproj_a" }, + { VISION_TENSOR_MMPROJ_B, "v.mmproj_b" }, + { VISION_TENSOR_ENC_EMBD_CLS, "v.enc.embd.cls" }, + { VISION_TENSOR_ENC_EMBD_PATCH, "v.enc.embd.patch" }, + { VISION_TENSOR_ENC_EMBD_POS, "v.enc.embd.pos" }, + { VISION_TENSOR_ENC_ATTN_Q, "v.enc.blk.%d.attn_q" }, + { VISION_TENSOR_ENC_ATTN_K, "v.enc.blk.%d.attn_k" }, + { VISION_TENSOR_ENC_ATTN_V, "v.enc.blk.%d.attn_v" }, + { VISION_TENSOR_ENC_INPUT_NORM, "v.enc.blk.%d.input_norm" }, + { VISION_TENSOR_ENC_OUTPUT, "v.enc.blk.%d.output" }, + { VISION_TENSOR_ENC_OUTPUT_NORM, "v.enc.blk.%d.output_norm" }, + { VISION_TENSOR_ENC_FFN_UP, "v.enc.blk.%d.ffn_up" }, + { VISION_TENSOR_ENC_FFN_DOWN, "v.enc.blk.%d.ffn_down" }, + { VISION_TENSOR_PRE_NORM, "v.pre_norm" }, + { VISION_TENSOR_POST_NORM, "v.post_norm" }, + } + } +}; + static llm_arch llm_arch_from_string(const std::string & name) { for (const auto & kv : LLM_ARCH_NAMES) { // NOLINT if (kv.second == name) { @@ -1540,56 +1622,66 @@ static llm_arch llm_arch_from_string(const std::string & name) { return LLM_ARCH_UNKNOWN; } -// helper to handle gguf constants -// usage: -// -// const auto tn = LLM_TN(LLM_ARCH_LLAMA); -// -// std::string name = tn(LLM_TENSOR_OUTPUT); -> "output" -// std::string name = tn(LLM_TENSOR_TOKEN_EMBD, "bias"); -> "token_embd.bias" -// std::string name = tn(LLM_TENSOR_ATTN_NORM, "weight", 3); -> "blk.3.attn_norm.weight" -// -struct LLM_TN { - LLM_TN(llm_arch arch) : arch(arch) {} +template +struct BASE_TN { + Tname arch; + std::map> name_mapping; - llm_arch arch; + BASE_TN(Tname arch, std::map> name_mapping) : arch(arch), name_mapping(name_mapping) {} - std::string operator()(llm_tensor tensor) const { - if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) { + std::string operator()(Ttensor tensor) const { + if (name_mapping.at(arch).find(tensor) == name_mapping.at(arch).end()) { return "__missing__"; } - return LLM_TENSOR_NAMES.at(arch).at(tensor); + return name_mapping.at(arch).at(tensor); } - std::string operator()(llm_tensor tensor, const std::string & suffix) const { - if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) { + std::string operator()(Ttensor tensor, const std::string & suffix) const { + if (name_mapping.at(arch).find(tensor) == name_mapping.at(arch).end()) { return "__missing__"; } - return LLM_TENSOR_NAMES.at(arch).at(tensor) + "." + suffix; + return name_mapping.at(arch).at(tensor) + "." + suffix; } - std::string operator()(llm_tensor tensor, int bid) const { - if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) { + std::string operator()(Ttensor tensor, int bid) const { + if (name_mapping.at(arch).find(tensor) == name_mapping.at(arch).end()) { return "__missing__"; } - return ::format(LLM_TENSOR_NAMES.at(arch).at(tensor).c_str(), bid); + return ::format(name_mapping.at(arch).at(tensor).c_str(), bid); } - std::string operator()(llm_tensor tensor, const std::string & suffix, int bid) const { - if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) { + std::string operator()(Ttensor tensor, const std::string & suffix, int bid) const { + if (name_mapping.at(arch).find(tensor) == name_mapping.at(arch).end()) { return "__missing__"; } - return ::format(LLM_TENSOR_NAMES.at(arch).at(tensor).c_str(), bid) + "." + suffix; + return ::format(name_mapping.at(arch).at(tensor).c_str(), bid) + "." + suffix; } - std::string operator()(llm_tensor tensor, const std::string & suffix, int bid, int xid) const { - if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) { + std::string operator()(Ttensor tensor, const std::string & suffix, int bid, int xid) const { + if (name_mapping.at(arch).find(tensor) == name_mapping.at(arch).end()) { return "__missing__"; } - return ::format(LLM_TENSOR_NAMES.at(arch).at(tensor).c_str(), bid, xid) + "." + suffix; + return ::format(name_mapping.at(arch).at(tensor).c_str(), bid, xid) + "." + suffix; } }; +// helper to handle gguf constants +// usage: +// +// const auto tn = LLM_TN(LLM_ARCH_LLAMA); +// +// std::string name = tn(LLM_TENSOR_OUTPUT); -> "output" +// std::string name = tn(LLM_TENSOR_TOKEN_EMBD, "bias"); -> "token_embd.bias" +// std::string name = tn(LLM_TENSOR_ATTN_NORM, "weight", 3); -> "blk.3.attn_norm.weight" +// +struct LLM_TN : BASE_TN { + LLM_TN(llm_arch arch) : BASE_TN(arch, LLM_TENSOR_NAMES) {} +}; + +struct VISION_TN : BASE_TN { + VISION_TN(vision_arch arch) : BASE_TN(arch, VISION_TENSOR_NAMES) {} +}; + // // gguf helpers // @@ -2458,6 +2550,9 @@ struct llama_hparams { enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE; enum llama_rope_scaling_type rope_scaling_type_train = LLAMA_ROPE_SCALING_TYPE_NONE; + bool has_vision = false; + clip_hparams clip; + bool operator!=(const llama_hparams & other) const { if (this->vocab_only != other.vocab_only) return true; if (this->n_vocab != other.n_vocab) return true; @@ -2908,6 +3003,8 @@ struct llama_model { std::vector layers; + clip_vision_model clip; + llama_split_mode split_mode; int main_gpu; int n_gpu_layers; @@ -5476,6 +5573,30 @@ static void llm_load_hparams( hparams.n_embd_head_v = 0; } + std::string vision_type; + ml.get_key(LLM_KV_VISION_TYPE, vision_type, false); + if (vision_type == "clip") { + hparams.has_vision = true; + ml.get_key(LLM_KV_VISION_IMAGE_SIZE, hparams.clip.image_size, true); + ml.get_key(LLM_KV_VISION_PATCH_SIZE, hparams.clip.patch_size, true); + ml.get_key(LLM_KV_VISION_CLIP_EMBEDDING_LENGTH, hparams.clip.hidden_size, true); + ml.get_key(LLM_KV_VISION_CLIP_BLOCK_COUNT, hparams.clip.n_layer, true); + ml.get_key(LLM_KV_VISION_CLIP_FEED_FORWARD_LENGTH, hparams.clip.n_intermediate, true); + ml.get_key(LLM_KV_VISION_CLIP_HEAD_COUNT, hparams.clip.n_head, true); + ml.get_key(LLM_KV_VISION_CLIP_LAYERNORM_EPS, hparams.clip.eps, true); + // TODO: add image_std + std::string arch; + ml.get_key(LLM_KV_VISION_CLIP_ARCHITECTURE, arch, true); + for (auto & it : VISION_ARCH_NAMES) { + if (arch == it.second) { + hparams.clip.arch = it.first; + break; + } + } + } else if (!vision_type.empty()) { + throw std::runtime_error(format("unsupported vision type: %s", vision_type.c_str())); + } + // arch-specific KVs switch (model.arch) { case LLM_ARCH_LLAMA: @@ -6123,6 +6244,15 @@ static void llm_load_hparams( default: (void)0; } + // arch-specific CLIP hparams + switch (hparams.clip.arch) { + case VISION_ARCH_LLAVA: + { + ml.get_key(LLM_KV_VISION_CLIP_MAX_POS_EMBD, hparams.clip.max_pos_embd, true); + } break; + default: (void)0; + } + model.ftype = ml.ftype; if (hparams.f_max_alibi_bias > 0.0f) { @@ -8811,7 +8941,69 @@ static bool llm_load_tensors( } } break; default: - throw std::runtime_error("unknown architecture"); + throw std::runtime_error("unknown llm architecture"); + } + } + + // load tensors for vision model + if (hparams.has_vision) { + const int64_t n_layer = hparams.clip.n_layer; + const int64_t n_embd = hparams.clip.hidden_size; + const int64_t n_ff = hparams.clip.n_intermediate; + const int64_t max_pos_embd = hparams.clip.max_pos_embd; + const int64_t n_channel = 3; // always RGB + const int64_t patch_size = hparams.clip.patch_size; + const auto tn = VISION_TN(hparams.clip.arch); + + ggml_context * ctx_vision = ctx_map.at(model.buft_input.buft); // TODO: make dedicated buft for vision + auto ctx_for_layer = [&](int i) { return ctx_map.at(model.buft_layer[i].buft); }; + + model.clip.layers.resize(n_layer); + + switch (hparams.clip.arch) { + case VISION_ARCH_LLAVA: + { + model.clip.mm_a_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_MMPROJ_A, "weight"), {n_embd, n_ff}); + model.clip.mm_a_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_MMPROJ_A, "bias" ), {n_ff}); + model.clip.mm_b_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_MMPROJ_B, "weight"), {n_ff, n_ff}); + model.clip.mm_b_b = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_MMPROJ_B, "bias" ), {n_ff}); + + model.clip.class_embedding = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_EMBD_CLS ), {n_embd}); + model.clip.patch_embeddings = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_EMBD_PATCH, "weight"), {patch_size, patch_size, n_channel, n_embd}); + model.clip.position_embeddings = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_ENC_EMBD_POS, "weight"), {n_embd, max_pos_embd}); + + model.clip.pre_norm_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_PRE_NORM, "weight"), {n_embd}); + model.clip.pre_norm_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_PRE_NORM, "bias" ), {n_embd}); + model.clip.post_norm_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_POST_NORM, "weight"), {n_embd}); + model.clip.post_norm_w = ml.create_tensor(ctx_vision, tn(VISION_TENSOR_POST_NORM, "bias" ), {n_embd}); + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + auto & layer = model.clip.layers[i]; + + layer.k_w = ml.create_tensor(ctx_layer, tn(VISION_TENSOR_ENC_ATTN_K, "weight", i), {n_embd, n_embd}); + layer.k_b = ml.create_tensor(ctx_layer, tn(VISION_TENSOR_ENC_ATTN_K, "bias" , i), {n_embd}); + layer.v_w = ml.create_tensor(ctx_layer, tn(VISION_TENSOR_ENC_ATTN_V, "weight", i), {n_embd, n_embd}); + layer.v_b = ml.create_tensor(ctx_layer, tn(VISION_TENSOR_ENC_ATTN_V, "bias" , i), {n_embd}); + layer.q_w = ml.create_tensor(ctx_layer, tn(VISION_TENSOR_ENC_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.q_b = ml.create_tensor(ctx_layer, tn(VISION_TENSOR_ENC_ATTN_Q, "bias" , i), {n_embd}); + + layer.ffn_up_w = ml.create_tensor(ctx_layer, tn(VISION_TENSOR_ENC_FFN_UP, "weight", i), {n_embd, n_ff}); + layer.ffn_up_b = ml.create_tensor(ctx_layer, tn(VISION_TENSOR_ENC_FFN_UP, "bias" , i), {n_ff}); + layer.ffn_down_w = ml.create_tensor(ctx_layer, tn(VISION_TENSOR_ENC_FFN_DOWN, "weight", i), {n_ff, n_embd}); + layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(VISION_TENSOR_ENC_FFN_DOWN, "bias" , i), {n_embd}); + + layer.norm_in_w = ml.create_tensor(ctx_layer, tn(VISION_TENSOR_ENC_INPUT_NORM, "weight", i), {n_embd}); + layer.norm_in_b = ml.create_tensor(ctx_layer, tn(VISION_TENSOR_ENC_INPUT_NORM, "bias" , i), {n_embd}); + layer.norm_out_w = ml.create_tensor(ctx_layer, tn(VISION_TENSOR_ENC_OUTPUT_NORM, "weight", i), {n_embd}); + layer.norm_out_b = ml.create_tensor(ctx_layer, tn(VISION_TENSOR_ENC_OUTPUT_NORM, "bias" , i), {n_embd}); + + layer.output_w = ml.create_tensor(ctx_layer, tn(VISION_TENSOR_ENC_OUTPUT, "weight", i), {n_embd, n_embd}); + layer.output_b = ml.create_tensor(ctx_layer, tn(VISION_TENSOR_ENC_OUTPUT, "bias" , i), {n_embd}); + } + } break; + default: + throw std::runtime_error("unknown vision architecture"); } } From 6854ad4057e682fbcc747c75a1d2670a7110ef51 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Mon, 30 Sep 2024 17:35:04 +0200 Subject: [PATCH 3/3] img pre processing --- convert_hf_to_gguf.py | 5 + gguf-py/gguf/constants.py | 5 + gguf-py/gguf/gguf_writer.py | 10 + src/llama-vision.cpp | 491 ++++++++++++++++++++++++++++++++++++ src/llama-vision.h | 18 +- src/llama.cpp | 61 +++-- 6 files changed, 564 insertions(+), 26 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index e0880511a4606..0340c138aba4d 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -1584,8 +1584,13 @@ def set_gguf_parameters(self): self.gguf_writer.add_vision_clip_feed_forward_length(self.vparams["intermediate_size"]) self.gguf_writer.add_vision_clip_head_count(self.vparams["num_attention_heads"]) # TODO: should not hardcode these, but they are currently missing from config.json + self.gguf_writer.add_vision_clip_projector_type(gguf.constants.CLIPProjectorType.MLP) self.gguf_writer.add_vision_clip_max_position_embeddings(577) self.gguf_writer.add_vision_clip_layer_norm_epsilon(1e-05) + default_image_mean = [0.48145466, 0.4578275, 0.40821073] + default_image_std = [0.26862954, 0.26130258, 0.27577711] + self.gguf_writer.add_vision_clip_image_mean(default_image_mean) + self.gguf_writer.add_vision_clip_image_std(default_image_std) @staticmethod def permute(weights: Tensor, n_head: int, n_head_kv: int | None): diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index f4ebd8f90beda..b83dc311ae2a4 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -196,6 +196,7 @@ class Clip: PROJECTION_DIM = "vision.clip.projection_dim" USE_GELU = "vision.clip.use_gelu" MAX_POS_EMBEDDING = "vision.clip.max_position_embeddings" + PROJECTOR_TYPE = "vision.clip.projector_type" HEAD_COUNT = "vision.clip.attention.head_count" LAYERNORM_EPS = "vision.clip.attention.layer_norm_epsilon" @@ -1425,6 +1426,10 @@ class PoolingType(IntEnum): CLS = 2 +class CLIPProjectorType(Enum): + MLP = 'mlp' + + class GGMLQuantizationType(IntEnum): F32 = 0 F16 = 1 diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 2828f0a802d70..e44ef9a1d55d9 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -26,6 +26,7 @@ RopeScalingType, PoolingType, TokenType, + CLIPProjectorType, ) from .quants import quant_shape_from_byte_shape @@ -844,9 +845,18 @@ def add_vision_clip_head_count(self, value: int) -> None: def add_vision_clip_max_position_embeddings(self, value: int) -> None: self.add_uint32(Keys.Vision.Clip.MAX_POS_EMBEDDING, value) + def add_vision_clip_projector_type(self, value: CLIPProjectorType) -> None: + self.add_string(Keys.Vision.Clip.PROJECTOR_TYPE, value.value) + def add_vision_clip_layer_norm_epsilon(self, value: float) -> None: self.add_float32(Keys.Vision.Clip.LAYERNORM_EPS, value) + def add_vision_clip_image_mean(self, value: Sequence[float]) -> None: + self.add_array(Keys.Vision.IMAGE_MEAN, value) + + def add_vision_clip_image_std(self, value: Sequence[float]) -> None: + self.add_array(Keys.Vision.IMAGE_STD, value) + def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None: if not isinstance(value, str): template_default = None diff --git a/src/llama-vision.cpp b/src/llama-vision.cpp index ac507babbe8c0..75fdfc7032398 100644 --- a/src/llama-vision.cpp +++ b/src/llama-vision.cpp @@ -1,5 +1,496 @@ #include "llama.h" #include "llama-vision.h" +#include "llama-impl.h" +struct clip_image_size { + int width; + int height; +}; + +// RGB uint8 image +// Memory layout: RGBRGBRGB... +struct clip_image_u8 { + int nx; + int ny; + std::vector buf; + clip_image_u8() {} + clip_image_u8(const llama_img img) { + nx = img.nx; + ny = img.ny; + buf.resize(nx*ny*3); + memcpy(buf.data(), img.data, buf.size()); + } +}; + +struct clip_image_u8_batch { + struct clip_image_u8 * data; + size_t size; +}; + +// RGB float32 image (NHWC) +// Memory layout: RGBRGBRGB... +struct clip_image_f32 { + int nx; + int ny; + std::vector buf; +}; + +using clip_image_f32_batch = std::vector; +using clip_image_f8_batch = std::vector; + +int32_t clip_image_encode (const clip_context & ctx, const clip_image_f32 & img, std::vector & output); +int32_t clip_image_batch_encode(const clip_context & ctx, const clip_image_f32_batch & imgs, std::vector & output); + +/** + * Selects the best resolution from a list of possible resolutions based on the original size. + * + * @param original_size The original size of the image in the format (width, height). + * @param possible_resolutions A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. + * @return The best fit resolution in the format (width, height). + */ +static clip_image_size select_best_resolution(const clip_image_size & original_size, const std::vector& possible_resolutions) { + int original_width = original_size.width; + int original_height = original_size.height; + + clip_image_size best_fit; + int max_effective_resolution = 0; + int min_wasted_resolution = std::numeric_limits::max(); + + for (const auto& resolution : possible_resolutions) { + int width = resolution.width; + int height = resolution.height; + float scale = std::min(static_cast(width) / original_width, static_cast(height) / original_height); + int downscaled_width = static_cast(original_width * scale); + int downscaled_height = static_cast(original_height * scale); + int effective_resolution = std::min(downscaled_width * downscaled_height, original_width * original_height); + int wasted_resolution = (width * height) - effective_resolution; + // LOG_DBG("resolution: %d %d, scale: %f, downscaled: %d %d, effective: %d, wasted: %d\n", width, height, scale, downscaled_width, downscaled_height, effective_resolution, wasted_resolution); + if (effective_resolution > max_effective_resolution || (effective_resolution == max_effective_resolution && wasted_resolution < min_wasted_resolution)) { + max_effective_resolution = effective_resolution; + min_wasted_resolution = wasted_resolution; + best_fit = resolution; + } + } + + return best_fit; +} + +static bool bicubic_resize(const clip_image_u8 & img, clip_image_u8 & dst, int target_width, int target_height) { + auto clip = [](int x, int lower, int upper) -> int { + return std::max(lower, std::min(x, upper)); + }; + + const int nx = img.nx; + const int ny = img.ny; + + dst.nx = target_width; + dst.ny = target_height; + dst.buf.resize(3 * target_width * target_height); + + float Cc; + float C[5]; + float d0, d2, d3, a0, a1, a2, a3; + int i, j, k, jj; + int x, y; + float dx, dy; + float tx, ty; + + tx = (float)nx / (float)target_width; + ty = (float)ny / (float)target_height; + + // Bicubic interpolation; adapted from ViT.cpp, inspired from : + // -> https://github.com/yglukhov/bicubic-interpolation-image-processing/blob/master/libimage.c#L36 + // -> https://en.wikipedia.org/wiki/Bicubic_interpolation + + for (i = 0; i < target_height; i++) { + for (j = 0; j < target_width; j++) { + x = (int)(tx * j); + y = (int)(ty * i); + + dx = tx * j - x; + dy = ty * i - y; + + for (k = 0; k < 3; k++) { + for (jj = 0; jj <= 3; jj++) { + d0 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x - 1, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k]; + d2 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x + 1, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k]; + d3 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x + 2, 0, nx - 1)) * 3 + k] - img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k]; + a0 = img.buf[(clip(y - 1 + jj, 0, ny - 1) * nx + clip(x, 0, nx - 1)) * 3 + k]; + + a1 = -1.0 / 3 * d0 + d2 - 1.0 / 6 * d3; + a2 = 1.0 / 2 * d0 + 1.0 / 2 * d2; + a3 = -1.0 / 6 * d0 - 1.0 / 2 * d2 + 1.0 / 6 * d3; + + C[jj] = a0 + a1 * dx + a2 * dx * dx + a3 * dx * dx * dx; + + d0 = C[0] - C[1]; + d2 = C[2] - C[1]; + d3 = C[3] - C[1]; + a0 = C[1]; + a1 = -1.0 / 3 * d0 + d2 - 1.0 / 6 * d3; + a2 = 1.0 / 2 * d0 + 1.0 / 2 * d2; + a3 = -1.0 / 6 * d0 - 1.0 / 2 * d2 + 1.0 / 6 * d3; + Cc = a0 + a1 * dy + a2 * dy * dy + a3 * dy * dy * dy; + + const uint8_t Cc2 = std::min(std::max(std::round(Cc), 0.0f), 255.0f); + dst.buf[(i * target_width + j) * 3 + k] = float(Cc2); + } + } + } + } + + return true; +} + +static std::vector divide_to_patches_u8(const clip_image_u8 & image, int patch_size) { + std::vector patches; + int width = image.nx; + int height = image.ny; + for (int i = 0; i < height; i += patch_size) { + for (int j = 0; j < width; j += patch_size) { + clip_image_u8 patch; + patch.nx = std::min(patch_size, width - j); + patch.ny = std::min(patch_size, height - i); + patch.buf.resize(3 * patch.nx * patch.ny); + for (int y = 0; y < patch.ny; ++y) { + for (int x = 0; x < patch.nx; ++x) { + for (int c = 0; c < 3; ++c) { + patch.buf[3 * (y * patch.nx + x) + c] = image.buf[3 * ((i + y) * width + (j + x)) + c]; + } + } + } + patches.push_back(patch); + } + } + return patches; +} + +// llava-1.6 type of resize_and_pad (black) +static void resize_and_pad_image(const clip_image_u8 & image, clip_image_u8 & image_output, const clip_image_size & target_resolution) { + int target_width = target_resolution.width; + int target_height = target_resolution.height; + + float scale_w = static_cast(target_width) / image.nx; + float scale_h = static_cast(target_height) / image.ny; + + int new_width, new_height; + + if (scale_w < scale_h) { + new_width = target_width; + new_height = std::min(static_cast(std::ceil(image.ny * scale_w)), target_height); + } else { + new_height = target_height; + new_width = std::min(static_cast(std::ceil(image.nx * scale_h)), target_width); + } + + clip_image_u8 resized_image; + // bilinear_resize(image, resized_image, new_width, new_height); + bicubic_resize(image, resized_image, new_width, new_height); + + clip_image_u8 padded_image; + padded_image.nx = target_width; + padded_image.ny = target_height; + padded_image.buf.resize(3 * target_width * target_height, 0); // Initialize with black + + // Calculate padding offsets + int pad_x = (target_width - new_width) / 2; + int pad_y = (target_height - new_height) / 2; + + // Copy the resized image into the center of the padded buffer + for (int y = 0; y < new_height; ++y) { + for (int x = 0; x < new_width; ++x) { + for (int c = 0; c < 3; ++c) { + padded_image.buf[3 * ((y + pad_y) * target_width + (x + pad_x)) + c] = resized_image.buf[3 * (y * new_width + x) + c]; + } + } + } + image_output = std::move(padded_image); +} + +static void normalize_image_u8_to_f32(const clip_image_u8 src, clip_image_f32 dst, const std::array & mean, const std::array & std) { + dst.nx = src.nx; + dst.ny = src.ny; + dst.buf.resize(src.buf.size()); + + for (size_t i = 0; i < src.buf.size(); ++i) { + int c = i % 3; // rgb + dst.buf[i] = (static_cast(src.buf[i]) / 255.0f - mean[c]) / std[c]; + } +} + +// returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector +// res_imgs memory is being allocated here, previous allocations will be freed if found +bool clip_image_preprocess(const clip_context & ctx, const clip_image_u8 & img, clip_image_f32_batch & output_imgs) { + bool pad_to_square = true; + auto & params = ctx.model.hparams; + // The model config actually contains all we need to decide on how to preprocess, here we automatically switch to the new llava-1.6 preprocessing + if (params.mm_patch_merge_type == MM_PATCH_MERGE_SPATIAL_UNPAD) { + pad_to_square = false; + } + + // the logic below is to pad the shorter side to the longer side with a background color: rgb(122, 116, 104) + // see https://github.com/haotian-liu/LLaVA/blob/e854a2bf85118c504f6f16bf5c3c7c92f8fa8c6b/llava/conversation.py#L113-L156 + + clip_image_u8 temp; + if (pad_to_square && img.nx != img.ny) { + int longer_side = std::max(img.nx, img.ny); + temp.nx = longer_side; + temp.ny = longer_side; + temp.buf.resize(3 * longer_side * longer_side); + const uint8_t bc[3] = {122, 116, 104}; // background color in RGB from LLaVA (this is the mean rgb color * 255) + + // fill with background color + for (size_t i = 0; i < temp.buf.size(); i++) { + temp.buf[i] = bc[i % 3]; + } + + // copy from the input image + for (int y = 0; y < img.ny; y++) { + for (int x = 0; x < img.nx; x++) { + const int i = 3 * (y * img.nx + x); + const int j = 3 * (y * temp.nx + x); + temp.buf[j] = img.buf[i]; + temp.buf[j+1] = img.buf[i+1]; + temp.buf[j+2] = img.buf[i+2]; + } + } + } else { + if (params.image_grid_pinpoints[0] != 0) { + // "spatial_unpad" with "anyres" processing for llava-1.6 + std::vector possible_resolutions; + for (int i = 0; i < 32 && params.image_grid_pinpoints[i] != 0; i += 2) { + clip_image_size s; + s.width = params.image_grid_pinpoints[i]; + s.height = params.image_grid_pinpoints[i+1]; + possible_resolutions.push_back(s); + } + clip_image_size best_resolution = select_best_resolution({img.nx, img.ny}, possible_resolutions); + // clip_image_save_to_bmp(*img, "input.bmp"); + resize_and_pad_image(img, temp, best_resolution); // we do not pad with mean-bg color anymore in llava-1.6 + // clip_image_save_to_bmp(*temp, "resized.bmp"); + + std::vector patches = divide_to_patches_u8(temp, params.image_size); // prepare spatial sorted main patches of image_size each (336 in llava-1.6) + + clip_image_u8 image_original_resize; + // bilinear_resize(*img, *image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square + bicubic_resize(img, image_original_resize, params.image_size, params.image_size); // in python this is "shortest_edge", but all CLIP are square + patches.insert(patches.begin(), image_original_resize); + // clip_image_f32_batch_init(patches.size()); + output_imgs.resize(patches.size()); + int num = 0; + for (auto & patch : patches) { + normalize_image_u8_to_f32(patch, output_imgs[num], params.image_mean, params.image_std); + num++; + } + return true; + } else { + temp.nx = img.nx; + temp.ny = img.ny; + temp.buf.resize(img.buf.size()); + memcpy(temp.buf.data(), img.buf.data(), temp.buf.size()); + } + } + + const int nx = temp.nx; + const int ny = temp.ny; + // clip_image_save_to_bmp(*temp, "resized_vanilla.bmp"); + + const int nx2 = params.image_size; + const int ny2 = params.image_size; + clip_image_f32 res; + res.nx = nx2; + res.ny = ny2; + res.buf.resize(3 * nx2 * ny2); + + const float scale = std::max(nx, ny) / (float)params.image_size; + + const int nx3 = int(nx / scale + 0.5f); + const int ny3 = int(ny / scale + 0.5f); + + const auto & m3 = params.image_mean; // {0.48145466f, 0.4578275f, 0.40821073f}; + const auto & s3 = params.image_std; // {0.26862954f, 0.26130258f, 0.27577711f}; + + for (int y = 0; y < ny3; y++) { + for (int x = 0; x < nx3; x++) { + for (int c = 0; c < 3; c++) { + // linear interpolation + const float sx = (x + 0.5f) * scale - 0.5f; + const float sy = (y + 0.5f) * scale - 0.5f; + + const int x0 = std::max(0, (int)std::floor(sx)); + const int y0 = std::max(0, (int)std::floor(sy)); + + const int x1 = std::min(x0 + 1, nx - 1); + const int y1 = std::min(y0 + 1, ny - 1); + + const float dx = sx - x0; + const float dy = sy - y0; + + const int j00 = 3 * (y0 * nx + x0) + c; + const int j01 = 3 * (y0 * nx + x1) + c; + const int j10 = 3 * (y1 * nx + x0) + c; + const int j11 = 3 * (y1 * nx + x1) + c; + + const float v00 = temp.buf[j00]; + const float v01 = temp.buf[j01]; + const float v10 = temp.buf[j10]; + const float v11 = temp.buf[j11]; + + const float v0 = v00 * (1.0f - dx) + v01 * dx; + const float v1 = v10 * (1.0f - dx) + v11 * dx; + + const float v = v0 * (1.0f - dy) + v1 * dy; + + const uint8_t v2 = std::min(std::max(std::round(v), 0.0f), 255.0f); + + const int i = 3 * (y * nx3 + x) + c; + + res.buf[i] = ((float(v2) / 255.0f) - m3[c]) / s3[c]; + } + } + } + + output_imgs.resize(1); + output_imgs[0] = std::move(res); + + return true; +} + +int clip_n_patches(const clip_context & ctx) { + auto & hparams = ctx.model.hparams; + int n_patches = (hparams.image_size / hparams.patch_size) * (hparams.image_size / hparams.patch_size); + return n_patches; +} + +static bool encode_image_with_clip(clip_context & ctx_clip, const llama_img img) { + clip_image_u8 img_u8(img); + clip_image_f32_batch img_res_v; + std::vector image_embd; // output vectors + auto & hparams = ctx_clip.model.hparams; + int n_output; + + if (!clip_image_preprocess(ctx_clip, img_u8, img_res_v)) { + LLAMA_LOG_ERROR("%s: unable to preprocess image\n", __func__); + return false; + } + + if (hparams.mm_patch_merge_type != MM_PATCH_MERGE_SPATIAL_UNPAD) { + // flat / default llava-1.5 type embedding + n_output = clip_n_patches(ctx_clip); + bool encoded = clip_image_encode(ctx_clip, img_res_v[0], image_embd); + if (!encoded) { + LLAMA_LOG_ERROR("Unable to encode image\n"); + return false; + } + } +} + +int32_t clip_image_encode(const clip_context & ctx, const clip_image_f32 & img, std::vector & output) { + clip_image_f32_batch imgs{img}; + return clip_image_batch_encode(ctx, imgs, output); +} + +int32_t clip_image_batch_encode(const clip_context & ctx, const clip_image_f32_batch & imgs, std::vector & output) { + int batch_size = imgs.size(); +} + + +//////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////// +// for debugging +#ifndef NDEBUG + +#include +#include +#include +#include + +// export clip_image_u8 to bmp file for debugging +// https://codereview.stackexchange.com/questions/195121/writing-a-bitmap-image-from-c +inline int bmp_export(const clip_image_u8 &img, const std::string &location) { + const uint32_t width = img.nx; + const uint32_t height = img.ny; + const std::vector &buffer = img.buf; + const bool hasAlphaChannel = false; + + std::ofstream fout(location, std::ios::out | std::ios::binary); + + if (fout.fail()) { + return 0; + } + + //Padding + const uint8_t padding = hasAlphaChannel ? 0 : (4 - (width * 3) % 4) % 4; + + //Bitmap file header. + const char signature[2] = { 'B', 'M' }; + const uint32_t fileSize = buffer.size() * sizeof(uint8_t) + padding * (height - 1) + 14 + 124; + const uint32_t offset = 14 + 124; + + //Bitmap information header file + const uint32_t DIBSize = 124; + const int32_t bitmapWidth = width; + const int32_t bitmapHeight = height; + const uint16_t numPlanes = 1; + const uint16_t bitsPerPixel = (hasAlphaChannel) ? 32 : 24; + const uint32_t compressionMethod = (hasAlphaChannel) ? 3 : 0; //BI_RGB = 0, BI_BITFIELDS = 3 + const uint32_t bitmapSize = buffer.size() * sizeof(uint8_t); + const int32_t horizontalResolution = 2834; + const int32_t verticalResolution = 2834; + const uint32_t numColors = 0; + const uint32_t impColorCount = 0; + const uint32_t redBitmask = (hasAlphaChannel) ? 0x0000FF00 : 0; //ARGB32 pixel format + const uint32_t greenBitmask = (hasAlphaChannel) ? 0x00FF0000 : 0; + const uint32_t blueBitmask = (hasAlphaChannel) ? 0xFF000000 : 0; + const uint32_t alphaBitmask = (hasAlphaChannel) ? 0x000000FF : 0; + + //Writing the file header and information header to the file + std::vector header(offset, 0); + header[0] = signature[0]; + header[1] = signature[1]; + +#define BMP_HEADERS(i, variableName) header[i] = variableName; header[i+1] = variableName >> 8; header[i+2] = variableName >> 16; header[i+3] = variableName >> 24; + + BMP_HEADERS(2, fileSize); + BMP_HEADERS(6, 0); + BMP_HEADERS(10, offset); + BMP_HEADERS(14, DIBSize); + BMP_HEADERS(18, bitmapWidth); + BMP_HEADERS(22, bitmapHeight); + + header[26] = (uint8_t)numPlanes; + header[27] = (uint8_t)(numPlanes >> 8); + header[28] = (uint8_t)bitsPerPixel; + header[29] = (uint8_t)(bitsPerPixel >> 8); + + BMP_HEADERS(30, compressionMethod); + BMP_HEADERS(34, (unsigned char)bitmapSize); + BMP_HEADERS(38, (unsigned char)horizontalResolution); + BMP_HEADERS(42, (unsigned char)verticalResolution); + BMP_HEADERS(46, (unsigned char)numColors); + BMP_HEADERS(50, (unsigned char)impColorCount); + BMP_HEADERS(54, (unsigned char)redBitmask); + BMP_HEADERS(58, (unsigned char)greenBitmask); + BMP_HEADERS(62, (unsigned char)blueBitmask); + BMP_HEADERS(66, alphaBitmask); + +#undef BMP_HEADERS + + fout.write((char *)header.data(), sizeof(uint8_t) * header.size()); + + //Writing the pixel array + const uint32_t bWidth = bitsPerPixel / 8 * width; + + for (int i = height - 1; i >= 0; i--) { + std::vector row(buffer.begin() + i * bWidth, buffer.begin() + i * bWidth + bWidth); + fout.write((char *)row.data(), row.size() * sizeof(uint8_t)); + fout.seekp(padding * sizeof(uint8_t), std::ios::cur); + } + + fout.close(); + return 1; +} + +#endif diff --git a/src/llama-vision.h b/src/llama-vision.h index 5bf1673e530a4..e7404ea186fe4 100644 --- a/src/llama-vision.h +++ b/src/llama-vision.h @@ -9,6 +9,10 @@ enum vision_arch { VISION_ARCH_UNKNOWN, }; +enum clip_projector_type { + CLIP_PROJECTOR_TYPE_MLP, +}; + enum mm_patch_merge { MM_PATCH_MERGE_FLAT, MM_PATCH_MERGE_SPATIAL_UNPAD, @@ -28,9 +32,13 @@ struct clip_hparams { float eps; + clip_projector_type proj_type = CLIP_PROJECTOR_TYPE_MLP; mm_patch_merge mm_patch_merge_type = MM_PATCH_MERGE_FLAT; - int32_t image_grid_pinpoints[32]; + std::array image_mean; + std::array image_std; + + std::array image_grid_pinpoints; int32_t image_crop_resolution; }; @@ -89,3 +97,11 @@ struct clip_vision_model { struct ggml_tensor * image_newline = NULL; }; + +struct clip_context { + struct ggml_context * ctx_ggml; + clip_vision_model model; + + int32_t n_output; + float * output; +}; diff --git a/src/llama.cpp b/src/llama.cpp index 0eac03a513f5f..2860c7094f90a 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -400,8 +400,9 @@ enum llm_kv { LLM_KV_VISION_CLIP_PROJECTION_TYPE, LLM_KV_VISION_CLIP_PROJECTION_DIM, LLM_KV_VISION_CLIP_USE_GELU, - LLM_KV_VISION_CLIP_HEAD_COUNT, LLM_KV_VISION_CLIP_MAX_POS_EMBD, + LLM_KV_VISION_CLIP_PROJECTOR_TYPE, + LLM_KV_VISION_CLIP_HEAD_COUNT, LLM_KV_VISION_CLIP_LAYERNORM_EPS, }; @@ -526,6 +527,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_VISION_CLIP_PROJECTION_DIM, "vision.clip.projection_dim" }, { LLM_KV_VISION_CLIP_USE_GELU, "vision.clip.use_gelu" }, { LLM_KV_VISION_CLIP_MAX_POS_EMBD, "vision.clip.max_position_embeddings" }, + { LLM_KV_VISION_CLIP_PROJECTOR_TYPE, "vision.clip.projector_type" }, { LLM_KV_VISION_CLIP_HEAD_COUNT, "vision.clip.attention.head_count" }, { LLM_KV_VISION_CLIP_LAYERNORM_EPS, "vision.clip.attention.layer_norm_epsilon" }, }; @@ -5573,30 +5575,6 @@ static void llm_load_hparams( hparams.n_embd_head_v = 0; } - std::string vision_type; - ml.get_key(LLM_KV_VISION_TYPE, vision_type, false); - if (vision_type == "clip") { - hparams.has_vision = true; - ml.get_key(LLM_KV_VISION_IMAGE_SIZE, hparams.clip.image_size, true); - ml.get_key(LLM_KV_VISION_PATCH_SIZE, hparams.clip.patch_size, true); - ml.get_key(LLM_KV_VISION_CLIP_EMBEDDING_LENGTH, hparams.clip.hidden_size, true); - ml.get_key(LLM_KV_VISION_CLIP_BLOCK_COUNT, hparams.clip.n_layer, true); - ml.get_key(LLM_KV_VISION_CLIP_FEED_FORWARD_LENGTH, hparams.clip.n_intermediate, true); - ml.get_key(LLM_KV_VISION_CLIP_HEAD_COUNT, hparams.clip.n_head, true); - ml.get_key(LLM_KV_VISION_CLIP_LAYERNORM_EPS, hparams.clip.eps, true); - // TODO: add image_std - std::string arch; - ml.get_key(LLM_KV_VISION_CLIP_ARCHITECTURE, arch, true); - for (auto & it : VISION_ARCH_NAMES) { - if (arch == it.second) { - hparams.clip.arch = it.first; - break; - } - } - } else if (!vision_type.empty()) { - throw std::runtime_error(format("unsupported vision type: %s", vision_type.c_str())); - } - // arch-specific KVs switch (model.arch) { case LLM_ARCH_LLAMA: @@ -6244,6 +6222,39 @@ static void llm_load_hparams( default: (void)0; } + // vision model + std::string vision_type; + ml.get_key(LLM_KV_VISION_TYPE, vision_type, false); + if (vision_type == "clip") { + hparams.has_vision = true; + std::string proj_type; + ml.get_key(LLM_KV_VISION_IMAGE_SIZE, hparams.clip.image_size, true); + ml.get_key(LLM_KV_VISION_PATCH_SIZE, hparams.clip.patch_size, true); + ml.get_key_or_arr(LLM_KV_VISION_IMAGE_MEAN, hparams.clip.image_mean, 3, true); + ml.get_key_or_arr(LLM_KV_VISION_IMAGE_STD, hparams.clip.image_std, 3, true); + ml.get_key(LLM_KV_VISION_CLIP_EMBEDDING_LENGTH, hparams.clip.hidden_size, true); + ml.get_key(LLM_KV_VISION_CLIP_BLOCK_COUNT, hparams.clip.n_layer, true); + ml.get_key(LLM_KV_VISION_CLIP_FEED_FORWARD_LENGTH, hparams.clip.n_intermediate, true); + ml.get_key(LLM_KV_VISION_CLIP_HEAD_COUNT, hparams.clip.n_head, true); + ml.get_key(LLM_KV_VISION_CLIP_LAYERNORM_EPS, hparams.clip.eps, true); + ml.get_key(LLM_KV_VISION_CLIP_PROJECTOR_TYPE, proj_type, true); + if (proj_type == "mlp") { + hparams.clip.proj_type = CLIP_PROJECTOR_TYPE_MLP; + } else { + throw std::runtime_error(format("unsupported clip projector type: %s", proj_type.c_str())); + } + std::string arch; + ml.get_key(LLM_KV_VISION_CLIP_ARCHITECTURE, arch, true); + for (auto & it : VISION_ARCH_NAMES) { + if (arch == it.second) { + hparams.clip.arch = it.first; + break; + } + } + } else if (!vision_type.empty()) { + throw std::runtime_error(format("unsupported vision type: %s", vision_type.c_str())); + } + // arch-specific CLIP hparams switch (hparams.clip.arch) { case VISION_ARCH_LLAVA: