跳转到主要内容
Chal1ce blog

从Kaggle-Eddi认识对比学习和双塔模型

参考一下kdd cup的冠军方案,我们来学学对比学习和双塔模型

上次更新已经是两周前了,这几天参加了kaggle的 Eddi 数学误解挖掘赛,在讨论区前排选手分享了一个对比学习的双塔模型思路,而参考的训练代码则是 他在 KddCup 2024 OAG挑战赛中获得第一名的训练代码,这个思路也让我学到了很多,附上代码链接:simcse_deepspeed_mistrial_qlora_argu.py

这篇文章我主要来介绍以下几点:

1、大神的代码和思路

2、对比学习

3、双塔模型

1、大神的代码和思路

1)任务简介

image.png

这场比赛的任务是给你两份文件,其中一份文件里面包含了数学题、正确答案、错误答案,另一份文件里面包含造成错误答案的原因(误解)。你要训练一个模型来识别匹配出错误答案对应的误解。

2)大神思路

首先附上大佬的推理代码:点击查看

简单点介绍就是,使用对比学习的方法,训练一个双塔模型,使用 对比学习 来优化嵌入空间,让问题、错误答案和相关误解之间的距离更近,而与不相关误解之间的距离更远。

推理方面,使用训练好的模型对问题、正确答案、错误答案和误解集进行编码,找出最相近的25个误解。(比赛的评分函数是MAP@25)以下代码来自上方推理代码:

def inference(df, model, tokenizer, device):
    batch_size = 16
    max_length = 512
    sentences = list(df['query_text'].values)
    pids = list(df['order_index'].values)
    all_embeddings = []
    length_sorted_idx = np.argsort([-len(sen) for sen in sentences])
    sentences_sorted = [sentences[idx] for idx in length_sorted_idx]
    for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=False):
        sentences_batch = sentences_sorted[start_index: start_index + batch_size]
        features = tokenizer(sentences_batch, max_length=max_length, padding=True, truncation=True,
                             return_tensors="pt")
        features = batch_to_device(features, device)
        with torch.no_grad():
            outputs = model.model(**features)
            embeddings = last_token_pool(outputs.last_hidden_state, features['attention_mask'])
            embeddings = torch.nn.functional.normalize(embeddings, dim=-1)
            embeddings = embeddings.detach().cpu().numpy().tolist()
        all_embeddings.extend(embeddings)

    all_embeddings = [np.array(all_embeddings[idx]).reshape(1, -1) for idx in np.argsort(length_sorted_idx)]

    sentence_embeddings = np.concatenate(all_embeddings, axis=0)
    result = {pids[i]: em for i, em in enumerate(sentence_embeddings)}
    return result
cosine_similarity = np.dot(query_em, sentence_embeddings.T).flatten()
sort_index = np.argsort(-cosine_similarity)[:25]

在模型方面,比赛中是用到了训练好的Mistral 7b的Lora权重,将Mistral 7b进行对比学习,参考的训练代码中因为有Deepspeed加速、LoRA 微调和 4-bit 量化,所以在设备条件允许的情况下,训练速度还是挺快的,不过我设备不足,在训练批次不够的情况下没办法达到最佳训练效果,退而求其次去搞了sentence transformer 和 deberta 系列模型。

在训练代码中,大佬定义的双塔模型用到了下面的思路,如果想看详情的话可以点我在文章顶部附上的链接:

  1. 查询编码器文档编码器 是相同的模型(权重共享)。
  2. 支持使用 LoRA 和量化技术以减少内存消耗并加速训练。
  3. 使用批次内负样本和跨设备负样本技术增强对比学习的效果。
  4. 支持多种句子池化方法,灵活提取句子嵌入。
  5. 支持分布式训练,并通过温度系数调整 softmax 平滑度。

数据方面

数据加载与预处理:

  • 通过 TrainDatasetForEmbedding 类加载查询和文档数据,构建正负样本。
  • 支持将多个文件合并为一个训练数据集,并提供对查询和文档的文本提示功能。

