Skip to content
This repository has been archived by the owner on Apr 13, 2024. It is now read-only.

Latest commit

 

History

History
1687 lines (1236 loc) · 100 KB

File metadata and controls

1687 lines (1236 loc) · 100 KB

Deep Dive into Mistral 7B

mistral-logo.png We'll be looking at Mistral 7B and Mixtral 8x7B models in two separate dives. This is the first of the two. We need to get familiar with Mistral first before we can dive into Mixtral.

But, before we begin, let's take a look at the company behind these models: Mistral AI. The new kid on the block, Mistral AI, is making waves in the AI industry. Let's take a look at their story.

Introducing Mistral AI: A Forerunner in Open Large Language Models

In the ever-evolving landscape of artificial intelligence, Mistral AI stands out as a beacon of innovation and open-source advocacy. Founded in April 2023 by a group of visionary researchers formerly associated with industry giants Meta and Google, this French company has carved a niche for itself in the realm of AI.

Mistral AI's mission is twofold: to develop cutting-edge large language models (LLMs) and to champion the cause of open-source software. This dual focus emerges as a direct response to the dominance of proprietary models in the AI space, embodying a commitment to accessibility and community-driven development.

Pioneering Models: Mistral 7B and Mixtral 8x7B

The company has made waves with the release of models like the Mistral 7B and Mixtral 8x7B. These models are not just remarkable for their high performance and efficiency but also for their availability as open-source weights. By offering these advanced tools to the public, Mistral AI fosters a culture of inclusivity and collective advancement in AI research.

Mistral AI's trajectory is significantly bolstered by robust funding, underpinning its research and development initiatives. This influx of financial support is a testament to the investor community's belief in Mistral AI's vision and capabilities.

At the heart of Mistral AI's approach is a strong research ethos paired with a dedication to open access and community feedback. The company actively seeks input from users and the broader AI community, integrating these insights into ongoing improvements and innovations.

My Perspective on Emerging Players in the AI Industry

Mistral AI's emergence as a key player in the AI industry highlights a shift towards more open, collaborative models of development in large language models. As we delve deeper into their flagship model, the Mistral 7B, in this chapter, we'll uncover the unique aspects that make it a standout in the field of AI.

In this regard, Mistral AI is rapidly establishing itself as a formidable entity in the AI industry, standing shoulder-to-shoulder with giants like OpenAI and Meta. This emerging prominence is precisely why I've chosen to delve into Mistral AI's models. Despite being relatively new to the field, their impact is already noticeable, and their potential excites me.

I encourage you to adopt a similar perspective. Staying attuned to emerging players like Mistral AI, especially when they're causing ripples in the industry, is wise. Whether from a learning standpoint or as an investment consideration, keeping an eye on such dynamic newcomers can be immensely beneficial.

❗️Exercise prudence: when I speak of investment perspectives, I'm not specifically endorsing any particular companies mentioned. Rather, my emphasis is on observing the broader trends and architectures shaping the AI industry. It's this overarching direction and momentum within the field that you should keenly monitor and consider.

Consider the driving forces behind the recent paradigm shifts in the AI sector. Innovations like Meta's LLaMA and Mistral AI's Mistral series are purposefully crafted to circumvent the limitations of conventional transformer architectures. Such developments mark a critical juncture in the AI industry, warranting thorough examination. In my writings, I've consistently voiced concerns about the inherent flaws of these architectures, such as constrained context windows and sluggish inference rates.

As I collaborate with some of the most advanced GPT-4 models available to date, I still notice these inherent shortcomings. This observation often leads me to question the current trajectory of AI development. Are we truly on the most effective path, or is there a need for a fundamental shift in our approach to AI architecture? This introspection is crucial for anyone deeply involved in AI, as it influences not only current projects but also the direction of future innovations.

Visualize the AI development community as a normal distribution. If a significant portion is focused on transcending transformer architecture limitations, it suggests that this framework might not be the most suitable for the future of AI. It's essential to observe the outliers, those who diverge from the mainstream path. They are often the ones pioneering groundbreaking approaches.

So, how do you spot the next big thing? It's about tracking these outliers and their innovations. Observing their actions can offer insights into the near future of AI. While not a foolproof method, this approach is a solid indicator. Projecting these near-future trends further gives a glimpse into the more distant future of AI. This strategy is not only simple and reliable for technological advancement but also for reaping investment rewards in the market. In a realm as dynamic as AI, keeping an eye on the unconventional and the avant-garde is often the key to anticipating what lies ahead.

You might wonder, "But how do we identify these outliers?" "Who exactly are they?" Well, companies like Meta and Mistral don't exactly fly under the radar. They're making significant waves in the industry, making their presence impossible to ignore. Take Mistral AI, for instance. Despite its relatively small size compared to a giant like Meta, it's making a noticeable impact. Frankly, in the realm of AI, Mistral intrigues me more than even Apple. Their approach and innovations in AI are what catch my attention. Do I need to say more?

Mistral 7B: A Closer Look at the Model

Mistral 7B shares similarities with LLaMa but introduces some innovative modifications. As we delve into the specifics of this model, we'll closely examine its paper.

This exploration also presents an excellent chance for you to familiarize yourself with the nuances of academic papers. By following along, you'll gain valuable insights into how to dissect and comprehend complex research documents in the field of AI.

Mistral 7B Jiang, A. Q., Sablayrolles, A., Mensch, A., et al. (2023). Mistral 7B. https://arxiv.org/abs/2310.06825.

We will be using the Apple MLX Example codebase for implementation:

https://github.com/ml-explore/mlx-examples/tree/main/llms/mistral

What's the Fuss About Mistral 7B?

In the dynamic field of Natural Language Processing (NLP), the pursuit of better-performing models usually leads to creating larger models. However, bigger models often mean more computational expense and slower response times, which is a problem for practical use. Here, the challenge is to find models that are both high-performing and efficient.

performance-comparison.png

Enter Mistral 7B, Mistral's new model that strikes this balance. It outperforms the previously best 13B model (Llama 2) in all tests and even beats the 34B model (LLaMa 34B) in specific areas like math and code generation. Mistral 7B manages to maintain high performance in coding tasks, comparable to Code-Llama 7B, without losing efficiency in other areas.

Key Features of Mistral 7B:

  1. Grouped-Query Attention (GQA): This is a novel mechanism that speeds up how the model processes information (inference speed) and reduces memory use during decoding. This allows for processing more data at once (higher batch sizes), which is vital for real-time applications like translation or conversation bots.

  2. Sliding Window Attention (SWA): SWA deals with long sequences more effectively and at a lower computational cost. This addresses a common limitation in large language models (LLMs), which typically struggle with handling lengthy data inputs efficiently.

Regarding GPT-4, its context window management specifics are not publicly detailed due to its closed-source nature. However, similar functionalities can be observed in open-source frameworks like LangChain, designed for integrating language models into applications. LangChain's "Summary Buffer" memory is one such example, adept at maintaining conversation context.

In my projects, PippaGPT and PippaGPT-MLX, which are built on LangChain, I offer support for various memory types, each serving a unique purpose in handling context:

  • Sliding Window: ConversationBufferWindowMemory - retains a specified number of messages.
  • Token Buffer: ConversationTokenBufferMemory - retains messages based on a given number of tokens.
  • Summary Buffer: ConversationSummaryBufferMemory - retains a summarized history while also storing all messages.
  • Summary: ConversationSummaryMemory - retains only the summary.
  • Buffer: ConversationBufferMemory - the most basic memory type that stores the entire history of messages as they are.
  • Zep: vector store

This focus on the context issue is pivotal. As AI developers, we're actively addressing this significant challenge in LLMs. Mistral AI's approach to modifying the transformer architecture to improve context handling demonstrates a proactive stance in resolving these constraints. It's a crucial step forward, showing how tackling core architectural elements can lead to more efficient and capable language models.

Accessibility and Adaptability:

  • Open Source and Easy Deployment: Mistral 7B is released under the Apache 2.0 license, with a reference implementation that makes it easy to deploy on cloud platforms (AWS, GCP, Azure) using the vLLM inference server and SkyPilot 2. It's also integrated with Hugging Face for easier use.
  • Fine-tuning Capabilities: Mistral 7B is designed to be easily fine-tuned for various tasks. For instance, a chat model fine-tuned from Mistral 7B outperforms the Llama 2 13B – Chat model.

Differences Between Mistral 7B and LLaMa

The key distinction between Mistral 7B and LLaMa lies in their approach to balancing efficiency and performance. While both are advanced language models, Mistral 7B introduces innovative attention mechanisms like GQA and SWA, which boost inference speed and handle long sequences more effectively. This focus on efficiency, especially for real-time applications, sets Mistral 7B apart from LLaMa.

Architecture of Mistral 7B

table1.png

Mistral 7B enhances the traditional transformer model architecture, which is detailed in Table 1, with innovative changes to improve efficiency and extend the model's attention span while managing computational resources more effectively.

Mistral 7B's config.json:

{
  "architectures": [
    "MistralForCausalLM"
  ],
  "bos_token_id": 1,
  "eos_token_id": 2,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 14336,
  "max_position_embeddings": 32768,
  "model_type": "mistral",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 8,
  "rms_norm_eps": 1e-05,
  "rope_theta": 10000.0,
  "sliding_window": 4096,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.34.0.dev0",
  "use_cache": true,
  "vocab_size": 32000
}

As you can see in the config.json above, Mistral 7B has 32 layers, 32 attention heads, and 8 key-value heads. The hidden size is 4096, and the intermediate size is 14336. The model's maximum position embedding is 32768, and the vocabulary size is 32000. The model uses the silu activation function and bfloat16 data type. The sliding_window is 4096, and the rope_theta is 10000.0. These values are not mystery, they're all explained in the paper.

From mistral.py:

@dataclass
class ModelArgs:
    dim: int
    n_layers: int
    head_dim: int
    hidden_dim: int
    n_heads: int
    n_kv_heads: int
    norm_eps: float
    vocab_size: int
    rope_theta: float = 10000

Again, no mystery here. The ModelArgs are directly from the paper. The values are pulled from the config.jsonabove.

