Fine-Tuning BERT for Text Classification: A Step-by-Step Guide with Code Examples
Fine-Tuning BERT for Text Classification: A Step-by-Step Guide with Code Examples

Nov 23, 2024

In our last blog, we explored how to choose the right transformer model, highlighting BERT’s strengths in classification tasks. Now, we dive deeper into fine-tuning BERT with real-world implementations and hands-on code.

Introduction

Text classification is a cornerstone of natural language processing (NLP), enabling tasks such as sentiment analysis, spam detection, and topic categorization. At the forefront of NLP advancements is BERT (Bidirectional Encoder Representations from Transformers), a pre-trained transformer model renowned for its ability to understand context in text.

Fine-tuning BERT for classification tasks not only leverages its contextual understanding but also allows for exceptional performance, even with smaller datasets, as long as they are clean and well-prepared. This blog will guide you through the process of fine-tuning BERT step-by-step, demonstrating its real-world applications with hands-on code and practical insights.

Data Preparation

Before fine-tuning BERT, it’s essential to prepare clean, balanced, and well-structured data to ensure the model learns meaningful patterns and generalizes effectively. In this tutorial, we’ll use a real-world example of resume text chunks, each representing different sections such as Contact Information, Education, Work Experience, and Skills.

To achieve this, ensure the data is free from irrelevant text or missing rows, balanced across classes, and diverse enough to improve model generalization. This will optimize the model’s performance during training and inference.

Sample Dataset

Implementation: Fine Tuning Bert Model

Importing Necessary Libraries

import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import torch
from transformers import BertTokenizer, BertForSequenceClassification
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from tqdm import tqdm

This sets up our environment for the next steps.

Prepare Data For Training

Data preparation is critical for training an effective classification model. Here’s what we focus on:

• Cleaning: Removing irrelevant or incomplete data ensures we work with a high-quality dataset.

• Label Encoding: Converting categorical labels (like “Education” or “Skills”) into numeric labels makes the data compatible with the model.

• Train-Test Split: Separating the data into training and validation sets ensures the model can generalize well.

For example, with a dataset of resume chunks labeled by section:

data = pd.read_csv("resume_data.csv")

# Encode labels into numeric format
label_encoder = LabelEncoder()
data["section_encoded"] = label_encoder.fit_transform(data["section"])

# Split into training and validation sets
train_texts, val_texts, train_labels, val_labels = train_test_split(
    data["text"].values, data["section_encoded"].values,
    test_size=0.2, random_state=42,)

This ensures the data is structured, labeled, and ready for tokenization.

Tokenizer and Dataset Class

After data preparation, we need to tokenize the text and format it for the BERT model:

Tokenizer

The BertTokenizer splits text into smaller subwords and tokens that BERT can process. It also adds special tokens like:

• [CLS]: Indicates the start of the input.

• [SEP]: Separates segments in input.

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")


Dataset Class

The custom ClassificationDataset class takes raw text, tokenizes it, and prepares it for training. It also creates:

• Input IDs: Tokenized text converted into integers.

• Attention Masks: Flags to distinguish real tokens from padding.

• Labels: Encoded numeric labels for each input.

class ClassificationDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        encoding = self.tokenizer(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )
        return {
            "input_ids": encoding["input_ids"].squeeze(0),
            "attention_mask": encoding["attention_mask"].squeeze(0),
            "label": torch.tensor(label, dtype=torch.long),
        }

# Define dataset
max_len = 128
train_dataset = ClassificationDataset(train_texts, train_labels, tokenizer, max_len)
val_dataset = ClassificationDataset(val_texts, val_labels, tokenizer, max_len)

# Create DataLoaders
batch_size = 16
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

This class ensures the data is in the correct format for BERT.

Model Setup

BERT is a pre-trained model that can be fine-tuned for specific tasks like classification:

1. Load Pre-trained BERT: We use bert-base-uncased, which is a lowercase English BERT model.