数据批量化与增强:

  • 使用 EmbedCollator 类将数据批量化,并对输入进行随机遮蔽。
  • 处理后的批量数据包含 tokenized 的查询和文档,可以直接用于模型训练。

大佬的整个训练流程高度优化,适用于大规模、分布式和多 GPU 环境。通过 DeepSpeed、LoRA 和 4-bit 量化的结合,显著降低了显存占用,提高了训练效率。以下是其关键特性:

  • 数据加载与并行计算:支持分布式数据加载和训练。
  • 优化器与调度器:使用 DeepSpeed 的高效优化器和自定义学习率调度。
  • 模型微调与量化:结合 LoRA 和量化技术进行高效微调。
  • 混合精度训练:减少显存占用,加速计算。
  • 动态评估与保存:通过检查点和提前停止机制,优化训练流程。

这个训练框架适合于大规模语义检索和对比学习场景,例如搜索引擎、问答系统以及向量数据库的构建,以后大家如果有需要的场景也可以参考大佬的代码。

下面来介绍一下对比学习和双塔模型。

2、对比学习与双塔模型

在现代NLP领域,对比学习(Contrastive Learning)和双塔模型(Bi-Encoder Model)在语义搜索、推荐系统、信息检索等任务中展现出了卓越的性能。

一、什么是对比学习?

1.1 对比学习的概念

对比学习是一种 自监督学习 方法,其核心思想是通过 拉近相似样本之间的距离(Positive Pair),同时 拉远不相似样本之间的距离(Negative Pair),从而学习到更有判别力的特征表示。它被广泛应用于视觉识别、自然语言处理和多模态任务中。

1.2 对比学习的目标

在对比学习中,模型的目标是最大化相似样本对(正样本对)的相似度,同时最小化与无关样本(负样本对)的相似度。通过这种方式,模型能够学习到一个语义一致的嵌入空间,使得相似的输入数据在嵌入空间中更加接近,而不相似的数据则彼此远离。

1.3 常见的对比学习损失函数

  • 对比损失(Contrastive Loss)

    • 用于度量样本对的相似性,常应用于孪生网络(Siamese Network)。
  • 信息噪声对比估计(InfoNCE Loss)

    • 通过最大化正样本对的相似度与随机负样本的相似度差异来优化模型,常用于 SimCLR、MoCo、CLIP 等预训练任务。

    • 公式如下:

L=logexp(sim(q,k+)/τ)i=0Kexp(sim(q,ki)/τ)\mathcal{L} = -\log \frac{\exp(\text{sim}(q, k^+)/\tau)}{\sum_{i=0}^K \exp(\text{sim}(q, k_i)/\tau)}
  • 其中:

    • qq:查询向量
    • k+k^+:正样本向量
    • kik_i:负样本向量
    • sim()\text{sim}():相似度度量(通常使用余弦相似度)
    • τ\tau:温度参数,用于调整对比度

二、什么是双塔模型?

2.1 双塔模型的概念

双塔模型是一种经典的 深度学习架构,主要用于 检索任务。它由两个独立的编码器(通常是 Transformer)组成,分别处理查询(Query)和文档(Document),从而生成各自的嵌入表示。这些嵌入表示可以用来计算查询与文档之间的相似度。

2.2 双塔模型的架构

+---------------------+ +---------------------+ | Query Encoder A | | Document Encoder B | +---------------------+ +---------------------+ | | Query Embedding Document Embedding | | \ | / \ | / Compute Similarity Score (e.g., Dot Product)

  • Encoder A:处理查询输入,将其编码为查询嵌入向量。
  • Encoder B:处理文档输入,将其编码为文档嵌入向量。
  • 相似度计算:通过查询嵌入和文档嵌入的点积或余弦相似度来计算匹配分数。

2.3 双塔模型的优势

  • 可扩展性强:可以预先生成和存储大量文档嵌入向量,仅在查询时计算一次相似度,大大提高检索效率。
  • 低延迟:适合于实时性要求高的应用场景,如搜索引擎、推荐系统等。
  • 解耦的编码方式:使得查询和文档的嵌入可以独立更新,便于处理动态内容。