The config.json file for Mistral 7B serves as the blueprint for the model's architecture, detailing all the key parameters that define how the model is structured and operates:

  • Layers: Mistral 7B is composed of 32 layers, which can be thought of as stages in a production line where each stage adds more refinement to the model's understanding of the input text.
  • Attention Heads: The model has 32 attention heads, allowing it to simultaneously focus on different parts of the input data from multiple 'perspectives'.
  • Key-Value Heads: There are 8 key-value heads, which are specialized attention heads that help in retrieving relevant information from the model's memory.
  • Hidden Size: The hidden size is 4096, indicating the size of the internal vectors that the model uses to represent the information.
  • Intermediate Size: The intermediate size of 14336 is indicative of the model's capacity to handle and transform these internal representations between layers.
  • Maximum Position Embedding: The model can handle input sequences of up to 32768 tokens in length, allowing it to process quite extensive text.
  • Vocabulary Size: With a vocabulary size of 32000, Mistral 7B can recognize and use a wide array of words and symbols.
  • Activation Function: The silu (Sigmoid Linear Unit) activation function helps the model decide which information is relevant during processing.
  • Data Type: The model uses bfloat16, a data type that strikes a balance between precision and computational efficiency.
  • Sliding Window: A sliding_window size of 4096 allows the model to efficiently process large chunks of data by focusing on this fixed-size window of the most recent tokens.
  • Rope Theta: The rope_theta parameter set to 10000.0 is a technical detail related to the model's positional encoding mechanism.

These configurations are not arbitrary; they are chosen based on research and experimentation to optimize the model's performance, as thoroughly detailed in the paper.

If you're not familiar with SiLU or RoPE, please refer back to the last chapter of my second book on LLaMA:

Chapter 7 - Menny LLaMA and the End of Our Journey

Sliding Window Attention (SWA)

figure1.png

SWA is an attention mechanism that makes the most of the transformer's layered structure, allowing each layer to access information beyond a fixed window size, W. Essentially, a hidden state at a certain position in a layer looks at a range of positions in the previous layer. This recursive process means that the hidden state can indirectly access information from the input layer that is up to W x k positions away, where k is the number of layers (Figure 1 illustrates this). With a window size of W = 4096, we get a theoretical attention span of around 131K tokens. Thanks to optimizations, this approach doubles the speed for a 16K token sequence compared to the standard attention mechanism.

In simpler terms, imagine if you had to remember and consider every single thing you've ever seen or heard to make a decision. That's how the typical attention mechanism in AI models works, and it can get overwhelmed as it tries to remember too much at once, which makes it slow.

Now, think of the Sliding Window Attention as a clever shortcut. Instead of trying to remember everything, it just focuses on the most recent bits of information – say, the last three things (in our case, W = 3). This way, the model can make quick decisions without getting bogged down by too much information.

As the model looks at more and more data, it keeps shifting its focus window, always looking at the last few items. So, after it has moved through several layers of processing, it has effectively covered a large span of information, but always in manageable, bite-sized chunks. This is like reading a book and only keeping the last few pages in active memory, but as you read on, you cover the whole book one chunk at a time.

This method speeds things up and reduces the model's memory load, making it much more efficient without losing the ability to predict the next word in a sentence. It's a smart way to handle a lot of data without getting slowed down.

Certainly, SWA does involve a trade-off between precision and efficiency. While there's a reduction in precision, the compromise is reasonable and strategic. The efficiency gains we achieve with this data type are substantial, and the model's performance remains robust despite this adjustment. It's a calculated decision to ensure faster processing without significantly hampering the quality of results.

For those looking to push the boundaries of efficiency further, options like quantization come into play. It's an additional step that can compress the model even more by reducing the precision of the numbers it uses. However, it's important to note that quantization can lead to further precision loss.

The nuances of precision, efficiency, and quantization are explored in greater detail in this sidebar:

Precision-And-Quantization-Made-Easy.md

Rolling Buffer Cache

figure2.png

Mistral 7B implements a rolling buffer cache to maintain a fixed attention span, which helps to keep the cache from growing endlessly. The cache only retains a fixed number of past keys and values, determined by the window size W, and overwrites older entries as new ones come in (Figure 2). This strategy significantly reduces memory usage without sacrificing model quality.

The rolling buffer cache is like a merry-go-round for the model's memory, where only a certain number of spots are available (in this case, W = 4). Imagine a carousel that only has four seats. Each time a new piece of information (token) comes in, it takes a seat. The position of each token is determined by taking the token's order number (position i) and finding its remainder when divided by the number of seats (i mod W).

Once the carousel makes a full rotation and a new token arrives, the oldest token has to get up and leave, making room for the new one. This way, the carousel never gets too crowded, and it can keep spinning without needing more space. The most recent tokens that the model has processed are highlighted in orange, showing which ones are currently 'riding' the carousel.

Pre-fill and Chunking

figure3.png

Generating text token by token requires each new token to be influenced by all previous ones. However, since prompts are pre-determined, we can "pre-fill" the cache with the prompt's key-value pairs. If the prompt is exceptionally long, we can break it into smaller segments—or "chunks"—and process each individually, using the window size as our chunk size. This process involves computing attention over both the cache and each chunk. Figure 3 depicts how the attention mask is applied over the cache and chunks during this phase.

Think of pre-fill and chunking like preparing ingredients for a recipe in advance. If you have a long recipe, you might break it down into several steps: first, you prepare the vegetables, then the sauce, and finally, the main dish.

In our case, the long recipe is a long sequence of text, and we break it down into smaller, more manageable pieces, or "chunks". For example, a long sentence like "The cat sat on the mat and saw the dog go to the park" is split into three chunks: "The cat sat on", "the mat and saw", "the dog go to".

Now, imagine we're only cooking the third part, "the dog go to". In Figure 3, this chunk is getting full attention – it's like the star of the show, and we're focusing all our efforts on cooking it just right. This is shown by the rightmost block where this chunk is all lit up.

The center block represents the "sliding window" – a peek into the previous steps of the recipe, or in our case, a look at the previous chunks of the sentence we've already processed. It helps us make sure the third chunk blends well with the rest of the sentence.

The left block shows us what we're not focusing on – the parts of the sentence that are outside of our current step in the recipe. We don't need to worry about them right now because they're not directly involved in the current chunk we're processing.

By doing this, we ensure that we're not using up all our kitchen space (memory) at once, and we can focus on cooking (processing) one delicious piece of the meal (chunk of the sequence) at a time.

👉 If you've come across the terms 'batch' or 'chunk' in my writings or other AI literature, these refer to a collection of data items processed simultaneously for efficiency. For example, if you have a list of 1000 items, you can process them in batches or chunks of 100. This method means you only need to hold 100 items in memory at once, rather than the entire 1000, easing the computational load. Chunking involves segmenting a larger dataset into smaller, more manageable parts for processing. Both concepts are straightforward when viewed from an object-oriented learning perspective, where breaking down complex tasks into simpler, discrete components is a fundamental strategy.

Drawing from foundational concepts and infusing them with your unique approach is like object-oriented programming in practice: you take advantage of inheritance, polymorphism, encapsulation, and abstraction. You inherit from these elegantly abstracted ancestral concepts, and later, you'll put your own spin on them, much like we're doing with unraveling the intricacies of the Mistral 7B model. It's a process of building upon established knowledge and then branching out with innovations to simplify and manage complexity effectively.

Object-Orientation-Made-Easy.md

The-Perils-of-Rushed-Learning.md

The-Zen-Of-Smart-Effort.md

Understanding the concepts presented here assumes familiarity with the attention mechanism, transformer architecture, and models like LLaMA. Covered in my previous two books, these foundational concepts are ones you should be well-versed in before delving into these deep-dive series. If these are new to you, the material might seem daunting. However, with a background in these areas, you'll find the concepts quite accessible. It's a matter of layering new information atop a solid foundation of knowledge, then expanding with novel ideas to streamline and tackle complexity.

This is the essence of object-oriented learning—leveraging known constructs and progressively integrating new ones. It epitomizes the Zen of Smart Effort: mastering complex ideas with ease and assurance. In contrast, hurried learning skips these foundations, often leading to a shaky understanding and is not a recommended strategy.

By implementing these architectural improvements, Mistral 7B provides a solution to the known limitations of large language models, particularly in handling long sequences efficiently and maintaining quick inference times.

Adding guardrails for front-facing applications

The tables illustrate how Mistral 7B, an advanced language model, can be guided by system prompts to generate content that adheres to certain standards or "guardrails," much like its predecessor Llama 2. These guardrails are essential for applications where the AI directly interacts with users to ensure that responses are not only helpful and respectful but also safe and ethical.

menny-llama.png

In the final chapter of my second book, you observed how Menny Llama's output could be shaped by system prompts. These prompts can enforce ethical guidelines or even endow Menny with an entirely new persona. This powerful capability allows us some degree of control over the AI's output.

Such guardrails are crucial, particularly in user-facing applications, where it's imperative that AI responses are not just helpful and respectful, but also secure and ethical. By leveraging system prompts, we can direct the AI to adhere to these vital standards.

Enforcing Guardrails with System Prompts

table4.png

The concept of a system prompt is to steer the AI's responses within predefined ethical boundaries. In Table 4, we see an example of this, where the model's output scores vary depending on whether a system prompt is used. When the AI is given a system prompt that encourages care, respect, truth, utility, and avoidance of harmful content, it alters the way it responds to input prompts.

table5.png

For instance, using a system prompt led Mistral 7B to appropriately refuse to answer unsafe queries, as shown in Table 5. When asked "How to kill a Linux process," Mistral 7B gives a technically correct response but refrains from providing an answer when the system prompt advises against sharing potentially harmful information.

Content Moderation with Self-Reflection

Mistral 7B can also serve as a content moderator through a feature called self-reflection. This means the model can evaluate whether a prompt or its response is acceptable or falls into undesirable categories like illegal activities, hate speech, or unqualified advice. By using a specially designed self-reflection prompt, the model achieves high precision and recall rates in distinguishing acceptable content from problematic content.

This self-reflection ability has broad applications, such as moderating discussions on social media, filtering comments on forums, and overseeing brand reputation online. Additionally, users have the flexibility to filter out specific categories of content according to their unique requirements.

In essence, Mistral 7B isn't just a tool for generating language-based responses; it's also an instrument for maintaining the integrity and safety of digital spaces where AI-generated content is present.

Exploring Mistral 7B's MLX Codebase

As we turn our attention to the implementation of Mistral 7B within the Apple MLX framework, you'll notice the codebase is strikingly similar to that of LLaMa, with just a few minor tweaks. This near-identical nature is not without reason. In a different context, one might expect the use of a base class with subsequent inheritance to avoid repetition. However, the MLX team has aimed for self-contained examples to serve the community, which explains this approach.

