В этом руководстве мы уточняем модель встраивания Sentence-Transformers, используя Matryoshka Representation Learning, чтобы первые измерения вектора несли наиболее полезный семантический сигнал.
Мы обучаем модель с помощью MatryoshkaLoss на тройных данных, а затем проверяем ключевое обещание MRL, сравнивая качество поиска после усечения вложений до 64, 128 и 256 измерений.
Установка необходимых библиотек
Сначала установим необходимые библиотеки и импортируем все необходимые модули для обучения и оценки.
«`python
!pip -q install -U sentence-transformers datasets accelerate
import math
import random
import numpy as np
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer, InputExample
from sentence_transformers import losses
from sentencetransformers.util import cossim
«`
Затем зададим детерминированное начальное значение, чтобы наше поведение при выборке и обучении оставалось согласованным при каждом запуске.
«`python
def set_seed(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manualseedall(seed)
set_seed(42)
«`
Оценка метрик поиска
«`python
@torch.no_grad()
def retrievalmetricsmrrrecallat_k(
model,
queries,
corpus,
qrels,
dims_list=(64, 128, 256, None),
k=10,
batch_size=64,
):
device = «cuda» if torch.cuda.is_available() else «cpu»
model.to(device)
qids = list(queries.keys())
docids = list(corpus.keys())
q_texts = [queries[qid] for qid in qids]
d_texts = [corpus[did] for did in docids]
qemb = model.encode(qtexts, batchsize=batchsize, converttotensor=True, normalize_embeddings=True)
demb = model.encode(dtexts, batchsize=batchsize, converttotensor=True, normalize_embeddings=True)
results = {}
for dim in dims_list:
if dim is None:
qe = q_emb
de = d_emb
dim_name = «full»
else:
qe = q_emb[:, :dim]
de = d_emb[:, :dim]
dim_name = str(dim)
qe = torch.nn.functional.normalize(qe, p=2, dim=1)
de = torch.nn.functional.normalize(de, p=2, dim=1)
sims = cos_sim(qe, de)
mrr_total = 0.0
recall_total = 0.0
for i, qid in enumerate(qids):
rel = qrels.get(qid, set())
if not rel:
continue
topk = torch.topk(sims[i], k=min(k, sims.shape[1]), largest=True).indices.tolist()
topk_docids = [docids[j] for j in topk]
recalltotal += 1.0 if any(d in rel for d in topkdocids) else 0.0
rr = 0.0
for rank, d in enumerate(topk_docids, start=1):
if d in rel:
rr = 1.0 / rank
break
mrr_total += rr
denom = max(1, len(qids))
results[dimname] = {f»MRR@{k}»: mrrtotal / denom, f»Recall@{k}»: recall_total / denom}
return results
«`
Загрузка и подготовка данных
«`python
DATASET_ID = «sentence-transformers/msmarco-co-condenser-margin-mse-sym-mnrl-mean-v1»
SUBSET = «triplet-hard»
SPLIT = «train»
TRAIN_SAMPLES = 4000
EVAL_QUERIES = 300
stream = loaddataset(DATASETID, SUBSET, split=SPLIT, streaming=True)
train_examples = []
eval_queries = {}
eval_corpus = {}
eval_qrels = {}
docidcounter = 0
qid_counter = 0
for row in stream:
q = (row.get(«query») or «»).strip()
pos = (row.get(«positive») or «»).strip()
neg = (row.get(«negative») or «»).strip()
if not q or not pos or not neg:
continue
train_examples.append(InputExample(texts=[q, pos, neg]))
if len(evalqueries) < EVALQUERIES:
qid = f»q{qid_counter}»
qid_counter += 1
posid = f»d{docidcounter}»; docid_counter += 1
negid = f»d{docidcounter}»; docid_counter += 1
eval_queries[qid] = q
evalcorpus[posid] = pos
evalcorpus[negid] = neg
evalqrels[qid] = {posid}
if len(trainexamples) >= TRAINSAMPLES and len(evalqueries) >= EVALQUERIES:
break
print(len(trainexamples), len(evalqueries), len(eval_corpus))
«`
Обучение модели
«`python
MODEL_ID = «BAAI/bge-base-en-v1.5»
device = «cuda» if torch.cuda.is_available() else «cpu»
model = SentenceTransformer(MODEL_ID, device=device)
fulldim = model.getsentenceembeddingdimension()
baseline = retrievalmetricsmrrrecallat_k(
model,
queries=eval_queries,
corpus=eval_corpus,
qrels=eval_qrels,
dims_list=(64, 128, 256, None),
k=10,
)
pretty_print(baseline, «BEFORE»)
batch_size = 16
epochs = 1
warmup_steps = 100
trainloader = DataLoader(trainexamples, batchsize=batchsize, shuffle=True, drop_last=True)
base_loss = losses.MultipleNegativesRankingLoss(model=model)
mrldims = [fulldim, 512, 256, 128, 64] if fulldim >= 768 else [fulldim, 256, 128, 64]
mrl_loss = losses.MatryoshkaLoss(
model=model,
loss=base_loss,
matryoshkadims=mrldims
)
model.fit(
trainobjectives=[(trainloader, mrl_loss)],
epochs=epochs,
warmupsteps=warmupsteps,
showprogressbar=True,
)
after = retrievalmetricsmrrrecallat_k(
model,
queries=eval_queries,
corpus=eval_corpus,
qrels=eval_qrels,
dims_list=(64, 128, 256, None),
k=10,
)
pretty_print(after, «AFTER»)
out_dir = «mrl-msmarco-demo»
model.save(out_dir)
m64 = SentenceTransformer(outdir, truncatedim=64)
emb = m64.encode(
[«what is the liberal arts?», «liberal arts covers humanities and sciences»],
normalize_embeddings=True
)
print(emb.shape)
«`
Мы создали модель встраивания, оптимизированную для «матрёшки», которая поддерживает высокую производительность поиска даже при усечении векторов до небольших префиксных измерений, таких как 64. Мы проверили эффект, сравнив базовые показатели поиска с показателями после обучения.
1. Какие библиотеки и модули используются для обучения и оценки модели встраивания предложений в данном руководстве?
Ответ:
В данном руководстве используются следующие библиотеки и модули: `sentence-transformers`, `datasets`, `accelerate`, `numpy`, `torch`, а также модули для работы с данными (`DataLoader`), модели встраивания (`SentenceTransformer`), функции потерь (`losses`), и утилиты (`util`).
2. Как задаётся начальное значение для обеспечения согласованности при каждом запуске модели?
Ответ:
Начальное значение задаётся с помощью функции `setseed`, которая принимает параметр `seed` и устанавливает начальное значение для `random`, `np.random`, `torch.manualseed` и `torch.cuda.manualseedall`. В примере используется значение `seed=42`.
3. Какие метрики используются для оценки качества поиска после усечения вложений до различных измерений?
Ответ:
Для оценки качества поиска используются метрики `MRR (Mean Reciprocal Rank)` и `Recall at K`. Они вычисляются с помощью функции `retrievalmetricsmrrrecallat_k`, которая принимает модель, запросы, корпус, релевантные документы и другие параметры.
4. Какие параметры используются при обучении модели с помощью `MatryoshkaLoss`?
Ответ:
При обучении модели с помощью `MatryoshkaLoss` используются следующие параметры: `model` (модель встраивания), `loss` (функция потерь), `matryoshka_dims` (список измерений для усечения). В примере используется `MultipleNegativesRankingLoss` в качестве базовой функции потерь и `MatryoshkaLoss` для обучения модели.
5. Как сохраняется обученная модель после обучения?
Ответ:
Обученная модель сохраняется с помощью метода `save` класса `SentenceTransformer`. В примере модель сохраняется в директорию `out_dir` с именем `mrl-msmarco-demo`.