三、对比学习与双塔模型的关系

3.1 结合对比学习与双塔模型

双塔模型虽然结构简单,但其性能依赖于嵌入向量的质量。为了提升查询和文档嵌入的质量,对比学习被广泛应用于双塔模型的训练中。通过对比学习,模型可以更好地拉近查询与正确文档之间的距离,同时拉远与无关文档的距离,从而增强检索效果。

3.2 训练过程

  1. 构建正负样本对

    • 正样本对:查询与相关文档(如用户点击过的文档)。
    • 负样本对:查询与随机抽取的无关文档,或通过 难负样本挖掘(Hard Negative Mining)得到的文档。
  2. 编码输入

    • 使用双塔模型中的两个编码器分别对查询和文档进行编码,得到查询嵌入 ( q ) 和文档嵌入 ( d )。
  3. 计算相似度

    • 通过点积或余弦相似度计算查询和文档嵌入之间的匹配分数。
  4. 优化嵌入空间

    • 使用 InfoNCE 损失函数Triplet Loss 进行优化。
    • 优化目标是最大化查询与正样本文档的相似度,同时最小化查询与负样本文档的相似度。

四、对比学习与双塔模型的实际应用

4.1 语义搜索(Semantic Search)

通过对比学习训练的双塔模型,可以将查询和文档映射到相同的嵌入空间中,使得语义相似的查询和文档具有更高的相似度。这样一来,即便用户的查询与文档表述不同,模型也能通过语义理解匹配到相关内容。

4.2 信息检索(Information Retrieval)

双塔模型可以快速地在大规模文档库中进行检索。通过预先计算并存储文档嵌入向量,只需对查询进行编码并与文档嵌入进行相似度计算,即可快速找到最相关的文档,大大降低了在线计算的开销。

4.3 问答系统(Question Answering)

在问答系统中,双塔模型被用来匹配用户问题与知识库中的答案。通过对比学习,模型可以学习到更好的问题与答案之间的映射关系,从而提升问答系统的精确度。


五、实例代码:如何实现对比学习双塔模型

为了实现一个自定义的 BiEncoderModel,我们需要构建一个双塔(Bi-Encoder)架构。这个模型通常包含两个独立的编码器(可以是预训练的 Transformer 模型,如 BERT、RoBERTa 等),分别用于处理 查询(Query)文档(Document) ,并输出各自的嵌入表示。这种架构的主要优势在于其 高效的检索能力,尤其适合在大型文档库中进行相似性搜索。

接下来,我们将详细实现 BiEncoderModel,并支持以下特性:

  • Normalized:对输出的嵌入进行 L2 归一化,以确保计算余弦相似度时的稳定性。
  • Cross Device Negatives:支持跨设备的负样本选择,以提升对比学习的效果。
  • Temperature Scaling:引入温度参数来调节对比损失函数中的相似度分布。

一个双塔模型的例子如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer

class BiEncoderModel(nn.Module):
    def __init__(self, args, normalized=True, negatives_cross_device=False, temperature=0.05):
        super(BiEncoderModel, self).__init__()
        self.args = args
        self.normalized = normalized
        self.negatives_cross_device = negatives_cross_device
        self.temperature = temperature
        
        # 初始化查询和文档编码器,通常使用预训练的 Transformer 模型
        self.query_encoder = AutoModel.from_pretrained(args.model_name_or_path)
        self.doc_encoder = AutoModel.from_pretrained(args.model_name_or_path)
        
        # 将池化策略设置为 CLS 池化(获取 [CLS] token 的输出)
        self.pooling = self.mean_pooling
    
    def mean_pooling(self, model_output, attention_mask):
        """对 Transformer 输出进行平均池化处理"""
        token_embeddings = model_output.last_hidden_state
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size())
        return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
    
    def forward(self, query_input_ids, query_attention_mask, doc_input_ids, doc_attention_mask):
        # 编码查询
        query_outputs = self.query_encoder(input_ids=query_input_ids, attention_mask=query_attention_mask)
        query_embeddings = self.pooling(query_outputs, query_attention_mask)
        
        # 编码文档
        doc_outputs = self.doc_encoder(input_ids=doc_input_ids, attention_mask=doc_attention_mask)
        doc_embeddings = self.pooling(doc_outputs, doc_attention_mask)
        
        # 归一化处理
        if self.normalized:
            query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
            doc_embeddings = F.normalize(doc_embeddings, p=2, dim=1)
        
        return query_embeddings, doc_embeddings
    
    def compute_similarity(self, query_embeddings, doc_embeddings):
        # 计算查询和文档之间的相似度(点积)
        similarity_scores = torch.matmul(query_embeddings, doc_embeddings.T)
        return similarity_scores

    def compute_loss(self, query_embeddings, doc_embeddings):
        # 计算 InfoNCE 损失
        similarity_scores = self.compute_similarity(query_embeddings, doc_embeddings)
        labels = torch.arange(similarity_scores.size(0), device=similarity_scores.device)
        
        # 除以温度系数
        similarity_scores = similarity_scores / self.temperature
        
        # 计算交叉熵损失
        loss = F.cross_entropy(similarity_scores, labels)
        return loss

    def training_step(self, batch):
        # 从 batch 中提取查询和文档
        query_input_ids = batch['query_input_ids']
        query_attention_mask = batch['query_attention_mask']
        doc_input_ids = batch['doc_input_ids']
        doc_attention_mask = batch['doc_attention_mask']

        # 计算嵌入向量
        query_embeddings, doc_embeddings = self(
            query_input_ids=query_input_ids, 
            query_attention_mask=query_attention_mask,

解释实现细节

  1. 双编码器架构

    • self.query_encoderself.doc_encoder 分别用于处理查询和文档。它们通常是预训练的 BERT 或 RoBERTa 模型,可以使用相同的权重(共享参数)或者不同的权重(独立训练)。
  2. 池化策略

    • 采用了 平均池化mean_pooling)策略,从而对 Transformer 的输出序列进行池化得到句子级嵌入表示。
  3. L2 归一化

    • 为了确保余弦相似度的稳定性,对输出的嵌入进行了 L2 归一化处理。
  4. InfoNCE 损失

    • 使用 InfoNCE Loss 作为对比学习的损失函数,通过温度参数 self.temperature 控制相似度分布的平滑度。
  5. 跨设备负样本

    • 在本代码中,我们保留了跨设备负样本(negatives_cross_device)的功能接口,如果您希望扩展为跨设备训练,可以使用 torch.distributed 相关函数。

如何对其进行训练呢,参考的代码如下:

# 初始化参数
class Args:
    model_name_or_path = 'bert-base-uncased'
    learning_rate = 2e-5
    temperature = 0.05

args = Args()

# 初始化模型和优化器
model = BiEncoderModel(args)
optimizer = torch.optim.AdamW(model.parameters(), lr=args.learning_rate)

# 多轮训练
for epoch in range(3):
    for batch in train_dataloader:  # train_dataloader 是 DataLoader 对象
        optimizer.zero_grad()
        
        # 前向传播并计算损失
        loss = model.training_step(batch)
        
        # 反向传播与参数更新
        loss.backward()
        optimizer.step()
        
        print(f"Epoch {epoch+1} Loss: {loss.item()}")


六、总结

对比学习和双塔模型是现代 NLP 技术的强大组合。对比学习通过优化嵌入空间,使得模型能够更好地理解查询和文档之间的语义关系;而双塔模型则通过高效的编码结构,显著提升了检索和推荐任务的性能。二者结合应用,可以大幅提高语义搜索、信息检索和问答系统的效果,是构建智能搜索引擎和推荐系统的有力工具。

希望通过本文的介绍,您能够更好地理解对比学习和双塔模型的核心原理,并将其应用于实际项目中,提升系统的智能化水平。