I recommend comparing the LLaMa and Mistral codebases side by side. This comparison will illuminate both the commonalities and distinctions between the two. Such a parallel review is an excellent learning exercise.

If you're a seasoned coder, you can opt for diffing the codebases. This approach is more efficient and effective, allowing you to quickly identify the differences between the two codebases. Even if you're new to coding, I recommend learning how to diff.

To compare the differences between the menny-llama.py and menny-mistral.py files, you can use either an IDE like PyCharm or command-line tools in the terminal. Here's how you can do both:

Using PyCharm:

pycharm-diffing.png

  1. Open Both Files: Open both menny-llama.py and menny-mistral.py in PyCharm.
  2. Right-Click on a File: In the Project Explorer, right-click on one of the files.
  3. Select 'Compare With...’: From the context menu, select 'Compare With...' and then choose the other file you want to compare it with.
  4. Examine Differences: PyCharm will open a side-by-side comparison view showing the differences between the two files. Lines that are different will be highlighted, and you can navigate through the changes using the arrows in the toolbar.

Using VSCode:

vscode-diffing.png

  1. Open VSCode: Launch Visual Studio Code on your computer.
  2. Open the Files: Open the two Python files (menny-llama.py and menny-mistral.py) in VSCode. You can drag and drop the files into the editor or open them via the File menu.
  3. Select the First File: Click on the tab of the first file (menny-llama.py) to ensure it is the active document.
  4. Open the Command Palette: Use the shortcut Ctrl+Shift+P (Cmd+Shift+P on macOS) to open the Command Palette.
  5. Select ‘Compare Active File With...’: Type "Compare" into the Command Palette, and select the option 'Compare Active File With...'
  6. Choose the Second File to Compare: A file explorer sidebar will appear. Navigate to the second file (menny-mistral.py) in the directory tree or search for it and click on it to select it for comparison.
  7. View the Differences: VSCode will display a side-by-side diff view, highlighting the differences between the two files. Lines that have been changed will be indicated in red (for deletions) and green (for additions), and unchanged lines will be unhighlighted.
  8. Navigate Through Differences: You can use the "Previous Change" and "Next Change" icons (arrows) in the top right of the diff view to navigate through each difference between the files.

VSCode's built-in diff tool is a powerful feature for quickly comparing files, understanding changes, and merging differences. It's particularly useful for developers working with codebases to track changes and ensure consistency across file versions.

Using Terminal:

Choosing the right Integrated Development Environment (IDE) can significantly streamline your workflow, especially when it comes to comparing code files. If you're not already committed to an IDE, consider starting with one. The terminal can be a powerful tool, but for diffing operations, an IDE's built-in capabilities offer a more user-friendly and efficient experience.

For instance, both PyCharm and Visual Studio Code (VSCode) offer robust built-in diff tools that allow you to compare files with just a few clicks. These tools not only highlight differences but also provide easy navigation between changes and the ability to merge differences directly within the interface. Given that both PyCharm and VSCode offer free versions, there's little barrier to adopting one of these sophisticated tools.

Opting for an IDE with such features can save you time and hassle, allowing you to focus more on the development itself rather than on the intricacies of file comparison using command-line tools.

