Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : first attempt to implement vision API (WIP) #9687

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -1120,6 +1120,7 @@ src/llama.o: \
src/llama-vocab.h \
src/llama-grammar.h \
src/llama-sampling.h \
src/llama-vision.h \
src/unicode.h \
include/llama.h \
ggml/include/ggml-cuda.h \
Expand Down Expand Up @@ -1152,6 +1153,17 @@ src/llama-sampling.o: \
include/llama.h
$(CXX) $(CXXFLAGS) -c $< -o $@

src/llama-vision.o: \
src/llama-vision.cpp \
src/llama-vision.h \
include/llama.h \
ggml/include/ggml-cuda.h \
ggml/include/ggml-metal.h \
ggml/include/ggml.h \
ggml/include/ggml-alloc.h \
ggml/include/ggml-backend.h
$(CXX) $(CXXFLAGS) -c $< -o $@

$(LIB_LLAMA): \
$(OBJ_LLAMA) \
$(LIB_GGML)
Expand Down
44 changes: 40 additions & 4 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ class Model:
dir_model_card: Path
is_lora: bool

# for vision model
vparams: dict[str, Any] | None = None
v_tensor_map: gguf.TensorNameMap
v_tensor_names: set[str] | None

# subclasses should define this!
model_arch: gguf.MODEL_ARCH

Expand Down Expand Up @@ -210,9 +215,13 @@ def match_model_tensor_name(self, name: str, key: gguf.MODEL_TENSOR, bid: int |

def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", ".bias")) -> str:
new_name = self.tensor_map.get_name(key=name, try_suffixes=try_suffixes)
if new_name is None:
new_name_vision = self.v_tensor_map.get_name(key=name, try_suffixes=try_suffixes)
if new_name is not None:
return new_name
elif new_name_vision is not None:
return new_name_vision
else:
raise ValueError(f"Can not map tensor {name!r}")
return new_name

def set_gguf_parameters(self):
self.gguf_writer.add_block_count(self.block_count)
Expand Down Expand Up @@ -452,7 +461,10 @@ def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str]
@staticmethod
def load_hparams(dir_model: Path):
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
return json.load(f)
hparams = json.load(f)
if "text_config" in hparams:
hparams = {**hparams, **hparams["text_config"]}
return hparams

@classmethod
def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]:
Expand Down Expand Up @@ -1501,10 +1513,17 @@ def prepare_tensors(self):
raise ValueError(f"Unprocessed norms: {norms}")


@Model.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM")
@Model.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM", "LlavaForConditionalGeneration")
class LlamaModel(Model):
model_arch = gguf.MODEL_ARCH.LLAMA

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if "vision_config" in self.hparams:
self.vparams = self.hparams["vision_config"]
if self.vparams is not None:
self.v_tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.LLAVA_VISION, self.vparams["num_hidden_layers"])

def set_vocab(self):
try:
self._set_vocab_sentencepiece()
Expand Down Expand Up @@ -1554,6 +1573,20 @@ def set_gguf_parameters(self):
if self.hparams.get("vocab_size", 32000) == 49152:
self.gguf_writer.add_add_bos_token(False)

# For vision model
if self.vparams is not None:
self.gguf_writer.add_vision_type("clip")
self.gguf_writer.add_vision_image_size(self.vparams["image_size"])
self.gguf_writer.add_vision_patch_size(self.vparams["patch_size"])
self.gguf_writer.add_vision_clip_architecture("llava")
self.gguf_writer.add_vision_clip_block_count(self.vparams["num_hidden_layers"])
self.gguf_writer.add_vision_clip_embedding_length(self.vparams["hidden_size"])
self.gguf_writer.add_vision_clip_feed_forward_length(self.vparams["intermediate_size"])
self.gguf_writer.add_vision_clip_head_count(self.vparams["num_attention_heads"])
# TODO: should not hardcode these, but they are currently missing from config.json
self.gguf_writer.add_vision_clip_max_position_embeddings(577)
self.gguf_writer.add_vision_clip_layer_norm_epsilon(1e-05)

