# Fine-Tune Mistral Nemo 12bn

In this article, we'll explore how to fine-tune the Mistral-Nemo 12B model for a specialized use case using the ChatML template, quantize it to FP8, and deploy it as a vLLM API. This setup allows it to function seamlessly as a drop-in replacement for OpenAI's GPT API.

Fine-tuning is essential when aiming for high-quality performance in a specific application, such as crafting precise responses to particular types of questions or structuring unstructured data in a tailored way. However, if your goal is to build a generalized model like GPT, Claude, LLaMA, or Mistral, fine-tuning alone—especially with limited high-quality, diverse data—won't suffice. Therefore, we'll focus on optimizing the model to excel in a targeted, specialized task where fine-tuning can truly shine.

### Coverage

In this tutorial, we will cover the following:

* Fine-tuning Mistral-Nemo 12B using QLoRA for efficient memory usage
* Formatting datasets with ChatML for instruction-following tasks
* Quantizing the model to FP8 for reduced size and improved hardware compatibility
* Deploying the model as an API with vLLM on a 24GB GPU for fast and cost-effective inference

There are several methods to fine-tune an LLM like Mistral-Nemo 12B. One option is to use the "official" Python packages provided by Mistral (available [here](https://github.com/mistralai/mistral-finetune/tree/main)). Alternatively, you can leverage the robust Huggingface library, or for the most memory-efficient approach, you can combine Unsloth with Huggingface. In this tutorial, we'll be using Unsloth alongside the Huggingface stack for optimal performance.

If you prefer working with notebooks, Unsloth also offers [notebook-based tutorials](https://docs.unsloth.ai/get-started/unsloth-notebooks) to help you get started. Personally, I prefer a more "raw" approach, so notebooks aren't my go-to, but they’re a great resource if that’s your style.

### Requirements

For this tutorial, you'll need access to a GPU with at least 24GB of VRAM (though 16GB may work with some limitations). Suitable options include the RTX 3090/4090, A100, H100, L4, RTX ADA 6000, or AMD's MI300. In this tutorial, we used an RTX ADA 6000, so we will be using CUDA.

On the software side, make sure to have all dependencies listed in the `requirements.txt` file installed before proceeding.

```
datasets
transformers
peft
...
```

### Data

High-quality data is absolutely essential for achieving good results when fine-tuning a model. In fact, it's likely the most critical factor in the entire process. If your data is flawed or contains errors, the model will inevitably learn and replicate those patterns, leading to subpar performance. The more issues in your data, the more compromised your model will be. Therefore, before beginning the fine-tuning process, it’s vital to ensure your data accurately reflects the goals you aim to achieve.

For this tutorial, we will focus on using data in the ChatML format. ChatML, or Chat Markup Language, is a format widely used—with variations—in nearly all instruction-tuned models today. If you've previously worked with the GPT API, you’re likely familiar with this format in its JSON form:

```json
[{"role": "system", "content" : "You are ChatGPT, a large language model trained by OpenAI. Answer as concisely as possible.\nKnowledge cutoff: 2021-09-01\nCurrent date: 2023-03-02"},
{"role": "user", "content" : "How are you?"},
{"role": "assistant", "content" : "I am doing well"}]
```

Our goal is to fine-tune the model so it understands a variant of this format, enabling us to later use it as a drop-in replacement with the OpenAI Python package and vLLM. We'll use a JSONL dataset for training, which leverages this format. JSONL (JSON Lines) is a format where each line contains an individual chat conversation, making it easier to manage large datasets.

Each line should end with a message from the "assistant" role, as this will serve as the completion that our model is trained to generate. While conversations can be as long as the token limit allows, preparing multi-turn datasets can be quite complex. For simplicity, this tutorial will focus on single-turn instructions.

We’ll be working with a dataset in the file `training_data.jsonl` formatted according to these guidelines.

```
{"messages": []}....
```

### Model Fine-Tuning

With our data now in the correct format, we can move on to the fine-tuning process. For this, we’ll be using Huggingface and Unsloth, opting for a more efficient approach: QLoRA fine-tuning instead of a full fine-tune. This allows us to fine-tune the model while minimizing memory usage and preserving model performance. Let’s dive into the fine-tuning script and get started!

{% hint style="info" %}
LoRA (Low-Rank Adaptation) is a method for fine-tuning large language models by adding low-rank trainable matrices to the model's layers, reducing the number of parameters that need to be updated. This approach makes fine-tuning more efficient by focusing only on specific components, saving memory and computational resources without sacrificing (much) performance.

\
QLoRA (Quantized Low-Rank Adaptation) is a technique used to fine-tune large language models efficiently by reducing the model's memory footprint through quantization, while still retaining the flexibility of LoRA (Low-Rank Adaptation). It enables lower resource usage and faster adaptation without significant performance loss.
{% endhint %}

Before we begin the fine-tuning process, we need to select the right model. Many models, including Mistral, LLaMA, and Gemma, typically come in two versions:

1. **Base**
2. **Instruct**

The Instruct model is a refined, further fine-tuned version of the Base model. While the Base model is essentially a "raw" text completion model, the Instruct version is optimized to follow instructions and respond to specific chat formats more effectively.

The choice between the two depends on the problem you're aiming to solve. The Base model offers more flexibility and adaptability, making it a better option if you're looking to fine-tune for a custom use case. For this tutorial, we'll be working with the Base version of Mistral-Nemo 12B to fine-tune it according to our specific needs.

{% hint style="info" %}
We had quite some issues fine-tuning Mistral Nemo Base using ChatML format. The problem was that the EOS token wasn't always added correctly and the generation did not stop. However, the script below does work.
{% endhint %}

Since we're working with the Base model, we have the flexibility to define a custom text completion template that suits our specific requirements. This could include function calling or other specialized tags. However, for simplicity, we'll follow the chat format used by the Instruction model. To do this, we'll load the Instruction model and extract the text template from its tokenizer.

While loading the entire model just for a text template might seem excessive—since it's essentially just a string—you can define the template directly if you prefer. Huggingface offers detailed [documentation](https://huggingface.co/docs/transformers/main/en/chat_templating) on how to create your own templates. For ease of use in this tutorial, though, we'll simply use the template from the Instruction model's tokenizer.

```python
import os
from unsloth import FastLanguageModel, is_bfloat16_supported
from datasets import load_dataset
from trl import SFTTrainer
from transformers import TrainingArguments

# Load dataset
def load_training_data(file_path):
    return load_dataset("json", data_files=file_path, split="train")

# Load model and tokenizer with shared configurations
def load_model_and_tokenizer(model_name, max_seq_length, dtype, load_in_4bit, hf_token):
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=model_name,
        max_seq_length=max_seq_length,
        dtype=dtype,
        load_in_4bit=load_in_4bit,
        token=hf_token
    )
    return model, tokenizer

# Load tokenizer only for instruct model
def load_instruct_tokenizer(instruct_model_name, max_seq_length, dtype, load_in_4bit, hf_token):
    _, instruct_tokenizer = FastLanguageModel.from_pretrained(
        model_name=instruct_model_name,
        max_seq_length=max_seq_length,
        dtype=dtype,
        load_in_4bit=load_in_4bit,
        token=hf_token
    )
    return instruct_tokenizer

# Prepare QLoRA model
def prepare_lora_model(model, r, target_modules, lora_alpha, lora_dropout, bias, gradient_checkpointing, random_state, rslora, loftq_config):
    return FastLanguageModel.get_peft_model(
        model,
        r=r,
        target_modules=target_modules,
        lora_alpha=lora_alpha,
        lora_dropout=lora_dropout,
        bias=bias,
        use_gradient_checkpointing=gradient_checkpointing,
        random_state=random_state,
        use_rslora=rslora,
        loftq_config=loftq_config
    )

# Helper function to stitch roles for prompt formatting
def stitch_roles(prompt):
    try:
        system_message = next((msg for msg in prompt if msg['role'] == 'system'), None)
        user_message = next((msg for msg in prompt if msg['role'] == 'user'), None)

        if system_message and user_message:
            user_message['content'] = system_message['content'] + '\n\n' + user_message['content']
            prompt = [msg for msg in prompt if msg['role'] != 'system']
    except Exception as e:
        print(f"ERROR: Failed to stitch roles in the prompt: {e}")
        raise

    return prompt

# Formatting function for prompts
def create_prompt(prompt, tokenizer):
    prompt = stitch_roles(prompt["messages"])
    return tokenizer.apply_chat_template(prompt, tokenize=False)

# Main training function
def train_model(train_data_file, model_name, instruct_model_name, max_seq_length, dtype, load_in_4bit, output_dir, num_epochs=6):
    hf_token = os.getenv("HF_TOKEN")

    # Load dataset
    dataset = load_training_data(train_data_file)

    # Load base model and tokenizer
    model, tokenizer = load_model_and_tokenizer(model_name, max_seq_length, dtype, load_in_4bit, hf_token)
    
    # Load instruct tokenizer and update tokenizer with instruct chat template
    instruct_tokenizer = load_instruct_tokenizer(instruct_model_name, max_seq_length, dtype, load_in_4bit, hf_token)
    tokenizer.chat_template = instruct_tokenizer.chat_template
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = 'right'  # Prevent warnings
    
    # Prepare QLoRA model
    model = prepare_lora_model(
        model=model,
        r=16,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
        lora_alpha=16,
        lora_dropout=0,
        bias="none",
        gradient_checkpointing="unsloth",
        random_state=3407,
        rslora=False,
        loftq_config=None
    )

    # Define trainer
    trainer = SFTTrainer(
        model=model,
        tokenizer=tokenizer,
        train_dataset=dataset,
        formatting_func=lambda prompt: create_prompt(prompt, tokenizer),
        max_seq_length=max_seq_length,
        packing=True,  # Can make training 5x faster for short sequences
        args=TrainingArguments(
            per_device_train_batch_size=6,
            gradient_accumulation_steps=1,
            warmup_steps=5,
            num_train_epochs=num_epochs,
            learning_rate=3e-5,
            fp16=not is_bfloat16_supported(),
            bf16=is_bfloat16_supported(),
            logging_steps=1,
            optim="adamw_8bit",
            lr_scheduler_type="constant",
            seed=3407,
            output_dir=output_dir,
            save_total_limit=5,
            save_steps=100  # Save every 100 steps
        )
    )

    # Start training
    trainer_stats = trainer.train()
    return trainer_stats

if __name__ == "__main__":
    # Define parameters
    TRAIN_DATA_FILE = "train_data.jsonl"
    MODEL_NAME = "unsloth/Mistral-Nemo-Base-2407"
    INSTRUCT_MODEL_NAME = "unsloth/Mistral-Nemo-Instruct-2407"
    MAX_SEQ_LENGTH = 1024
    DTYPE = "float16"
    LOAD_IN_4BIT = True
    OUTPUT_DIR = "outputs-fusionbase-seraphon-12b-8k"

    # Run the training process
    train_stats = train_model(
        train_data_file=TRAIN_DATA_FILE,
        model_name=MODEL_NAME,
        instruct_model_name=INSTRUCT_MODEL_NAME,
        max_seq_length=MAX_SEQ_LENGTH,
        dtype=DTYPE,
        load_in_4bit=LOAD_IN_4BIT,
        output_dir=OUTPUT_DIR,
        num_epochs=6
    )
    print("Training completed with stats:", train_stats)

```

Based on your available resources, there are a couple of key parameters you can adjust: the LoRA rank (`r=16`) and the LoRA alpha (`lora_alpha=16`). These values determine how many parameters are fine-tuned during training. Higher rank and alpha values mean more parameters are adjusted, resulting in a larger adapter for the model. There are some excellent benchmarks available, like [this article](https://magazine.sebastianraschka.com/p/practical-tips-for-finetuning-llms), which explores different ranks and alpha settings. We've also seen strong results with higher values such as `rank=256` and `alpha=512`.

Once training is complete, the QLoRA adapter will be saved in the output directory specified by the `TrainingArguments`. Keep in mind, this adapter is not the full model—it contains only the fine-tuned parameters. Depending on your needs, you can either use the adapter as is, or merge it with the base model to create a fully self-contained, standalone model. In our case, we aim to create a standalone model, so we will merge the adapter back into the base model.

```
<ADD code for merging>
```

As shown in the script, we have the option to either push the fine-tuned model directly to Huggingface or save it locally in a folder. Keep in mind that, in this script, the model is not being quantized during the process, so the final model will retain the same size as the original Base model. If you require a smaller model, additional steps for quantization would need to be applied afterward.

#### Quantize FP8 for vLLM

FP8 quantization effectively reduces the model size by about half, making it much more manageable for hardware like 24GB VRAM GPUs, such as the RTX 3090/4090 or the widely used L4s. While quantization can sometimes lead to a slight decrease in model performance, using FP8 minimizes this impact, allowing for efficient compression without significant quality loss.

{% hint style="info" %}
FP8 (8-bit Floating Point) is a numerical format used to represent model weights and activations in deep learning, reducing memory usage and computational overhead compared to traditional formats like FP32. It enables faster processing and lower resource consumption while maintaining acceptable precision for training and inference in large models.
{% endhint %}

Although vLLM supports dynamic quantization to FP8, there's a limitation: the model must first be loaded at its full size. This means if the model isn’t already quantized beforehand, it won’t fit into a 24GB VRAM card. For optimal results, pre-quantizing the model ensures it can be loaded and run on these GPUs.

```
<ADD CODE FOR QUANTIZATION>
```

### vLLM OpenAI compatible API

vLLM is a fast and user-friendly library for LLM inference and model serving (check out the [documentation here](https://docs.vllm.ai/en/latest/)). Deploying a model with vLLM is likely the simplest part of the process, offering two main options for serving the model:

1. Serving directly from a local folder
2. Serving from Huggingface (in this case, the model is downloaded from Huggingface but still served locally)

Both methods allow for quick and efficient deployment with minimal setup.

```
<ADD CODE TO SERVE MODEL>
```

### Conclusion

And that's it! You've successfully fine-tuned, quantized, and deployed Mistral-Nemo. For a quick and cost-effective solution to fine-tune and deploy models, consider using platforms like [Runpod](https://runpod.io?ref=qzfldfvc). It simplifies the process and helps you scale efficiently.


---

# Agent Instructions: Querying This Documentation

If you need additional information that is not directly available in this page, you can query the documentation dynamically by asking a question.

Perform an HTTP GET request on the current page URL with the `ask` query parameter:

```
GET https://blog.fusionbase.com/llm/fine-tune-mistral-nemo-12bn.md?ask=<question>
```

The question should be specific, self-contained, and written in natural language.
The response will contain a direct answer to the question and relevant excerpts and sources from the documentation.

Use this mechanism when the answer is not explicitly present in the current page, you need clarification or additional context, or you want to retrieve related documentation sections.