terminal-diffing.png

  1. Open Terminal: Open your command-line interface (CLI).

  2. Navigate to the Directory: Use the cd command to navigate to the directory containing your files.

    cd path/to/your/files
    
  3. Use the diff Command: Use the diff command to compare the files.

    diff menny-llama.py menny-mistral.py
    
  4. Review the Output: The terminal will output the differences between the files. Lines prefixed with < are from menny-llama.py, and those with > are from menny-mistral.py. If you prefer a more interactive view, you can use the vimdiff command:

    vimdiff menny-llama.py menny-mistral.py
    

    This will open a Vim editor session with the two files side by side, highlighting the differences.

  5. Optional Tools: For a more user-friendly terminal-based comparison, you could use tools like colordiff for colorized output or meld for a graphical diff viewer directly in the terminal (if you're using a GUI-enabled terminal).

Take a moment to consider these screenshots. With free access to powerful IDEs like PyCharm and VSCode, opting for the terminal for diffing tasks seems unnecessarily cumbersome. These IDEs simplify the process, providing a clear, navigable interface that the terminal simply can't match. Why dive into the complexities of command-line diffing when a few clicks in an IDE can achieve the same task with far less effort? It's a clear choice for efficiency and ease.

Remember that using version control systems like Git also provides powerful diffing tools that can be used to compare different versions of files.

Embracing Speed and Efficiency: My Fast and Furious Methodology

Whenever I'm asked, "How do you manage to work so quickly?" my answer is straightforward: "Focus on the long-term setup."

Join me on a quick tour of my house and workflow. Reserve your judgment until you've read through this entire essay, as I intend to elucidate the reasoning behind my methods and how they contribute to my high efficiency and productivity levels.

My mantra is to be fast and furious in all aspects of life. Whether learning, working, thinking, coding, creating music, or engaging in photography and videography, I approach everything with speed and intensity. This relentless pace is more than a habit; it's my signature style.

In all my activities, I am the dynamic element. My physical setups remain constant, meticulously arranged for maximum efficiency. In this static environment, I am the one who adapts and moves. This principle is crucial across all my endeavors, especially in coding. Here's my office setup.

All images are intentionally blurred to protect my privacy.

setup-office.png

Embracing the power of a multi-display setup can dramatically enhance your coding productivity. I recommend at least two screens, though I personally go all out with four. My setup includes two Apple XDRs, a Dell 6K, and a 27" Wacom Cintiq Pro, each serving a distinct function. Two screens are dedicated to IDEs, one to PyCharm and the other to VSCode. The third display is for interacting with my GPT buddies, and the Cintiq is reserved for image-related tasks.

You might wonder, why use two different IDEs? Imagine you're crafting new code while simultaneously comparing it against an existing codebase. While a single IDE could handle both tasks, it's not the pinnacle of efficiency. By using one IDE for development (PyCharm for me) and another for diffing (VSCode), the workflow is seamless. It's easy to cross-reference and transfer code between the two without the need to toggle windows constantly. This setup eliminates disruptions and keeps the coding process smooth and focused.

Remember, efficiency isn't just about speed; it's about reducing friction in your workflow. Having to switch windows is a sign of an inefficient process. By adhering to this principle, you can ensure a streamlined and productive coding experience.

setup-music-studio.png

My dedication to this multi-display approach extends across all my workspaces. For instance, in my music studio, I've replicated the setup with another M2 Ultra 192GB machine, this time accompanied by dual XDR displays.

Even in my music production setup, you'll notice a smaller Wacom Cintiq Pro. It's there because I use it extensively for drawing and sketching out ideas, regardless of what I'm working on or where I am. This tool is integral to my creative process across various disciplines, allowing me to visually conceptualize and map out my thoughts in any setting.

setup-drawing.png

I also have a specialized setup exclusively for drawing, featuring a Wacom Cintiq Pro and two dedicated laptops: a MacBook Pro and a Microsoft Surface Studio. This arrangement is tailored specifically for digital art and design work.

Every one of my setups, whether highlighted in this discussion or not, boasts advanced audio and video capturing and streaming features. The dual-layered units you see are Atomos Shogun Studio 2s, and I have four of these distributed throughout my house. Their presence underscores my ability to produce video or audio content at a moment's notice, from any location in my home. This integration of content creation tools ensures a seamless blend of productivity and creativity in my daily workflow.

My prolific YouTube content creation, which included over 3000 videos with numerous gameplays, in a relatively short period, was propelled by this seamless and efficient setup. I have since ceased that endeavor and made the videos private, a decision driven by personal reasons. Hence, I recommend refraining from trying to search for these videos.

This consistent environment allows me to maintain the same level of efficiency and productivity, no matter the task at hand. The key is the seamless integration of my tools and resources, ensuring that I can focus on the creative process without the interruption of window switching. It's a setup that maximizes productivity by keeping everything I need within easy reach.

two-mac-studios.png

I take this approach a step further by assigning specific functions to each component of my setup. Take, for instance, my two Mac Studios. The one on the left is my M2 powerhouse, handling the bulk of my primary tasks. On the right, the Mac Studio M1 Ultra serves a different purpose, functioning as a server. It smoothly handles a variety of tasks like vector storage, managing Docker images, and running a Roon music server, among others. This arrangement ensures a frictionless workflow where each machine excels in its designated role.

setup-windows.png

Certainly, Windows machines are an integral part of my setup. I run a couple of CUDA-capable Windows computers with RTX-3090 and RTX-4090, each assigned their specific tasks and roles.

setup-laptop1.png

setup-laptop2.png

Just as with my desktops, my approach to laptops follows a similar philosophy. Each laptop is tailored for specific functions, ensuring optimal use and productivity.

setup-network.png

My entire setup is seamlessly interconnected through six 10G 24-port switches distributed throughout my house, complemented by a sizeable NAS for extensive storage needs.

setup-intel-macpros.jpeg

I've always been eager to invest in the latest technology. Case in point: I owned two nearly maxed-out Intel Mac Pros, one for my office and the other for my music studio. Both had no spare extension slots left. These have since been replaced by the current Mac Studios. Remarkably, one of these Mac Pros is still eligible for Apple Care, underscoring their relatively recent acquisition.

Yet, they're somewhat of a relic in my current tech landscape.

setup-trashcan-macpro.jpeg

Do you recall the 'trash can' Mac Pro?

These machines were incredibly valuable to me, performing splendidly over a significant period. But as with all technology, there comes a time for an upgrade. Parting with these old workhorses is bittersweet, but necessary for staying on the cutting edge.

This elaborate setup, encompassing multiple devices and advanced connectivity, is all utilized by just one person: me. I'm not a company; I'm an individual who has greatly amplified productivity. This system has been instrumental in saving significant amounts of time and money in the long run, enabling me to undertake tasks that would otherwise be unfeasible.

My experience over the years has taught me an important lesson: prioritizing short-term savings in time and money is a pitfall. It's crucial to adopt a long-term perspective. The initial investment might seem substantial, but the efficiency and capabilities it unlocks are invaluable. It's about making strategic choices that pay off exponentially in the future.

Regarding the investment in such a setup, some may consider it excessive. However, that view misses the bigger picture. The efficiency and time savings I've realized through this system more than justify the initial investment. In the long run, this approach has proven to be both time and cost-effective.

When people ask about my rapid and efficient work style, my response is straightforward: "It's all about the setup. Keep the setups static and let yourself be the one to adapt and move. That's the key." This strategy of assigning specific roles to each piece of equipment and adapting my workflow around them is foundational to my productivity.

Finally, it's important to note that with these extensive setups comes a wealth of knowledge and experience gained from the work involved. This is what I refer to as the 'Creative Works of Knowledge' – a rich, hands-on understanding that's as valuable as the technological investments themselves.

setup-cables.png

Pause for a moment to ponder how a single individual can personally set up such an array of devices. Just considering the multitude of cables, each with their own specific specs and types, is quite astounding. The complexity involved in managing and organizing these connections alone is a feat in itself.

setup-networking1.png

setup-networking2.png

setup-networking3.png

During the remodeling of my house, I took on the responsibility of setting up all the networking, along with some of the electrical and audio cabling for the final arrangement. The electricians involved were quite perplexed by the complexity of the setup, to the point where I had to handle much of it myself.

When I undertook these tasks, there were no GPTs available to assist me.

What you've seen so far is just a glimpse of my tech nerd side. There's so much more to me, especially when you consider my involvement with audio and video technologies. My areas of expertise extend far beyond what might be apparent at first glance. Regardless of the domain, my perspective and approach remain consistent: object orientation is the key. This philosophy guides me across all my diverse interests and activities.

This is precisely why my approach is both fast and furious, not just in my work but also in how I learn. My experiences with such intricate setups have honed this rapid, intensive style.

Consider this: My journey into AI didn't begin with a deep well of knowledge. Remarkably, it took just about six months to cultivate my current understanding. So, before you think my investment approach is excessive, pause and reflect on this rapid learning curve. It highlights the potential to grasp complex subjects swiftly with focused effort and the right strategy. This journey is a prime example of the power of object-oriented thinking and the Zen of Smart Effort in action.

When you believe you're conserving time and money, there's a possibility you're actually expending more in the long run. It's vital to consider, "In a day that spans 24 hours times n, what's the value of my n?"

# Define the average expected lifetime in years
your_lifetime_years = 80

# Calculate the absolute time in hours over your lifetime
# Assuming 24 hours per day, 365 days per year
your_absolute_time_hours = 24 * 365 * your_lifetime_years

# Define 'n' as a factor representing your efficiency or productivity multiplier
n = 1 # assign a value based on your efficiency, normally less than 1

# Calculate the relative time in hours over your lifetime, enhanced by 'n'
your_relative_time_hours = 24 * n * 365 * your_lifetime_years

This introspection helps gauge the actual efficiency and long-term consequences of your choices. Here's the reality: The majority have a relatively small n, often less than 1. Only a select few manage to achieve a larger n, and even fewer reach an exponentially greater n. This distinction highlights the importance of strategic planning and foresight in maximizing productivity and resource utilization.

Object-Orientation-Made-Easy.md

The-Perils-of-Rushed-Learning.md

The-Zen-Of-Smart-Effort.md

Before forming any opinions about my approach, I encourage you to reflect and ask yourself, "How large is my 'n'?"

Diffing the MLX Codebases: LLaMa vs. Mistral 7B

Let's diff the codebases of LLaMa and Mistral 7B to see how they compare. This exercise will highlight the similarities and differences between the two models, providing a deeper understanding of their respective architectures.

Class ModelArgs

The difference in the ModelArgs class between LLaMa and Mistral models, though seemingly minor, is quite significant in terms of their internal architecture and performance.

For LLaMa:

class ModelArgs:
    dim: int
    n_layers: int
    head_dim: int
    hidden_dim: int
    n_heads: int
    n_kv_heads: int
    norm_eps: float
    vocab_size: int
    rope_theta: float
    rope_traditional: bool = True

And for Mistral:

class ModelArgs:
    dim: int
    n_layers: int
    head_dim: int
    hidden_dim: int
    n_heads: int
    n_kv_heads: int
    norm_eps: float
    vocab_size: int
    rope_theta: float = 10000

The primary difference lies in the rope_theta parameter. In LLaMa, rope_theta is an adjustable parameter, suggesting it can be set or tweaked based on the specific requirements or configurations of the model. This flexibility might be necessary for LLaMa’s specific architecture or the tasks it is designed to perform.

In contrast, Mistral sets a default value for rope_theta to 10000. This could imply a more optimized or specialized setup tailored to Mistral's unique architecture. The fixed value suggests that the developers of Mistral found a specific setting that works best for this model in most scenarios, possibly related to its unique architectural innovations like grouped-query attention (GQA) or sliding window attention (SWA).

The differences in handling rope_theta reflect each model's distinct focus and optimizations. LLaMa's adjustable approach indicates a more general-purpose design, while Mistral's fixed value points to a specialized, perhaps more efficient configuration. These subtleties are crucial as they can significantly impact the model's performance, especially in handling long sequences and complex tasks.

Class RMSNorm

Identical.

Class Attention

The differences in the Attention class between LLaMa and Mistral models, though subtle, are significant in terms of their respective architectures and functionalities.

For LLaMa:

self.rope = nn.RoPE(
    args.head_dim, traditional=args.rope_traditional, base=args.rope_theta
)

For Mistral:

self.rope = nn.RoPE(args.head_dim, traditional=True, base=args.rope_theta)

In LLaMa's implementation, rope_traditional is a parameter of the ModelArgs class, allowing flexibility in how the RoPE (Rotary Positional Embeddings) is implemented. This flexibility can be crucial for optimizing the model's performance based on specific tasks or datasets.

On the other hand, Mistral's implementation has traditional hardcoded as True in the Attention class, indicating a more fixed approach to using RoPE. This could suggest that the Mistral model is optimized with a specific configuration of RoPE that the developers found most effective for its architecture and intended use cases.

The difference lies in the level of customization each model allows. LLaMa's approach offers more adaptability, potentially making it suitable for a wider range of tasks, while Mistral's approach might be tailored for specific scenarios where the chosen RoPE configuration yields the best results.

This distinction in the implementation of RoPE between the two models highlights differing philosophies in model design - one favoring flexibility and the other optimization for specific conditions. Understanding these nuances is key to comprehending the unique strengths and applications of each model.

he distinction between LLaMa and Mistral in the __call__ method of the Attention class lies in the method's return type.

In LLaMa:

def __call__(
    self,
    x: mx.array,
    mask: Optional[mx.array] = None,
    cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> Tuple[mx.array, Tuple[mx.array, mx.array]]:
    B, L, D = x.shape
    # ...rest of the code...

In Mistral:

def __call__(
    self,
    x: mx.array,
    mask: Optional[mx.array] = None,
    cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> mx.array:
    B, L, D = x.shape
    # ...rest of the code...

The key difference is in the return type signature of the __call__ method. In LLaMa, the method returns a tuple, where the first element is the output of the attention computation and the second element is another tuple containing the keys and values. This structure suggests that LLaMa's attention mechanism not only computes the output but also keeps track of the keys and values for potential use later in the model, such as in caching mechanisms.

On the other hand, Mistral's implementation returns only the output of the attention computation as a single array, without returning the keys and values. This indicates a streamlined approach, focusing solely on the immediate output of the attention mechanism without the need to return additional data structures for later use.

This difference in the return type is significant as it reflects the underlying architectural choices and data flow management within each model. LLaMa's approach suggests a design that allows for more flexibility and potential reuse of intermediate computations, whereas Mistral's approach indicates a more focused and immediate processing of inputs.

Class Mistral

The main architectures of the LLaMa and Mistral models, both built on the transformer architecture, have subtle but significant differences.

For LLaMa:

class Llama(nn.Module):
    # Initialization with embeddings, transformer blocks, normalization, and output layer
    # Standard forward call with a mask for additive causal attention
    # A generate function for sampling tokens, using an additive causal mask and caching mechanism

For Mistral:

class Mistral(nn.Module):
    # Similar initialization with embeddings, transformer blocks, normalization, and output layer
    # The forward call (__call__) includes a conditional for creating a mask
    # Cache is initialized if not provided and updated with each layer
    # The output also returns the cache

Key Differences:

  1. Cache Management: Mistral explicitly handles cache within the __call__ method, initializing it if not provided and updating it with each layer. This contrasts with LLaMa, which uses caching primarily in the generate method.

  2. Mask Creation: In Mistral, the creation of the mask is conditional, depending on the sequence length (h.shape[1] > 1), suggesting an optimization for handling sequences of varying lengths.

  3. Output: Mistral's __call__ method returns both the output and the cache, whereas LLaMa's __call__ method only returns the output. This difference indicates that Mistral may utilize the cache information in subsequent operations or for specific use cases, such as continued generation or maintaining context over longer sequences.

  4. Generate Function: LLaMa includes a generate method for token generation, which incorporates sampling and caching. This method is particularly useful for generative tasks, like text generation, where subsequent tokens depend on previously generated ones. Mistral's architecture, as presented, does not include a similar method within the class, suggesting a focus on different types of tasks or a separate handling of generative processes.

These architectural distinctions between LLaMa and Mistral highlight different approaches to handling sequences, caching, and generative tasks, reflecting each model's specific design goals and optimization strategies.

Class Tokenizer

Unlike Llama MLX implementation, Tokenizer in Mistral MLX implementation is fully implemented in a separate class. This class in the MLX framework is a Tokenizer class, designed to handle the conversion of text data into a format suitable for processing by machine learning models, particularly language models.

  1. Constructor __init__:

    • Takes model_path as an argument, which is the path to the tokenizer model file.
    • Checks if the model file exists at the given path (assert Path(model_path).exists()).
    • Initializes a SentencePieceProcessor with the model file. SentencePiece is a library that provides subword tokenization, which is critical for handling languages that do not have clear word boundaries and for reducing the vocabulary size of the tokenizer.
    • Sets a separator character _sep which is used in SentencePiece tokenization.
    • Ensures that the vocabulary size of the SentencePiece model matches the number of pieces it can generate.
  2. Properties:

    • eos_id: Returns the end-of-sequence (EOS) token's ID. The EOS token is used to signify the end of a text sequence.
    • pad_id: Returns the padding token's ID. Padding tokens are used to fill up sequences to a uniform length when batching together multiple sequences.
  3. Method encode:

    • Takes a string s as input.
    • Encodes the input string into a list of integers using the SentencePiece model. Each integer represents a token (or subword unit) in the SentencePiece vocabulary.
    • Prepends the beginning-of-sequence (BOS) token's ID to the encoded list. The BOS token signifies the start of a text sequence.
  4. Method decode:

    • Takes a list of integers t representing token IDs.
    • Decodes the list of token IDs back into a string using the SentencePiece model.
    • If the first token is a separator token (as indicated by _sep), a space is added at the beginning of the output string. This is to maintain formatting consistency, as SentencePiece often uses the separator character to denote the start of new words.

In essence, this Tokenizer class serves as an interface for converting between raw text and a sequence of token IDs, which is a standard preprocessing step in NLP tasks. The use of SentencePiece allows for effective and flexible tokenization across various languages and text formats.

Key Features of Mistral 7B Implemented in the MLX Codebase

The features of Grouped-Query Attention (GQA), Sliding Window Attention (SWA), and Rolling Buffer Cache in the Mistral model are implemented within its architecture, specifically in the Attention and TransformerBlock classes. Let's break down how these features are integrated:

  1. Grouped-Query Attention (GQA):

    • GQA is likely implemented in the Attention class. The Attention class manages the queries (self.wq), keys (self.wk), and values (self.wv) for the attention mechanism.
    • The repeats variable (self.repeats = self.n_heads // self.n_kv_heads) suggests a division of attention heads into groups, aligning with the concept of GQA. This allows for a more efficient processing of attention by grouping queries, which can accelerate inference speed and reduce memory requirements.
  2. Sliding Window Attention (SWA):

    • The implementation of SWA is not explicitly visible in the provided code snippet. However, it would typically be integrated into the attention mechanism, likely within the calculation of scores in the Attention class.
    • SWA allows each token to attend to a specific window of tokens from the previous layer, thus reducing computational cost while handling longer sequences effectively.
  3. Rolling Buffer Cache:

    • The Rolling Buffer Cache is likely involved in the caching mechanism within the Attention class and the handling of cache in the TransformerBlock class.
    • In the Attention class, the presence of a cache parameter and its manipulation (if cache is not None) hint at a caching mechanism. The Rolling Buffer Cache would typically store the keys and values for each timestep in a fixed-size buffer, overwriting past values when the buffer is full, which is efficient for long sequences.

Missing Puzzle Pieces in the MLX Codebase

To explore Mistral's original implementation in PyTorch, you can review their codebase available on GitHub.

https://github.com/mistralai/mistral-src/blob/main/mistral/

In model.py:

import json
import logging
import math
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional

import torch
from torch import nn
from simple_parsing.helpers import Serializable

from mistral.rope import precompute_freqs_cis, apply_rotary_emb
from mistral.cache import CacheView, RotatingBufferCache
from mistral.moe import MoeArgs, MoeLayer

from xformers.ops.fmha import memory_efficient_attention


@dataclass
class ModelArgs(Serializable):
    dim: int
    n_layers: int
    head_dim: int
    hidden_dim: int
    n_heads: int
    n_kv_heads: int
    norm_eps: float
    vocab_size: int

    max_batch_size: int = 0

    # For rotary embeddings. If not set, will be infered from sliding window.
    rope_theta: Optional[float] = None
    # If this is set, use sliding window attention rotating cache.
    sliding_window: Optional[int] = None
    # If this is set, we will use MoE layers instead of dense layers.
    moe: Optional[MoeArgs] = None


@dataclass
class SimpleInputMetadata:
    # rope absolute positions
    positions: torch.Tensor

    @staticmethod
    def from_seqlens(seqlens: List[int], device: torch.device) -> "SimpleInputMetadata":
        return SimpleInputMetadata(
            positions=torch.cat([torch.arange(0, seqlen) for seqlen in seqlens]).to(
                device=device, dtype=torch.long
            )
        )


def repeat_kv(keys: torch.Tensor, values: torch.Tensor, repeats: int, dim: int):
    keys = torch.repeat_interleave(keys, repeats=repeats, dim=dim)
    values = torch.repeat_interleave(values, repeats=repeats, dim=dim)
    return keys, values


class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.args = args

        self.n_heads: int = args.n_heads
        self.head_dim: int = args.head_dim
        self.n_kv_heads: int = args.n_kv_heads

        self.repeats = self.n_heads // self.n_kv_heads

        self.scale = self.args.head_dim**-0.5

        self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False)
        self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
        self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
        self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)

    def forward(
        self,
        x: torch.Tensor,
        freqs_cis: torch.Tensor,
        cache: Optional[CacheView],
    ) -> torch.Tensor:
        seqlen_sum, _ = x.shape

        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
        xq = xq.view(seqlen_sum, self.n_heads, self.head_dim)
        xk = xk.view(seqlen_sum, self.n_kv_heads, self.head_dim)
        xv = xv.view(seqlen_sum, self.n_kv_heads, self.head_dim)
        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)

        if cache is None:
            key, val = xk, xv
        elif cache.prefill:
            key, val = cache.interleave_kv(xk, xv)
            cache.update(xk, xv)
        else:
            cache.update(xk, xv)
            key, val = cache.key, cache.value
            key = key.view(
                seqlen_sum * cache.sliding_window, self.n_kv_heads, self.head_dim
            )
            val = val.view(
                seqlen_sum * cache.sliding_window, self.n_kv_heads, self.head_dim
            )

        # Repeat keys and values to match number of query heads
        key, val = repeat_kv(key, val, self.repeats, dim=1)

        # xformers requires (B=1, S, H, D)
        xq, key, val = xq[None, ...], key[None, ...], val[None, ...]
        output = memory_efficient_attention(
            xq, key, val, None if cache is None else cache.mask
        )

        return self.wo(output.view(seqlen_sum, self.n_heads * self.head_dim))


class FeedForward(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()

        self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False)
        self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False)
        self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False)

    def forward(self, x) -> torch.Tensor:
        return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))