@staticmethod
def permute(weights: Tensor, n_head: int, n_head_kv: int | None):
if n_head_kv is not None and n_head != n_head_kv:
Expand All @@ -1568,6 +1601,9 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
n_head = self.hparams["num_attention_heads"]
n_kv_head = self.hparams.get("num_key_value_heads")

if name.startswith("language_model"):
name = name.replace("language_model.", "")

if name.endswith(("q_proj.weight", "q_proj.bias")):
data_torch = LlamaModel.permute(data_torch, n_head, n_head)
if name.endswith(("k_proj.weight", "k_proj.bias")):
Expand Down
74 changes: 74 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,27 @@ class Adapter:
TYPE = "adapter.type"
LORA_ALPHA = "adapter.lora.alpha"

class Vision:
# only support vision.type = "clip" for now
TYPE = "vision.type"
IMAGE_SIZE = "vision.image_size"
PATCH_SIZE = "vision.patch_size"
IMAGE_MEAN = "vision.image_mean"
IMAGE_STD = "vision.image_std"

class Clip:
ARCHITECTURE = "vision.clip.architecture"
CONTEXT_LENGTH = "vision.clip.context_length"
EMBEDDING_LENGTH = "vision.clip.embedding_length"
BLOCK_COUNT = "vision.clip.block_count"
FEED_FORWARD_LENGTH = "vision.clip.feed_forward_length"
PROJECTION_TYPE = "vision.clip.projection_type"
PROJECTION_DIM = "vision.clip.projection_dim"
USE_GELU = "vision.clip.use_gelu"
MAX_POS_EMBEDDING = "vision.clip.max_position_embeddings"
HEAD_COUNT = "vision.clip.attention.head_count"
LAYERNORM_EPS = "vision.clip.attention.layer_norm_epsilon"

#
# recommended mapping of model tensor names for storage in gguf
#
Expand Down Expand Up @@ -238,6 +259,8 @@ class MODEL_ARCH(IntEnum):
GRANITE = auto()
GRANITE_MOE = auto()
CHAMELEON = auto()
# vision models
LLAVA_VISION = auto()


class MODEL_TENSOR(IntEnum):
Expand Down Expand Up @@ -345,6 +368,22 @@ class MODEL_TENSOR(IntEnum):
ENC_FFN_DOWN = auto()
ENC_FFN_UP = auto()
ENC_OUTPUT_NORM = auto()
# vision
V_MMPROJ_A = auto()
V_MMPROJ_B = auto()
V_ENC_EMBD_CLS = auto()
V_ENC_EMBD_PATCH = auto()
V_ENC_EMBD_POS = auto()
V_ENC_ATTN_Q = auto()
V_ENC_ATTN_K = auto()
V_ENC_ATTN_V = auto()
V_ENC_INPUT_NORM = auto()
V_ENC_OUTPUT = auto()
V_ENC_OUTPUT_NORM = auto()
V_ENC_FFN_UP = auto()
V_ENC_FFN_DOWN = auto()
V_PRE_NORM = auto()
V_POST_NORM = auto()


MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
Expand Down Expand Up @@ -397,6 +436,8 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.GRANITE: "granite",
MODEL_ARCH.GRANITE_MOE: "granitemoe",
MODEL_ARCH.CHAMELEON: "chameleon",
# vision
MODEL_ARCH.LLAVA_VISION: "llava",
}

TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
Expand Down Expand Up @@ -504,6 +545,22 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.ENC_FFN_DOWN: "enc.blk.{bid}.ffn_down",
MODEL_TENSOR.ENC_FFN_UP: "enc.blk.{bid}.ffn_up",
MODEL_TENSOR.ENC_OUTPUT_NORM: "enc.output_norm",
# vision
MODEL_TENSOR.V_MMPROJ_A: "v.mmproj_a",
MODEL_TENSOR.V_MMPROJ_B: "v.mmproj_b",
MODEL_TENSOR.V_ENC_EMBD_CLS: "v.enc.embd.cls",
MODEL_TENSOR.V_ENC_EMBD_PATCH: "v.enc.embd.patch",
MODEL_TENSOR.V_ENC_EMBD_POS: "v.enc.embd.pos",
MODEL_TENSOR.V_ENC_ATTN_Q: "v.enc.blk.{bid}.attn_q",
MODEL_TENSOR.V_ENC_ATTN_K: "v.enc.blk.{bid}.attn_k",
MODEL_TENSOR.V_ENC_ATTN_V: "v.enc.blk.{bid}.attn_v",
MODEL_TENSOR.V_ENC_INPUT_NORM: "v.enc.blk.{bid}.input_norm",
MODEL_TENSOR.V_ENC_OUTPUT: "v.enc.blk.{bid}.output",
MODEL_TENSOR.V_ENC_OUTPUT_NORM: "v.enc.blk.{bid}.output_norm",
MODEL_TENSOR.V_ENC_FFN_UP: "v.enc.blk.{bid}.ffn_up",
MODEL_TENSOR.V_ENC_FFN_DOWN: "v.enc.blk.{bid}.ffn_down",
MODEL_TENSOR.V_PRE_NORM: "v.pre_norm",
MODEL_TENSOR.V_POST_NORM: "v.post_norm",
}

MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
Expand Down Expand Up @@ -1279,6 +1336,23 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.LLAVA_VISION: [
MODEL_TENSOR.V_MMPROJ_A,
MODEL_TENSOR.V_MMPROJ_B,
MODEL_TENSOR.V_ENC_EMBD_CLS,
MODEL_TENSOR.V_ENC_EMBD_PATCH,
MODEL_TENSOR.V_ENC_EMBD_POS,
MODEL_TENSOR.V_ENC_ATTN_Q,
MODEL_TENSOR.V_ENC_ATTN_K,
MODEL_TENSOR.V_ENC_ATTN_V,
MODEL_TENSOR.V_ENC_INPUT_NORM,
MODEL_TENSOR.V_ENC_OUTPUT,
MODEL_TENSOR.V_ENC_OUTPUT_NORM,
MODEL_TENSOR.V_ENC_FFN_UP,
MODEL_TENSOR.V_ENC_FFN_DOWN,
MODEL_TENSOR.V_PRE_NORM,
MODEL_TENSOR.V_POST_NORM,
],
# TODO
}

Expand Down
33 changes: 33 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,39 @@ def add_remove_extra_whitespaces(self, value: bool) -> None:
def add_precompiled_charsmap(self, charsmap: Sequence[bytes]) -> None:
self.add_array(Keys.Tokenizer.PRECOMPILED_CHARSMAP, charsmap)

def add_vision_type(self, value: str) -> None:
self.add_string(Keys.Vision.TYPE, value)

def add_vision_image_size(self, value: int) -> None:
self.add_uint32(Keys.Vision.IMAGE_SIZE, value)

def add_vision_patch_size(self, value: int) -> None:
self.add_uint32(Keys.Vision.PATCH_SIZE, value)

def add_vision_clip_architecture(self, value: str) -> None:
self.add_string(Keys.Vision.Clip.ARCHITECTURE, value)

def add_vision_clip_context_length(self, value: int) -> None:
self.add_uint32(Keys.Vision.Clip.CONTEXT_LENGTH, value)

def add_vision_clip_embedding_length(self, value: int) -> None:
self.add_uint32(Keys.Vision.Clip.EMBEDDING_LENGTH, value)

def add_vision_clip_block_count(self, value: int) -> None:
self.add_uint32(Keys.Vision.Clip.BLOCK_COUNT, value)

def add_vision_clip_feed_forward_length(self, value: int) -> None:
self.add_uint32(Keys.Vision.Clip.FEED_FORWARD_LENGTH, value)

def add_vision_clip_head_count(self, value: int) -> None:
self.add_uint32(Keys.Vision.Clip.HEAD_COUNT, value)

