# Train a BERT classifier with synthetic data

Do we still need specifically trained classifiers in the age of zero-shot and LLMs? The answer to this question is - it depends. As of the time of writing (09/2024) even small LLMs with around 2bn parameters run quite slow (if at all) in a limited hardware setup without fast GPUs. Therefore, using a specialized classifier could be due to scalability and cost of operation. Another big factor is the complexity of the problem. Having a large number of classes, that are close to each other and that where there is no additional data available our experience is that even base-level BERT models after fine-tuning perform better than SOTA LLMs in a zero shot variant.

Currently, there are very few benchmarks available that compare "classic" language models like BERT with LLMs. One such experiment was published by Huggingface (<https://huggingface.co/blog/Lora-for-sequence-classification-with-Roberta-Llama-Mistral>) and it shows, that the RoBERTA based model outperformed the LLM variants by a large margin, by being faster and cheaper.

{% hint style="info" %}
To have a BERT-like classifier performing well it is necessary that the used Base model was pre-trained on data that contains "knowledge" and "semantics" that are necessary to understand the text. The Base models released by Google, Meta and Microsoft are very good in general English. However, in our experiments with a complex set of labels those did always perform mediocer. There are fully pre-trained German models though like deepset/gbert-base but also those are not really trained on specialized domain knowledge - depending on the use case of course.&#x20;

We achieved best results for very domain specific classifications by using domain adoption techniques like continued pre-training on texts that contain the domain knowledge. Be aware, that this requires a lot of data and ressources to do.
{% endhint %}

### Coverage

In this tutorial we will cover how to structure a training dataset and further augment the dataset with synthetic data to finally train a BERT-based classification model. For that, we will use the Huggingface library and ecosystem to fine-tune the model.

* Fine-tune a BERT-based classifier model
* Structure a training dataset
* Create synthetic samples to augment the training data

Eventually, we will evaluate the fine-tuned model and explain the metrics used in doing so.

### Requirements

For this tutorial, you'll need access to a GPU. The amount of VRAM available is a less important factor than for LLMs since BERT-based models are usually relatively small. However, more VRAM allows for larger batch sizes.

On the software side, make sure to have all dependencies listed in the `requirements.txt` file installed before proceeding.

```
// Some code
```

### Data

The structure of the dataset is relatively simple since we expect the model to have a text as an input and a label as an output, that is exactly how the structure of the training dataset should look like. Therefore we will have a CSV file with just two columns: text, label

We will use question answer pairs as input and the the used model for creating the answer as label. Therefore, our classifier should be able to detect which model answerd the question. For simplicity reasons, we use gpt-4o, gpt-4o-mini and gpt-3.5

<table><thead><tr><th width="586">text</th><th>label</th></tr></thead><tbody><tr><td>&#x3C;div class="tzej">We repair your phone within 3 work days&#x3C;/div></td><td>gpt-4o</td></tr><tr><td>&#x3C;div class="footer">&#x3C;a href="/contact">Contact us&#x3C;/a>&#x3C;/div></td><td>gpt-4o-mini</td></tr><tr><td>hwejrh wuerzweuir</td><td>gpt-3.5</td></tr></tbody></table>

We can easily generate the data by passing a question to the corresponding model API.

```
<script to train the model>
```

It is important to know that the semantics with in the labels is usually not leveraged since those are not embedded. Therefore it does not matter if a label is called "product\_description" or "asduhu7u".&#x20;

### Model Fine-Tuning

Once we have the data in a shape that we are satiesfied we can start with the actual fine-tuning process. This process is straight forward due to the abstractions that the Huggingface library offers. There are some points where we need to adjust. For evaluating the performance of our model training we should keep track of the training and the evaluation loss to see if and how our model still learns. Additionally, it is important to see how it actually performs in terms of Recall, Precision and F1 scores.

{% hint style="info" %}
**Understanding Precision, Recall, and F1 Score**

Imagine you're sorting apples from a mixed basket of fruits, and you want only the apples.

* **Precision** is about accuracy in your picks. It answers the question: *Out of all the fruits you picked as apples, how many really are apples?*
  * **Example:** You picked 5 fruits and said they are apples. If 4 are actually apples and 1 is a pear, your precision is 4 out of 5.
* **Recall** is about completeness. It answers: *Out of all the actual apples in the basket, how many did you successfully pick?*
  * **Example:** There are 10 apples in total. If you picked 4 of them, your recall is 4 out of 10.
* **F1 Score** combines precision and recall into one number to show the balance between them. It's useful when you want a single measure of overall performance.

So, if you're really good at picking apples accurately (high precision) but miss many apples in the basket (low recall), the F1 score helps you see that balance.
{% endhint %}

The add those additional score in the evaluation steps of our model training, we will implement a short function that does this evaluation for us. Below you'll find the full training script including the evaluation function.

```
script
```

#### Hyperparameters

Let's break down each of the training settings to understand what they do and how changing them might affect the training of a machine learning model.

***

**`output_dir='./results'`**

* **Purpose:** This specifies where the model's outputs—like checkpoints and logs—will be saved.
* **Effect of Changing:** If you set this to a different directory, all the saved models and logs will be stored there instead.

***

**`learning_rate=5e-5`**

* **Purpose:** The learning rate determines how quickly the model updates its parameters in response to the calculated error each time it learns.
* **Effects of Changing:**
  * **Increasing the Learning Rate:** The model learns faster but risks overshooting the optimal solution, which can prevent it from converging properly.
  * **Decreasing the Learning Rate:** The model learns more slowly and carefully, which can lead to better convergence but might get stuck in local minima (suboptimal solutions) and take a very long time to train.
* **Important to Know:** A learning rate that's too low can make the model get stuck in a local minimum, where it thinks it has found the best solution but hasn't explored enough to find the true best.

***

**`lr_scheduler_type='cosine_with_restarts'`**

* **Purpose:** The learning rate scheduler adjusts the learning rate during training according to a specific schedule.
* **Cosine with Restarts:** This scheduler decreases the learning rate following a cosine curve and then restarts it periodically. This can help the model escape local minima by periodically increasing the learning rate.
* **Effect of Changing:**
  * **Different Schedulers:** Using a different scheduler (like 'constant') will adjust the learning rate differently, which can impact how well and how quickly the model learns.

***

**`per_device_train_batch_size=64`**

* **Purpose:** This is the number of training samples the model processes before updating its parameters.
* **Effects of Changing:**
  * **Increasing Batch Size:** Can speed up training because it processes more data at once, but requires more memory. It also makes the gradient estimates more stable.
  * **Decreasing Batch Size:** Uses less memory and can make the model more responsive to the data, but the training might be noisier and less stable.

***

**`per_device_eval_batch_size=64`**

* **Purpose:** Similar to the training batch size, but used during evaluation to assess model performance.
* **Effect of Changing:** Adjusting this can speed up or slow down evaluation, affecting how quickly you can get feedback on the model's performance.

***

**`num_train_epochs=10`**

* **Purpose:** Specifies how many times the model will go through the entire training dataset.
* **Effects of Changing:**
  * **Increasing Epochs:** Gives the model more opportunities to learn from the data, which can improve performance but may lead to overfitting, where the model learns the training data too well and doesn't perform well on new data.
  * **Decreasing Epochs:** May prevent the model from fully learning the patterns in the data, leading to underfitting.

***

**`weight_decay=0.01`**

* **Purpose:** Weight decay is a regularization technique that penalizes large weights in the model to prevent overfitting.
* **Effects of Changing:**
  * **Increasing Weight Decay:** Encourages the model to keep weights small, which can improve generalization to new data.
  * **Decreasing Weight Decay:** Allows the model to have larger weights, which might capture complex patterns but can increase the risk of overfitting.

***

**`logging_dir='./logs'`**

* **Purpose:** Specifies where to save logs that record details about the training process.
* **Effect of Changing:** Directs logs to a different directory, which can help organize different training runs.

***

**`logging_steps=500`**

* **Purpose:** Determines how often (in steps) the training process logs information.
* **Effects of Changing:**
  * **Decreasing Logging Steps:** Logs information more frequently, providing more detailed insight but may slow down training due to the overhead.
  * **Increasing Logging Steps:** Logs less frequently, reducing overhead but providing less detailed information.

***

**`save_total_limit=5`**

* **Purpose:** Limits the number of saved model checkpoints to prevent using too much storage.
* **Effects of Changing:**
  * **Increasing the Limit:** Keeps more checkpoints, which can be useful for revisiting older model states but uses more storage.
  * **Decreasing the Limit:** Saves storage space but reduces the number of checkpoints you can revert to.

***

**`evaluation_strategy="epoch"`**

* **Purpose:** Determines how often the model evaluates its performance on a validation set.
* **Effects of Changing:**
  * **Setting to "steps":** Evaluates the model every few steps, providing more frequent feedback but increasing computation time.
  * **Setting to "no":** Disables evaluation during training, which might speed up training but leaves you without validation metrics until the end.

***

**`save_strategy="epoch"`**

* **Purpose:** Dictates how often the model saves its state (checkpoints).
* **Effects of Changing:**
  * **Saving More Frequently:** Provides more recovery points in case of interruption but can slow down training and use more storage.
  * **Saving Less Frequently:** Speeds up training and saves storage but risks losing more progress if training is interrupted.

***

**`label_smoothing_factor=0.05`**

* **Purpose:** Label smoothing softens the target labels, which can prevent the model from becoming too confident in its predictions.
* **Effects of Changing:**
  * **Increasing the Factor:** Makes the model less confident, which can improve generalization but might make learning slower.
  * **Decreasing the Factor:** Allows the model to be more confident, which can speed up learning but might increase overfitting.
* **Important to Know:** A high label smoothing factor can cause the model to underperform because it can't rely on clear signals from the labels.

***

**`load_best_model_at_end=True`**

* **Purpose:** After training, this setting ensures that the model loads the version that performed best on the validation set.
* **Effect of Changing:** Setting this to `False` means the model will keep its final state at the end of training, which might not be the best performing version.

***

**Key Considerations**

* **Learning Rate and Local Minima:** A small learning rate might cause the model to get stuck in a local minimum, where it settles on a solution that isn't the best overall. Using learning rate schedulers or adjusting the learning rate can help the model explore the solution space more effectively.
* **Batch Size and Model Stability:** Larger batch sizes can make the training process more stable because they provide a better estimate of the gradient, but they require more computational resources.
* **Overfitting vs. Underfitting:**
  * **Overfitting:** When the model learns the training data too well, including its noise and outliers, and performs poorly on new data.
  * **Underfitting:** When the model hasn't learned enough from the training data and performs poorly both on training and new data.
  * **Regularization Techniques:** Settings like `weight_decay` and `label_smoothing_factor` help prevent overfitting by encouraging the model to keep weights small and not be overly confident.
* **Evaluation and Saving Strategies:** Regular evaluation and saving can help monitor the model's progress and provide recovery points. However, they add computational overhead, so it's a balance between monitoring and efficiency.
* **Adjusting Hyperparameters:** Finding the right values for these settings often requires experimentation. What works best can depend on the specific dataset and the problem you're trying to solve.

### Using the Model

Once we are happy with the model's performance, we can finally use it to do some predictions. Using the model is straight forward we basically need to model and the mapping of the labels.&#x20;

```
<prediction function>
```

That's basically it, you can now use the model to deploy it anyhwhere you want!


---

# Agent Instructions: Querying This Documentation

If you need additional information that is not directly available in this page, you can query the documentation dynamically by asking a question.

Perform an HTTP GET request on the current page URL with the `ask` query parameter:

```
GET https://blog.fusionbase.com/llm/train-a-bert-classifier-with-synthetic-data.md?ask=<question>
```

The question should be specific, self-contained, and written in natural language.
The response will contain a direct answer to the question and relevant excerpts and sources from the documentation.

Use this mechanism when the answer is not explicitly present in the current page, you need clarification or additional context, or you want to retrieve related documentation sections.