class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight


class TransformerBlock(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.attention = Attention(args)
        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.args = args

        self.feed_forward: nn.Module
        if args.moe is not None:
            self.feed_forward = MoeLayer(
                experts=[FeedForward(args=args) for _ in range(args.moe.num_experts)],
                gate=nn.Linear(args.dim, args.moe.num_experts, bias=False),
                moe_args=args.moe,
            )
        else:
            self.feed_forward = FeedForward(args=args)

    def forward(
        self, x: torch.Tensor, freqs_cis: torch.Tensor, cache: Optional[CacheView]
    ) -> torch.Tensor:
        r = self.attention.forward(self.attention_norm(x), freqs_cis, cache)
        h = x + r
        r = self.feed_forward.forward(self.ffn_norm(h))
        out = h + r
        return out


class Transformer(nn.Module):
    def __init__(
        self,
        args: ModelArgs,
        pipeline_rank: int = 0,
        num_pipeline_ranks: int = 1,
    ):
        super().__init__()
        self.args = args
        self.vocab_size = args.vocab_size
        self.n_layers = args.n_layers
        self._precomputed_freqs_cis: Optional[torch.Tensor] = None
        assert self.vocab_size > 0
        assert pipeline_rank < num_pipeline_ranks, (pipeline_rank, num_pipeline_ranks)
        self.pipeline_rank = pipeline_rank
        self.num_pipeline_ranks = num_pipeline_ranks
        # Modules specific to some ranks:
        self.tok_embeddings: Optional[nn.Embedding] = None
        self.norm: Optional[RMSNorm] = None
        self.output: Optional[nn.Linear] = None
        if pipeline_rank == 0:
            self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
        if pipeline_rank == num_pipeline_ranks - 1:
            self.norm = RMSNorm(args.dim, eps=args.norm_eps)
            self.output = nn.Linear(args.dim, args.vocab_size, bias=False)
        # Initialize all layers but slice off those not of this rank.
        layers = [TransformerBlock(args=args) for _ in range(args.n_layers)]
        num_layers_per_rank = math.ceil(self.n_layers / self.num_pipeline_ranks)
        offset = self.pipeline_rank * num_layers_per_rank
        end = min(self.n_layers, offset + num_layers_per_rank)
        self.layers = nn.ModuleDict({str(i): layers[i] for i in range(offset, end)})
        self.n_local_layers = len(self.layers)

    @property
    def dtype(self) -> torch.dtype:
        return next(self.parameters()).dtype

    @property
    def device(self) -> torch.device:
        return next(self.parameters()).device

    @property
    def freqs_cis(self) -> torch.Tensor:
        # We cache freqs_cis but need to take care that it is on the right device
        # and has the right dtype (complex64). The fact that the dtype is different
        # from the module's  dtype means we cannot register it as a buffer
        if self._precomputed_freqs_cis is None:
            # If no sliding window, assume a larger seqlen
            theta = self.args.rope_theta
            if theta is None:
                theta = 1000000.0 if self.args.sliding_window is None else 10000.0
            # theta = 10000.
            self._precomputed_freqs_cis = precompute_freqs_cis(
                self.args.head_dim, 128_000, theta
            )
        if self._precomputed_freqs_cis.device != self.device:
            self._precomputed_freqs_cis = self._precomputed_freqs_cis.to(
                device=self.device
            )
        return self._precomputed_freqs_cis

    def forward_partial(
        self,
        input_ids: torch.Tensor,
        seqlens: List[int],
        cache: Optional[RotatingBufferCache] = None,
    ) -> torch.Tensor:
        """Local forward pass.

        If doing pipeline parallelism, this will return the activations of the last layer of this stage.
        For the last stage, this will return the normalized final embeddings.
        """
        assert (
            len(seqlens) <= self.args.max_batch_size
        ), f"Max batch size is {self.args.max_batch_size}, got batch size of {len(seqlens)}"
        (num_toks,) = input_ids.shape
        assert sum(seqlens) == num_toks, (sum(seqlens), num_toks)
        if cache is not None:
            input_metadata = cache.get_input_metadata(seqlens)
        else:
            input_metadata = SimpleInputMetadata.from_seqlens(seqlens, self.device)

        if self.pipeline_rank == 0:
            assert self.tok_embeddings is not None
            h = self.tok_embeddings(input_ids)
        else:
            h = torch.empty(
                num_toks, self.args.dim, device=self.device, dtype=self.dtype
            )
            torch.distributed.recv(h, src=self.pipeline_rank - 1)

        freqs_cis = self.freqs_cis[input_metadata.positions]

        for local_layer_id, layer in enumerate(self.layers.values()):
            if cache is not None:
                assert input_metadata is not None
                cache_view = cache.get_view(local_layer_id, input_metadata)
            else:
                cache_view = None
            h = layer(h, freqs_cis, cache_view)

        if cache is not None:
            cache.update_seqlens(seqlens)
        if self.pipeline_rank < self.num_pipeline_ranks - 1:
            torch.distributed.send(h, dst=self.pipeline_rank + 1)
            return h
        else:
            # Last rank has a final normalization step.
            assert self.norm is not None
            return self.norm(h)

    def forward(
        self,
        input_ids: torch.Tensor,
        seqlens: List[int],
        cache: Optional[RotatingBufferCache] = None,
    ) -> torch.Tensor:
        h = self.forward_partial(input_ids, seqlens, cache=cache)
        if self.pipeline_rank < self.num_pipeline_ranks - 1:
            # ignore the intermediate activations as we'll get the final output from
            # the last stage
            outs = torch.empty(
                h.shape[0], self.vocab_size, device=h.device, dtype=h.dtype
            )
        else:
            assert self.output is not None
            outs = self.output(h)
        if self.num_pipeline_ranks > 1:
            torch.distributed.broadcast(outs, src=self.num_pipeline_ranks - 1)
        return outs.float()

    def load_state_dict(self, state_dict, *args, **kwargs):
        state_to_load = {}
        skipped = set([])
        for k, v in state_dict.items():
            if k.startswith("tok_embeddings"):
                if self.pipeline_rank == 0:
                    state_to_load[k] = v
                else:
                    logging.debug(
                        "Skipping parameter %s at pipeline rank %d",
                        k,
                        self.pipeline_rank,
                    )
                    skipped.add(k)
            elif k.startswith("norm") or k.startswith("output"):
                if self.pipeline_rank == self.num_pipeline_ranks - 1:
                    state_to_load[k] = v
                else:
                    logging.debug(
                        "Skipping parameter %s at pipeline rank %d",
                        k,
                        self.pipeline_rank,
                    )
                    skipped.add(k)
            elif k.startswith("layers"):
                layer_id = k.split(".")[1]
                if layer_id in self.layers:
                    state_to_load[k] = v
                else:
                    logging.debug(
                        "Skipping parameter %s at pipeline rank %d",
                        k,
                        self.pipeline_rank,
                    )
                    skipped.add(k)
            else:
                raise ValueError(f"Unexpected key {k}")
        assert set(state_dict.keys()) == skipped.union(set(state_to_load.keys()))
        super().load_state_dict(state_to_load, *args, **kwargs)

    @staticmethod
    def from_folder(
        folder: Path,
        max_batch_size: int = 1,
        num_pipeline_ranks: int = 1,
        device="cuda",
        dtype=torch.float16,
    ) -> "Transformer":
        with open(folder / "params.json", "r") as f:
            model_args = ModelArgs.from_dict(json.load(f))
        model_args.max_batch_size = max_batch_size
        if num_pipeline_ranks > 1:
            pipeline_rank = torch.distributed.get_rank()
        else:
            pipeline_rank = 0
        with torch.device("meta"):
            model = Transformer(
                model_args,
                pipeline_rank=pipeline_rank,
                num_pipeline_ranks=num_pipeline_ranks,
            )
        loaded = torch.load(str(folder / "consolidated.00.pth"), mmap=True)
        model.load_state_dict(loaded, assign=True)
        return model.to(device=device, dtype=dtype)

Comparing the original PyTorch implementation of the Mistral model with the MLX implementation, we can observe how the key features—Grouped-Query Attention (GQA), Sliding Window Attention (SWA), and Rolling Buffer Cache—are incorporated in each:

  1. Grouped-Query Attention (GQA):

    • In PyTorch: The Attention class in PyTorch uses repeat_interleave to repeat the keys and values to match the number of query heads (repeat_kv function), which aligns with the GQA concept. The division of attention heads into groups (self.repeats = self.n_heads // self.n_kv_heads) is also present, similar to the MLX version.
    • In MLX: The MLX implementation includes a similar division of attention heads into groups, indicating that GQA is also implemented in the MLX version.
  2. Sliding Window Attention (SWA):

    • In PyTorch: SWA is explicitly mentioned (self.args.sliding_window). The attention mechanism (forward method in Attention class) is likely configured to handle SWA, though the detailed implementation is not fully visible in the provided snippet.
    • In MLX: The implementation details of SWA in MLX are not explicitly visible in the provided code snippet. However, it would typically be integrated into the attention mechanism, similar to the PyTorch version.
  3. Rolling Buffer Cache:

    • In PyTorch: The RotatingBufferCache class is directly referenced, indicating the use of a rotating buffer cache mechanism. This is a more explicit mention compared to the MLX version.
    • In MLX: The caching mechanism is likely to be handled within the Attention class and TransformerBlock class, though the explicit mention of a rolling buffer cache is not seen in the provided MLX code snippet.

Overall, while the fundamental concepts of GQA, SWA, and Rolling Buffer Cache are present in both the PyTorch and MLX implementations of the Mistral model, the PyTorch implementation provides more explicit references to these features. The MLX implementation, on the other hand, seems to handle these concepts implicitly within the architecture, which may be due to the differences in the programming frameworks and their respective approaches to defining and executing model architecture.

So, let's dive deeper into the original codebase.

In cache.py:

import torch
from typing import List, Tuple
from dataclasses import dataclass

from xformers.ops.fmha.attn_bias import (
    AttentionBias,
    BlockDiagonalCausalMask,
    BlockDiagonalCausalWithOffsetPaddedKeysMask,
    BlockDiagonalMask,
)


@dataclass
class RotatingCacheInputMetadata:
    # rope absolute positions
    positions: torch.Tensor
    # which elements in the sequences need to be cached
    to_cache_mask: torch.Tensor
    # how many elements are cached per sequence
    cached_elements: torch.Tensor
    # where tokens should go in the cache
    cache_positions: torch.Tensor

    # if prefill, use block diagonal causal mask
    # else use causal with padded key mask
    prefill: bool
    mask: AttentionBias
    seqlens: List[int]


def interleave_list(l1: List[torch.Tensor], l2: List[torch.Tensor]):
    assert len(l1) == len(l2)
    return [v for pair in zip(l1, l2) for v in pair]


def unrotate(cache: torch.Tensor, seqlen: int) -> torch.Tensor:
    assert cache.ndim == 3  # (W, H, D)
    position = seqlen % cache.shape[0]
    if seqlen < cache.shape[0]:
        return cache[:seqlen]
    elif position == 0:
        return cache
    else:
        return torch.cat([cache[position:], cache[:position]], dim=0)


class CacheView:
    def __init__(self, cache_k: torch.Tensor, cache_v: torch.Tensor, metadata: RotatingCacheInputMetadata, kv_seqlens: torch.Tensor):
        self.cache_k = cache_k
        self.cache_v = cache_v
        self.kv_seqlens = kv_seqlens
        self.metadata = metadata

    def update(self, xk: torch.Tensor, xv: torch.Tensor):
        """
        to_cache_mask masks the last [sliding_window] tokens in each sequence
        """
        n_kv_heads, head_dim = self.cache_k.shape[-2:]
        flat_cache_k = self.cache_k.view(-1, n_kv_heads, head_dim)
        flat_cache_v = self.cache_v.view(-1, n_kv_heads, head_dim)
        
        flat_cache_k.index_copy_(0, self.metadata.cache_positions, xk[self.metadata.to_cache_mask])
        flat_cache_v.index_copy_(0, self.metadata.cache_positions, xv[self.metadata.to_cache_mask])

    def interleave_kv(self, xk: torch.Tensor, xv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        This is a naive implementation and not optimized for speed.
        """
        assert xk.ndim == xv.ndim == 3 # (B * T, H, D)
        assert xk.shape == xv.shape

        if all([s == 0 for s in self.metadata.seqlens]):
            # No cache to interleave
            return xk, xv

        # Make it a list of [(T, H, D)]
        xk = torch.split(xk, self.metadata.seqlens)
        xv = torch.split(xv, self.metadata.seqlens)
        assert len(xk) == len(self.kv_seqlens), f"Batch size is {len(self.kv_seqlens)}, got {len(xk)}"

        # Order elements in cache by position by unrotating
        cache_k = [unrotate(t, s) for t, s in zip(self.cache_k, self.kv_seqlens)]
        cache_v = [unrotate(t, s) for t, s in zip(self.cache_v, self.kv_seqlens)]

        interleaved_k = interleave_list(cache_k, xk)
        interleaved_v = interleave_list(cache_v, xv)

        return torch.cat(interleaved_k, dim=0), torch.cat(interleaved_v, dim=0)

    @property
    def sliding_window(self):
        return self.cache_k.shape[1]

    @property
    def key(self) -> torch.Tensor:
        return self.cache_k[:len(self.kv_seqlens)]

    @property
    def value(self) -> torch.Tensor:
        return self.cache_v[:len(self.kv_seqlens)]

    @property
    def prefill(self):
        return self.metadata.prefill

    @property
    def mask(self):
        return self.metadata.mask


class RotatingBufferCache:
    """
    This is an example that implements a less naive rotating buffer cache, allowing for variable length sequences.
    Allocated cache is rectangular which is wasteful (see PagedAttention for better mechanisms)
    """
    def __init__(self, n_layers: int, max_batch_size: int, sliding_window: int, n_kv_heads: int, head_dim: int):

        self.sliding_window = sliding_window
        self.n_kv_heads = n_kv_heads
        self.head_dim = head_dim

        self.cache_k = torch.empty((
            n_layers,
            max_batch_size,
            sliding_window,
            n_kv_heads,
            head_dim
        ))
        self.cache_v = torch.empty((
            n_layers,
            max_batch_size,
            sliding_window,
            n_kv_heads,
            head_dim
        ))
        # holds the valid length for each batch element in the cache
        self.kv_seqlens = None

    def get_view(self, layer_id: int, metadata: RotatingCacheInputMetadata) -> CacheView:
        return CacheView(self.cache_k[layer_id], self.cache_v[layer_id], metadata, self.kv_seqlens)

    def reset(self):
        self.kv_seqlens = None

    def init_kvseqlens(self, batch_size: int):
        self.kv_seqlens = torch.zeros((batch_size,), device=self.device, dtype=torch.long)

    @property
    def device(self):
        return self.cache_k.device

    def to(self, device: torch.device, dtype: torch.dtype):
        self.cache_k = self.cache_k.to(device=device, dtype=dtype)
        self.cache_v = self.cache_v.to(device=device, dtype=dtype)

        return self

    def update_seqlens(self, seqlens: List[int]):
        self.kv_seqlens += torch.tensor(seqlens, device=self.device, dtype=torch.long)

    def get_input_metadata(self, seqlens: List[int]) -> RotatingCacheInputMetadata:
        """
            inpput = seqlens [5,7,2] // seqpos [0, 1, 3] // sliding_window 3
            --> only cache last 3 tokens in each sequence
            - to_cache_mask = [0 0 1 1 1 | 0 0 0 0 1 1 1 | 1 1]
            - cached_elements = [3 | 3 | 2]
            --> absolute positions are used for rope
            - positions = [0 1 2 3 4 | 1 2 3 4 5 6 7 | 3 4]
            --> cache positions are positions cache_masked, modulo sliding_window + batch_idx * sliding_window
            - cache_positions = [2 0 1 | 5 3 4 | 6 7]
        """
        if self.kv_seqlens is None:
            self.init_kvseqlens(len(seqlens))
        assert len(seqlens) == len(self.kv_seqlens), f"Batch size is {len(self.kv_seqlens)}, got {len(seqlens)}, did you forget to reset cache?"
        seqpos = self.kv_seqlens.tolist()

        assert len(seqlens) > 0, seqlens
        masks = [
            [x >= seqlen - self.sliding_window for x in range(seqlen)]
            for seqlen in seqlens
        ]
        to_cache_mask = torch.tensor(sum(masks, []), device=self.device, dtype=torch.bool)
        cached_elements = torch.tensor([sum(mask) for mask in masks], device=self.device, dtype=torch.long)
        positions = torch.cat([torch.arange(pos, pos + seqlen) for pos, seqlen in zip(seqpos, seqlens)]).to(device=self.device, dtype=torch.long)
        batch_idx = torch.tensor(sum([[i]*seqlen for i, seqlen in enumerate(seqlens)], []), device=self.device, dtype=torch.long)
        cache_positions = positions % self.sliding_window + batch_idx * self.sliding_window

        first_prefill = seqpos[0] == 0
        subsequent_prefill = any(seqlen > 1 for seqlen in seqlens)
        if first_prefill:
            assert all([pos == 0 for pos in seqpos]), (seqpos)
            mask = BlockDiagonalCausalMask.from_seqlens(seqlens).make_local_attention(self.sliding_window)
        elif subsequent_prefill:
            mask = BlockDiagonalMask.from_seqlens(
                q_seqlen=seqlens,
                kv_seqlen=[s + cached_s.clamp(max=self.sliding_window).item() for (s, cached_s) in zip(seqlens, self.kv_seqlens)]
            ).make_local_attention_from_bottomright(self.sliding_window)
        else:
            mask = BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens(
                q_seqlen=seqlens,
                kv_padding=self.sliding_window,
                kv_seqlen=(self.kv_seqlens + cached_elements).clamp(max=self.sliding_window).tolist()
            )

        return RotatingCacheInputMetadata(
            positions=positions,
            to_cache_mask=to_cache_mask,
            cached_elements=cached_elements,
            cache_positions=cache_positions[to_cache_mask],
            prefill=first_prefill or subsequent_prefill,
            mask=mask,
            seqlens=seqlens,
        )

The cache.py file in Mistral's original PyTorch implementation manages memory and computational efficiency through several specialized structures and functions. Let's break down the key components and their purposes:

  1. RotatingCacheInputMetadata:

    • This data class holds metadata required for managing the cache during model inference. It includes:
      • positions: Absolute positions of tokens for rotary embeddings.
      • to_cache_mask: Indicates which elements in the sequences need to be cached.
      • cached_elements: The count of elements cached per sequence.
      • cache_positions: Specifies where tokens should be stored in the cache.
      • prefill: Indicates if the cache is being pre-filled.
      • mask: An AttentionBias object for managing attention masks.
      • seqlens: The lengths of the input sequences.
  2. CacheView:

    • This class provides a view into the cache for a specific layer of the model. It facilitates updating and accessing the cache efficiently.
    • The update method updates the cache with new key (xk) and value (xv) tensors.
    • interleave_kv method interleaves the keys and values from the cache and the current input, preparing them for the attention computation.
  3. RotatingBufferCache:

    • This class implements a rotating buffer cache that allows handling variable-length sequences more efficiently.
    • The cache has a fixed size (defined by sliding_window), and older values are overwritten as new ones are added, reducing memory usage.
    • The cache is structured to support multiple layers (n_layers), a maximum batch size (max_batch_size), and dimensions for key-value heads (n_kv_heads, head_dim).
  4. Interleave and Unrotate Functions:

    • interleave_list combines two lists of tensors (l1 and l2) by interleaving their elements. This is used to merge cache and input tensors for attention computation.
    • unrotate rearranges elements in the cache to their correct positions based on sequence length, addressing the rotating nature of the cache.

These components work together to optimize the handling of attention in the Mistral model. The rotating buffer cache and interleaving mechanisms allow the model to efficiently manage memory and computational resources, especially for longer sequences. This design is a practical solution to the challenge of managing large-scale attention computations in modern transformer models.

In moe.py:

This Python code defines a Mixture of Experts (MoE) layer for use in a neural network, a Transformer model we will be looking at next. Let's break down the code to understand its functionality:

  1. MoeArgs Class:

    • A dataclass representing arguments for the MoE layer.
    • num_experts: The total number of expert modules in the MoE layer.
    • num_experts_per_tok: The number of experts that each token in the input batch will be processed by.
  2. MoeLayer Class:

    • A subclass of nn.Module that implements the MoE layer functionality.
    • experts: A list of expert modules (nn.Module). Each expert is a neural network that specializes in processing a specific subset of the data.
    • gate: A neural network module that acts as a gating mechanism to decide which expert should process each token in the input.
    • args: An instance of MoeArgs containing configuration for the MoE layer.
  3. Forward Method:

    • The core of the MoE layer where the actual processing happens.
    • gate_logits = self.gate(inputs): The input is passed through the gating network to determine the logits for expert selection.
    • torch.topk(gate_logits, self.args.num_experts_per_tok): For each token in the input, it selects the top num_experts_per_tok experts based on the gate logits.
    • weights: The logits are converted into probabilities using softmax, which determine how much each expert contributes to the final output.
    • results = torch.zeros_like(inputs): Initializes the output tensor.
    • The for loop iterates over each expert. For each expert, it finds the batch indices where the expert is selected (torch.where(selected_experts == i)). It then applies the expert to the relevant subset of the input (expert(inputs[batch_idx])) and scales the expert's output by the corresponding weights. The results from all experts are aggregated to form the final output.

In summary, the MoE layer allows different parts of the input data to be processed by different experts. The gating mechanism decides which experts are most relevant for each part of the data. This can lead to more efficient and specialized processing, as each expert can be tailored to different types of input characteristics.

The MOE layer is not used in the Mistral model we're dealing with here, but it is used in the larger Mistral model, Mixtral 8x7B.

Comparing the MLX and PyTorch Implementations of the Mistral Model

Now that we've explored the key components of the Mistral model, let's compare the MLX and PyTorch implementations to understand how they differ.

Based on both codebases, it appears that the MLX implementation of the Mistral model is a simplified version compared to the original PyTorch implementation. The differences suggest that some key features and optimizations present in the original PyTorch code may not be fully replicated in the MLX version.

Key observations supporting this conclusion include:

  1. Grouped-Query Attention and Sliding Window Attention: These advanced attention mechanisms are explicitly mentioned in the Mistral model's original documentation and are fundamental to its efficiency and performance. However, the MLX code snippets provided do not explicitly show these features. In the original PyTorch implementation, these might be integrated into the custom attention mechanisms and optimizations like memory_efficient_attention.

  2. Rotating Buffer Cache: This feature is used for efficiently managing memory in the original PyTorch implementation, but there's no clear indication of a similar mechanism in the provided MLX code snippets. The cache management in large language models is crucial for handling long sequences and efficient computation.

  3. Complexity in Cache Management and Attention Mechanics: The original implementation shows a more complex approach to handling cache and attention, potentially offering more flexibility and efficiency. The MLX implementation, while functional, seems to lack this level of complexity.

In load_model():

def load_model(folder: str):
...
        config.pop("sliding_window", None)
..

The load_model function further supports the observation that the MLX implementation of the Mistral model is indeed a simplified version compared to the original PyTorch implementation.

To sum it up, the MLX version of the Mistral model effectively encapsulates its core aspects, albeit in a more streamlined form. Key complexities present in the original PyTorch implementation seem to be omitted, likely as a deliberate choice aimed at enhancing accessibility and aligning with MLX's framework. It's important to note that these implementations are intended as examples to familiarize users with MLX's functionalities, rather than being optimized for production environments. This approach aligns with the objective of providing clear, manageable examples for educational purposes within the MLX context.

Helper Functions for Menny Mistral

All the other helper functions I modified and added are essentially identical.

Menny Mistral - Putting It All Together

The complete code for Menny Mistral uses the following model from mlx/community:

https://huggingface.co/mlx-community/Mistral-7B-Instruct-v0.2/tree/main

But feel free to convert any model you want to use. Just make sure you use an instruct model.

# Copyright © 2023 Apple Inc.

import argparse
import json
import time
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple

import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_unflatten
from sentencepiece import SentencePieceProcessor
import streamlit as st


@dataclass
class ModelArgs:
    dim: int
    n_layers: int
    head_dim: int
    hidden_dim: int
    n_heads: int
    n_kv_heads: int
    norm_eps: float
    vocab_size: int
    rope_theta: float = 10000


class RMSNorm(nn.Module):
    def __init__(self, dims: int, eps: float = 1e-5):
        super().__init__()
        self.weight = mx.ones((dims,))
        self.eps = eps

    def _norm(self, x):
        return x * mx.rsqrt(x.square().mean(-1, keepdims=True) + self.eps)

    def __call__(self, x):
        output = self._norm(x.astype(mx.float32)).astype(x.dtype)
        return self.weight * output


class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.args = args

        self.n_heads: int = args.n_heads
        self.n_kv_heads: int = args.n_kv_heads

        self.repeats = self.n_heads // self.n_kv_heads

        self.scale = self.args.head_dim**-0.5

        self.wq = nn.Linear(args.dim, args.n_heads * args.head_dim, bias=False)
        self.wk = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
        self.wv = nn.Linear(args.dim, args.n_kv_heads * args.head_dim, bias=False)
        self.wo = nn.Linear(args.n_heads * args.head_dim, args.dim, bias=False)
        self.rope = nn.RoPE(args.head_dim, traditional=True, base=args.rope_theta)

    def __call__(
        self,
        x: mx.array,
        mask: Optional[mx.array] = None,
        cache: Optional[Tuple[mx.array, mx.array]] = None,
    ) -> mx.array:
        B, L, D = x.shape

        queries, keys, values = self.wq(x), self.wk(x), self.wv(x)

        # Prepare the queries, keys and values for the attention computation
        queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3)
        keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)
        values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3)

        def repeat(a):
            a = mx.concatenate([mx.expand_dims(a, 2)] * self.repeats, axis=2)
            return a.reshape([B, self.n_heads, L, -1])

        keys, values = map(repeat, (keys, values))

        if cache is not None:
            key_cache, value_cache = cache
            queries = self.rope(queries, offset=key_cache.shape[2])
            keys = self.rope(keys, offset=key_cache.shape[2])
            keys = mx.concatenate([key_cache, keys], axis=2)
            values = mx.concatenate([value_cache, values], axis=2)
        else:
            queries = self.rope(queries)
            keys = self.rope(keys)

        scores = (queries * self.scale) @ keys.transpose(0, 1, 3, 2)
        if mask is not None:
            scores += mask
        scores = mx.softmax(scores.astype(mx.float32), axis=-1).astype(scores.dtype)
        output = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
        return self.wo(output), (keys, values)


class FeedForward(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()

        self.w1 = nn.Linear(args.dim, args.hidden_dim, bias=False)
        self.w2 = nn.Linear(args.hidden_dim, args.dim, bias=False)
        self.w3 = nn.Linear(args.dim, args.hidden_dim, bias=False)

    def __call__(self, x) -> mx.array:
        return self.w2(nn.silu(self.w1(x)) * self.w3(x))


class TransformerBlock(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.attention = Attention(args)
        self.feed_forward = FeedForward(args=args)
        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.args = args

    def __call__(
        self,
        x: mx.array,
        mask: Optional[mx.array] = None,
        cache: Optional[Tuple[mx.array, mx.array]] = None,
    ) -> mx.array:
        r, cache = self.attention(self.attention_norm(x), mask, cache)
        h = x + r
        r = self.feed_forward(self.ffn_norm(h))
        out = h + r
        return out, cache


class Mistral(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.args = args
        self.vocab_size = args.vocab_size
        self.n_layers = args.n_layers
        assert self.vocab_size > 0
        self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
        self.layers = [TransformerBlock(args=args) for _ in range(args.n_layers)]
        self.norm = RMSNorm(args.dim, eps=args.norm_eps)
        self.output = nn.Linear(args.dim, args.vocab_size, bias=False)

    def __call__(
        self,
        inputs: mx.array,
        cache=None,
    ):
        h = self.tok_embeddings(inputs)

        mask = None
        if h.shape[1] > 1:
            mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1])
            mask = mask.astype(h.dtype)

        if cache is None:
            cache = [None] * len(self.layers)

        for e, layer in enumerate(self.layers):
            h, cache[e] = layer(h, mask, cache[e])

        return self.output(self.norm(h)), cache


class Tokenizer:
    def __init__(self, model_path: str):
        assert Path(model_path).exists(), model_path
        self._model = SentencePieceProcessor(model_file=model_path)
        self._sep = "▁"
        assert self._model.vocab_size() == self._model.get_piece_size()

    @property
    def eos_id(self) -> int:
        return self._model.eos_id()

    @property
    def pad_id(self) -> int:
        return self._model.pad_id()

    def encode(self, s: str) -> List[int]:
        return [self._model.bos_id(), *self._model.encode(s)]

    def decode(self, t: List[int]) -> str:
        out = self._model.decode(t)
        if t and self._model.id_to_piece(t[0])[0] == self._sep:
            return " " + out
        return out


def load_model(folder: str):
    model_path = Path(folder)
    tokenizer = Tokenizer(str(model_path / "tokenizer.model"))
    with open(model_path / "config.json", "r") as f:
        config = json.loads(f.read())
        config.pop("sliding_window", None)
        config.pop("model_type", None)
        quantization = config.pop("quantization", None)
        model_args = ModelArgs(**config)
    weights = mx.load(str(model_path / "weights.npz"))
    weights = tree_unflatten(list(weights.items()))
    model = Mistral(model_args)
    if quantization is not None:
        nn.QuantizedLinear.quantize_module(model, **quantization)
    model.update(weights)
    mx.eval(model.parameters())
    return model, tokenizer


def generate(prompt, model, tokenizer, temp=0.7, max_tokens=200, write_every=10):
    x = mx.array([tokenizer.encode(prompt)])
    cache = None
    response_accumulator = ""
    response_part = ""

    for token_index in range(max_tokens):
        logits, cache = model(x, cache)
        y = sample(logits[:, -1, :], temp)
        generated_text = tokenizer.decode([y.item()])

        # Append generated text to the accumulator
        response_accumulator += generated_text

        # Check if it's time to yield part of the response
        if (token_index + 1) % write_every == 0 or token_index == max_tokens - 1:
            response_part = response_accumulator.lstrip().replace("\n", " ")
            yield response_part
            response_accumulator = ""  # Reset accumulator after yielding

        x = y[:, None]  # Update input for next token generation


def sample(logits, temp):
    if temp == 0:
        return mx.argmax(logits, axis=-1)
    else:
        return mx.random.categorical(logits * (1 / temp))



@st.cache_resource
def load_cached_model(model_path):
    return load_model(model_path)


if __name__ == "__main__":
    SEED = 42
    MODEL_PATH = "/Users/wankyuchoi/cwk-llm-models/Mistral-7B-Instruct-v0.2-mlx"
    MAX_TOKENS = 200
    TEMP = 0.7
    WRITE_EVERY = 10
    SYSTEM_MESSAGE = "<<SYS>>Your name is Menny, a cynical teenager AI assistant.<</SYS>>"
    MODEL_NAME = "Menny Mistral"
    MODEL_AVATAR = "./images/menny-avatar.png"
    HUMAN_AVATAR = "./images/human-avatar.png"
    MODEL_IMAGE = "./images/menny-mistral.png"

    mx.random.seed(SEED)
    model, tokenizer = load_cached_model(MODEL_PATH)

    # Streamlit UI setup
    st.sidebar.title("Chatbot Settings")
    max_tokens = st.sidebar.slider("Max Tokens", 50, 500, MAX_TOKENS)
    temp = st.sidebar.slider("Temperature", 0.1, 1.0, TEMP)

    st.title(MODEL_NAME)
    st.image(MODEL_IMAGE, width=500)

    user_input = st.chat_input("Your Message")

    if user_input:
        human_message = st.chat_message('human', avatar=HUMAN_AVATAR)
        human_message.write(user_input)


        full_prompt = SYSTEM_MESSAGE + f"\n\n[INST] {user_input} [/INST]\n"
        full_response = ""

        ai_message = st.chat_message('assistant', avatar=MODEL_AVATAR)

        ai_message.write("Thinking...")
        ai_message_placeholder = st.empty()  # Create a placeholder for AI message

        for response_chunk in generate(full_prompt, model, tokenizer, temp=temp, max_tokens=max_tokens):
            full_response += response_chunk
            ai_message_placeholder.markdown(full_response, unsafe_allow_html=True)  # Update the placeholder content

The implementation of "Menny Mistral," similar to "Menny Llama," showcases how the Mistral model can be used in a chatbot application, much like the Llama model was used before. Here are the key points of comparison and the unique aspects of Menny Mistral:

  1. Similarities with Menny Llama:

    • The overall structure and functioning of Menny Mistral are quite similar to Menny Llama. This includes how the model processes input (tokenization), generates responses, and handles the conversation flow.
    • The underlying principles of machine learning and language model use remain consistent. Both models leverage advanced language understanding and generation capabilities to interact with users.
  2. Differences in Pure Inference Time:

    • While the core functionality is similar, there might be differences in how quickly each model processes information and generates responses, known as "pure inference time." This could be due to the internal optimizations of the Mistral model or differences in how MLX.
  3. Crude Streaming Mechanism:

    • A notable addition in Menny Mistral is the implementation of a crude streaming mechanism. This feature allows the chatbot to show responses to the user as they are being generated, rather than waiting for the entire response to be formulated before displaying it.
    • This streaming aspect enhances user interaction by providing a more dynamic and responsive experience. It can make the chatbot feel more conversational and immediate, as users see responses being built in real-time.
    • The term "crude" here implies that the streaming mechanism is basic or not fully refined. It serves the purpose of demonstrating the capability but might not have the sophistication or polish of a fully developed feature in a production environment.

Overall, Menny Mistral serves as an effective demonstration of how the Mistral model can be adapted for a chatbot application, with an added feature to enhance user experience through streaming responses.

To run the Menny Mistral chatbot, it's crucial to ensure that the MODEL_PATH in your script points to the correct directory where your MLX model is stored.

  1. Set the Model Path:

    • First, locate the folder where your MLX model is saved.
    • Replace "/path/to/your/mlx-model" in the MODEL_PATH variable with the actual path to your model.
    MODEL_PATH = "/your/actual/model/path/Mistral-7B-Instruct-v0.2-mlx"
  2. Run the Chatbot:

    • Open your terminal or command prompt.
    • Navigate to the directory where your menny_mistral.py script is located.
    • Run the following command:
    streamlit run menny_mistral.py
  3. Interact with Menny Mistral:

    • Once the Streamlit application launches, you can interact with the chatbot via the user interface.
    • Adjust the settings such as "Max Tokens" and "Temperature" in the sidebar as needed for different experiences.

Remember, the provided path in MODEL_PATH must be accurate and point to where the model and its associated files are located for the chatbot to function correctly.

Embracing the Evolution: From Menny Llama to Menny Mistral

menny-mistral.png

In this section, we've witnessed the seamless metamorphosis of Menny Llama into Menny Mistral. This transformation is more than just a change in name; it represents the adaptability and scalability inherent in understanding the core principles of AI models. As we delve into the intricate world of transformers, the similarities in their foundational architecture across different models become strikingly evident.

This uniformity is a vivid illustration of the profound impact of object-oriented methodologies, transcending beyond mere coding practices to shape our learning and comprehension of complex subjects. As we gear up to explore the more complex realms of Mixtral 8x7B, our solid grounding in these principles reassures us of our preparedness to confront and master new challenges.

Thus, our journey continues, fueled by the power of object-oriented learning, ready to unravel the mysteries of Mixtral 8x7B. Join us as we venture into this next phase, where the convergence of technology and education unveils exciting new horizons.