Как создать оптимизированную для «матрёшки» модель встраивания предложений для сверхбыстрого поиска с усечением до 64 измерений

В этом руководстве мы уточняем модель встраивания Sentence-Transformers, используя Matryoshka Representation Learning, чтобы первые измерения вектора несли наиболее полезный семантический сигнал.

Мы обучаем модель с помощью MatryoshkaLoss на тройных данных, а затем проверяем ключевое обещание MRL, сравнивая качество поиска после усечения вложений до 64, 128 и 256 измерений.

Установка необходимых библиотек

Сначала установим необходимые библиотеки и импортируем все необходимые модули для обучения и оценки.

«`python
!pip -q install -U sentence-transformers datasets accelerate

import math
import random
import numpy as np
import torch

from datasets import load_dataset
from torch.utils.data import DataLoader

from sentence_transformers import SentenceTransformer, InputExample
from sentence_transformers import losses
from sentencetransformers.util import cossim
«`

Затем зададим детерминированное начальное значение, чтобы наше поведение при выборке и обучении оставалось согласованным при каждом запуске.

«`python
def set_seed(seed=42):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manualseedall(seed)

set_seed(42)
«`

Оценка метрик поиска

«`python
@torch.no_grad()
def retrievalmetricsmrrrecallat_k(
model,
queries,
corpus,
qrels,
dims_list=(64, 128, 256, None),
k=10,
batch_size=64,
):
device = «cuda» if torch.cuda.is_available() else «cpu»
model.to(device)

qids = list(queries.keys())
docids = list(corpus.keys())

q_texts = [queries[qid] for qid in qids]
d_texts = [corpus[did] for did in docids]

qemb = model.encode(qtexts, batchsize=batchsize, converttotensor=True, normalize_embeddings=True)
demb = model.encode(dtexts, batchsize=batchsize, converttotensor=True, normalize_embeddings=True)

results = {}

for dim in dims_list:
if dim is None:
qe = q_emb
de = d_emb
dim_name = «full»
else:
qe = q_emb[:, :dim]
de = d_emb[:, :dim]
dim_name = str(dim)
qe = torch.nn.functional.normalize(qe, p=2, dim=1)
de = torch.nn.functional.normalize(de, p=2, dim=1)

sims = cos_sim(qe, de)

mrr_total = 0.0
recall_total = 0.0

for i, qid in enumerate(qids):
rel = qrels.get(qid, set())
if not rel:
continue

topk = torch.topk(sims[i], k=min(k, sims.shape[1]), largest=True).indices.tolist()
topk_docids = [docids[j] for j in topk]

recalltotal += 1.0 if any(d in rel for d in topkdocids) else 0.0

rr = 0.0
for rank, d in enumerate(topk_docids, start=1):
if d in rel:
rr = 1.0 / rank
break
mrr_total += rr

denom = max(1, len(qids))
results[dimname] = {f»MRR@{k}»: mrrtotal / denom, f»Recall@{k}»: recall_total / denom}

return results
«`

Загрузка и подготовка данных

«`python
DATASET_ID = «sentence-transformers/msmarco-co-condenser-margin-mse-sym-mnrl-mean-v1»
SUBSET = «triplet-hard»
SPLIT = «train»

TRAIN_SAMPLES = 4000
EVAL_QUERIES = 300

stream = loaddataset(DATASETID, SUBSET, split=SPLIT, streaming=True)

train_examples = []
eval_queries = {}
eval_corpus = {}
eval_qrels = {}

docidcounter = 0
qid_counter = 0

for row in stream:
q = (row.get(«query») or «»).strip()
pos = (row.get(«positive») or «»).strip()
neg = (row.get(«negative») or «»).strip()

if not q or not pos or not neg:
continue

train_examples.append(InputExample(texts=[q, pos, neg]))

if len(evalqueries) < EVALQUERIES:
qid = f»q{qid_counter}»
qid_counter += 1

posid = f»d{docidcounter}»; docid_counter += 1
negid = f»d{docidcounter}»; docid_counter += 1

eval_queries[qid] = q
evalcorpus[posid] = pos
evalcorpus[negid] = neg
evalqrels[qid] = {posid}

if len(trainexamples) >= TRAINSAMPLES and len(evalqueries) >= EVALQUERIES:
break

print(len(trainexamples), len(evalqueries), len(eval_corpus))
«`