2. Specify the Number of Labels: In this case, the number of unique sections in the resume data.

3. Device Setup: Leverages GPU if available for faster training.

# Model and Device Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BertForSequenceClassification.from_pretrained( "bert-base-uncased",
                                 num_labels=len(label_encoder.classes_))
model.to(device)

# Optimizer
optimizer = AdamW(model.parameters(), lr=5e-5)

The model is now ready for training.

Training Loop

The training loop is where the model learns patterns in the data:

1. Forward Pass: The input data is fed through the model to calculate predictions.

2. Loss Calculation: The loss measures how far the predictions are from actual labels.

3. Backward Pass: The optimizer updates the model weights to minimize the loss.

# Training Loop
epochs = 3
for epoch in range(epochs):
    model.train()
    total_loss = 0
    loop = tqdm(train_loader, leave=True)

    for batch in loop:
        optimizer.zero_grad()
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["label"].to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        total_loss += loss.item()

        loss.backward()
        optimizer.step()

        loop.set_description(f"Epoch {epoch}")
        loop.set_postfix(loss=loss.item())

    print(f"Epoch {epoch} Loss: {total_loss / len(train_loader)}")

This trains the model to classify resume sections effectively.

Evaluation

After training, the model is evaluated to measure its performance on the validation set. Metrics like accuracy, precision, recall, and F1 score are calculated.

# Evaluation
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for batch in val_loader:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["label"].to(device)

        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        predictions = torch.argmax(outputs.logits, dim=1)

        correct += (predictions == labels).sum().item()
        total += labels.size(0)

accuracy = correct / total
print(f"Validation Accuracy: {accuracy:.4f}")

The fine-tuned BERT model, trained on 1,839 labeled data points, achieved the following validation metrics: Accuracy: 0.8553, F1 Score: 0.8572, Precision: 0.8617, Recall: 0.8553

These results demonstrate the model’s strong performance and generalization capability, even with a modest dataset size.

Saving the Model

Saving the trained model, tokenizer, and label encoder allows us to reuse them for predictions or further training.

model.save_pretrained("bert_resume_classifier")
tokenizer.save_pretrained("bert_resume_classifier")
torch.save(label_encoder, "label_encoder.pth")

Inference

Finally, we use the saved model to classify new text data.

# Load Model and Tokenizer
model = BertForSequenceClassification.from_pretrained("bert_resume_classifier")
tokenizer = BertTokenizer.from_pretrained("bert_resume_classifier")
label_encoder = torch.load("label_encoder.pth")

# Sample Input
sample_text = "AI Engineer from Researchify."
inputs = tokenizer(sample_text, return_tensors="pt", padding="max_length", truncation=True, max_length=128).to(device)

# Prediction
outputs = model(**inputs)
predicted_class = torch.argmax(outputs.logits, dim=1).item()
print(f"Predicted Section: {label_encoder.inverse_transform([predicted_class])[0]}")

This demonstrates how to classify new resume chunks with the fine-tuned model.

To make this tutorial more accessible, I’ve provided a Colab notebook that includes all the code and explanations discussed in this blog. You can run the notebook in your browser, explore the implementation hands-on, and adapt it to your datasets with ease.

Access the notebook here: Colab Notebook

What’s Next?

In this blog, we explored how to fine-tune BERT for classification tasks, using real-world data and hands-on implementation. From data preparation to evaluation, each step was tailored to help you apply BERT effectively in your own projects.

Next, we’ll dive into advanced techniques like deploying models into production environments, scaling them for real-world applications, and understanding their capabilities and limitations. Stay tuned!

- Somasunder S, AI Engineer - Researchify Labs

Made with ❤️ in Bangalore, India.

copyright @ researchify.io 2024

Made with ❤️ in Bangalore, India.

copyright @ researchify.io 2024

Made with ❤️ in Bangalore, India.

copyright @ researchify.io 2024