def add_vision_clip_max_position_embeddings(self, value: int) -> None:
self.add_uint32(Keys.Vision.Clip.MAX_POS_EMBEDDING, value)

def add_vision_clip_layer_norm_epsilon(self, value: float) -> None:
self.add_float32(Keys.Vision.Clip.LAYERNORM_EPS, value)

def add_chat_template(self, value: str | Sequence[Mapping[str, str]]) -> None:
if not isinstance(value, str):
template_default = None
Expand Down
60 changes: 60 additions & 0 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,66 @@ class TensorNameMap:
MODEL_TENSOR.ENC_OUTPUT_NORM: (
"encoder.final_layer_norm", # t5
),

MODEL_TENSOR.V_MMPROJ_A: (
"multi_modal_projector.linear_1",
),

MODEL_TENSOR.V_MMPROJ_B: (
"multi_modal_projector.linear_2",
),

MODEL_TENSOR.V_ENC_EMBD_CLS: (
"vision_tower.vision_model.embeddings.class_embedding",
),

MODEL_TENSOR.V_ENC_EMBD_PATCH: (
"vision_tower.vision_model.embeddings.patch_embedding",
),

MODEL_TENSOR.V_ENC_EMBD_POS: (
"vision_tower.vision_model.embeddings.position_embedding",
),

MODEL_TENSOR.V_ENC_ATTN_Q: (
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.q_proj",
),

MODEL_TENSOR.V_ENC_ATTN_K: (
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.k_proj",
),

MODEL_TENSOR.V_ENC_ATTN_V: (
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.v_proj",
),

MODEL_TENSOR.V_ENC_INPUT_NORM: (
"vision_tower.vision_model.encoder.layers.{bid}.layer_norm1",
),

MODEL_TENSOR.V_ENC_OUTPUT: (
"vision_tower.vision_model.encoder.layers.{bid}.self_attn.out_proj",
),

MODEL_TENSOR.V_ENC_OUTPUT_NORM: (
"vision_tower.vision_model.encoder.layers.{bid}.layer_norm2",
),

MODEL_TENSOR.V_ENC_FFN_UP: (
"vision_tower.vision_model.encoder.layers.{bid}.mlp.fc1",
),

MODEL_TENSOR.V_ENC_FFN_DOWN: (
"vision_tower.vision_model.encoder.layers.{bid}.mlp.fc2",
),

MODEL_TENSOR.V_PRE_NORM: (
"vision_tower.vision_model.pre_layrnorm",
),

MODEL_TENSOR.V_POST_NORM: (
"vision_tower.vision_model.post_layernorm",
),
}

# architecture-specific block mappings
Expand Down
24 changes: 24 additions & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,20 @@ extern "C" {

typedef bool (*llama_progress_callback)(float progress, void * user_data);

// represent an RGB image
// size of data must be equal to 3*nx*ny
typedef struct llama_img {
uint32_t nx;
uint32_t ny;
unsigned char * data;
} llama_img;

// Input data for llama_vision_decode
typedef struct llama_img_batch {
int32_t n_imgs;
llama_img * imgs;
} llama_img_batch;

// Input data for llama_decode
// A llama_batch object can contain input about one or many sequences
// The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens
Expand Down Expand Up @@ -875,6 +889,16 @@ extern "C" {
// shape: [n_embd] (1-dimensional)
LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id);

//
// Vision
//

// encode image into embeddings
LLAMA_API int32_t llama_vision_encode(struct llama_context * ctx, llama_img_batch * batch);

// get output embeddings, to be put into language batch
LLAMA_API float * llama_vision_get_embeddings(struct llama_context * ctx, int32_t idx);

//
// Vocab
//
Expand Down
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ add_library(llama
llama-vocab.cpp
llama-grammar.cpp
llama-sampling.cpp
llama-vision.cpp
unicode.h
unicode.cpp
unicode-data.cpp
Expand Down
5 changes: 5 additions & 0 deletions src/llama-vision.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#include "llama.h"

#include "llama-vision.h"


Loading