Обучение модели

«`python
MODEL_ID = «BAAI/bge-base-en-v1.5»

device = «cuda» if torch.cuda.is_available() else «cpu»
model = SentenceTransformer(MODEL_ID, device=device)
fulldim = model.getsentenceembeddingdimension()

baseline = retrievalmetricsmrrrecallat_k(
model,
queries=eval_queries,
corpus=eval_corpus,
qrels=eval_qrels,
dims_list=(64, 128, 256, None),
k=10,
)
pretty_print(baseline, «BEFORE»)

batch_size = 16
epochs = 1
warmup_steps = 100

trainloader = DataLoader(trainexamples, batchsize=batchsize, shuffle=True, drop_last=True)

base_loss = losses.MultipleNegativesRankingLoss(model=model)

mrldims = [fulldim, 512, 256, 128, 64] if fulldim >= 768 else [fulldim, 256, 128, 64]
mrl_loss = losses.MatryoshkaLoss(
model=model,
loss=base_loss,
matryoshkadims=mrldims
)

model.fit(
trainobjectives=[(trainloader, mrl_loss)],
epochs=epochs,
warmupsteps=warmupsteps,
showprogressbar=True,
)

after = retrievalmetricsmrrrecallat_k(
model,
queries=eval_queries,
corpus=eval_corpus,
qrels=eval_qrels,
dims_list=(64, 128, 256, None),
k=10,
)
pretty_print(after, «AFTER»)

out_dir = «mrl-msmarco-demo»
model.save(out_dir)

m64 = SentenceTransformer(outdir, truncatedim=64)
emb = m64.encode(
[«what is the liberal arts?», «liberal arts covers humanities and sciences»],
normalize_embeddings=True
)
print(emb.shape)
«`

Мы создали модель встраивания, оптимизированную для «матрёшки», которая поддерживает высокую производительность поиска даже при усечении векторов до небольших префиксных измерений, таких как 64. Мы проверили эффект, сравнив базовые показатели поиска с показателями после обучения.

1. Какие библиотеки и модули используются для обучения и оценки модели встраивания предложений в данном руководстве?

Ответ:
В данном руководстве используются следующие библиотеки и модули: `sentence-transformers`, `datasets`, `accelerate`, `numpy`, `torch`, а также модули для работы с данными (`DataLoader`), модели встраивания (`SentenceTransformer`), функции потерь (`losses`), и утилиты (`util`).

2. Как задаётся начальное значение для обеспечения согласованности при каждом запуске модели?

Ответ:
Начальное значение задаётся с помощью функции `setseed`, которая принимает параметр `seed` и устанавливает начальное значение для `random`, `np.random`, `torch.manualseed` и `torch.cuda.manualseedall`. В примере используется значение `seed=42`.

3. Какие метрики используются для оценки качества поиска после усечения вложений до различных измерений?

Ответ:
Для оценки качества поиска используются метрики `MRR (Mean Reciprocal Rank)` и `Recall at K`. Они вычисляются с помощью функции `retrievalmetricsmrrrecallat_k`, которая принимает модель, запросы, корпус, релевантные документы и другие параметры.

4. Какие параметры используются при обучении модели с помощью `MatryoshkaLoss`?

Ответ:
При обучении модели с помощью `MatryoshkaLoss` используются следующие параметры: `model` (модель встраивания), `loss` (функция потерь), `matryoshka_dims` (список измерений для усечения). В примере используется `MultipleNegativesRankingLoss` в качестве базовой функции потерь и `MatryoshkaLoss` для обучения модели.

5. Как сохраняется обученная модель после обучения?

Ответ:
Обученная модель сохраняется с помощью метода `save` класса `SentenceTransformer`. В примере модель сохраняется в директорию `out_dir` с именем `mrl-msmarco-demo`.

Источник