KerasRS / 示例 / 使用 Scalable Nearest Neighbours (ScANN) 实现更快的检索

使用 Scalable Nearest Neighbours (ScANN) 实现更快的检索

作者: Abheesht Sharma, Fabien Hertschuh
创建日期 2025/04/28
最后修改日期 2025/04/28
描述:使用 ScANN 实现更快的检索。

在 Colab 中查看 GitHub 源代码


引言

检索模型旨在从海量数据池中快速识别出一小部分高度相关的候选项目,这些数据池通常包含数百万甚至数亿个项目。为了实时有效地响应用户的上下文和行为,这些模型必须在短短几毫秒内完成这项任务。

近似最近邻 (ANN) 搜索是实现这种效率水平的关键技术。在本教程中,我们将演示如何利用 ScANN——一个尖端的最近邻检索库——轻松地将检索扩展到数百万个项目。

ScANN 由 Google Research 开发,是一个高性能库,专为大规模密集向量相似性搜索而设计。它可以有效地索引候选嵌入数据库,从而在推理过程中实现快速搜索。通过利用先进的向量压缩技术和经过精心调优的算法,ScaNN 在速度和准确性之间取得了最佳平衡。因此,它可以显著优于暴力搜索方法,在准确性损失最小的情况下提供快速检索。

我们将从与基本检索示例相同的代码开始。数据处理、模型构建和训练完全相同。如果您之前已经看过基本检索示例,可以跳过此部分。

注意:ScANN 在 KerasRS 中没有自己的独立层,因为 ScANN 库仅支持 TensorFlow。在此示例中,我们直接使用 ScANN 库并演示它与 KerasRS 的用法。


导入

让我们安装 scann 库并导入所有必需的包。我们还将后端设置为 JAX。

!pip install -q keras-rs
!pip install -q scann
import os

os.environ["KERAS_BACKEND"] = "jax"  # `"tensorflow"`/`"torch"`

import time
import uuid

import keras
import tensorflow as tf  # Needed for the dataset
import tensorflow_datasets as tfds
from scann import scann_ops

import keras_rs
[?25h

准备数据集

# Ratings data with user and movie data.
ratings = tfds.load("movielens/100k-ratings", split="train")
# Features of all the available movies.
movies = tfds.load("movielens/100k-movies", split="train")

# Get user and movie counts so that we can define embedding layers for both.
users_count = (
    ratings.map(lambda x: tf.strings.to_number(x["user_id"], out_type=tf.int32))
    .reduce(tf.constant(0, tf.int32), tf.maximum)
    .numpy()
)

movies_count = movies.cardinality().numpy()


# Preprocess the dataset, by selecting only the relevant columns.
def preprocess_rating(x):
    return (
        # Input is the user IDs
        tf.strings.to_number(x["user_id"], out_type=tf.int32),
        # Labels are movie IDs + ratings between 0 and 1.
        {
            "movie_id": tf.strings.to_number(x["movie_id"], out_type=tf.int32),
            "rating": (x["user_rating"] - 1.0) / 4.0,
        },
    )


shuffled_ratings = ratings.map(preprocess_rating).shuffle(
    100_000, seed=42, reshuffle_each_iteration=False
)
# Train-test split.
train_ratings = shuffled_ratings.take(80_000).batch(1000).cache()
test_ratings = shuffled_ratings.skip(80_000).take(20_000).batch(1000).cache()

实现模型

class RetrievalModel(keras.Model):
    def __init__(
        self,
        num_users,
        num_candidates,
        embedding_dimension=32,
        **kwargs,
    ):
        super().__init__(**kwargs)
        # Our query tower, simply an embedding table.
        self.user_embedding = keras.layers.Embedding(num_users, embedding_dimension)
        # Our candidate tower, simply an embedding table.
        self.candidate_embedding = keras.layers.Embedding(
            num_candidates, embedding_dimension
        )

        self.loss_fn = keras.losses.MeanSquaredError()

    def build(self, input_shape):
        self.user_embedding.build(input_shape)
        self.candidate_embedding.build(input_shape)

        super().build(input_shape)

    def call(self, inputs, training=False):
        user_embeddings = self.user_embedding(inputs)
        result = {
            "user_embeddings": user_embeddings,
        }
        return result

    def compute_loss(self, x, y, y_pred, sample_weight, training=True):
        candidate_id, rating = y["movie_id"], y["rating"]
        user_embeddings = y_pred["user_embeddings"]
        candidate_embeddings = self.candidate_embedding(candidate_id)

        labels = keras.ops.expand_dims(rating, -1)
        # Compute the affinity score by multiplying the two embeddings.
        scores = keras.ops.sum(
            keras.ops.multiply(user_embeddings, candidate_embeddings),
            axis=1,
            keepdims=True,
        )
        return self.loss_fn(labels, scores, sample_weight)

训练模型

model = RetrievalModel(users_count + 1000, movies_count + 1000)
model.compile(optimizer=keras.optimizers.Adagrad(learning_rate=0.1))

history = model.fit(
    train_ratings, validation_data=test_ratings, validation_freq=5, epochs=50
)
Epoch 1/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 3s 9ms/step - loss: 0.4772

Epoch 2/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - loss: 0.4772

Epoch 3/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4771

Epoch 4/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4771

Epoch 5/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 2s 25ms/step - loss: 0.4771 - val_loss: 0.4835

Epoch 6/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4770

Epoch 7/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4770

Epoch 8/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4769

Epoch 9/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4769

Epoch 10/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4769 - val_loss: 0.4835

Epoch 11/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4768

Epoch 12/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4768

Epoch 13/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4767

Epoch 14/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4767

Epoch 15/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4766 - val_loss: 0.4834

Epoch 16/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4766

Epoch 17/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4765

Epoch 18/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4765

Epoch 19/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4764

Epoch 20/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4763 - val_loss: 0.4833

Epoch 21/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4762

Epoch 22/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4761

Epoch 23/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4760

Epoch 24/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4759

Epoch 25/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4758 - val_loss: 0.4829

Epoch 26/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4757

Epoch 27/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4756

Epoch 28/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4754

Epoch 29/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4752

Epoch 30/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4750 - val_loss: 0.4823

Epoch 31/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4748

Epoch 32/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4746

Epoch 33/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4744

Epoch 34/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4741

Epoch 35/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4738 - val_loss: 0.4810

Epoch 36/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4734

Epoch 37/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4730

Epoch 38/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4726

Epoch 39/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4721

Epoch 40/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4715 - val_loss: 0.4788

Epoch 41/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4709

Epoch 42/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4702

Epoch 43/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4695

Epoch 44/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4686

Epoch 45/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4677 - val_loss: 0.4749

Epoch 46/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4666

Epoch 47/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4654

Epoch 48/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4641

Epoch 49/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4627

Epoch 50/50

80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4610 - val_loss: 0.4679

进行预测

在尝试 ScANN 之前,我们先采用暴力方法,即对于给定的用户,计算所有电影的得分,然后排序并选取排名前 k 的电影。当然,当电影数量庞大时,这种方法的可伸缩性不高。

candidate_embeddings = keras.ops.array(model.candidate_embedding.embeddings.numpy())
# Artificially duplicate candidate embeddings to simulate a large number of
# movies.
candidate_embeddings = keras.ops.concatenate(
    [candidate_embeddings]
    + [
        candidate_embeddings
        * keras.random.uniform(keras.ops.shape(candidate_embeddings))
        for _ in range(100)
    ],
    axis=0,
)

user_embedding = model.user_embedding(keras.ops.array([10, 5, 42, 345]))

# Define the brute force retrieval layer.
brute_force_layer = keras_rs.layers.BruteForceRetrieval(
    candidate_embeddings=candidate_embeddings,
    k=10,
    return_scores=False,
)

现在,让我们对该层进行前向传播。请注意,在之前的教程中,我们将上述层作为模型类的一个属性,然后调用 .predict()。这显然会更快(因为它编译为 XLA 代码),但由于 ScANN 无法这样做,我们在这里只进行正常的前向传播,不进行编译,以确保公平比较。

t0 = time.time()
pred_movie_ids = brute_force_layer(user_embedding)
print("Time taken by brute force layer (sec):", time.time() - t0)
Time taken by brute force layer (sec): 0.6420145034790039

现在,让我们使用 ScANN 检索电影。我们将使用 Google Research 的 ScANN 库构建该层,然后调用它。要完全理解所有参数,请参阅ScANN README 文件

def build_scann(
    candidates,
    k=10,
    distance_measure="dot_product",
    dimensions_per_block=2,
    num_reordering_candidates=500,
    num_leaves=100,
    num_leaves_to_search=30,
    training_iterations=12,
):
    builder = scann_ops.builder(
        db=candidates,
        num_neighbors=k,
        distance_measure=distance_measure,
    )

    builder = builder.tree(
        num_leaves=num_leaves,
        num_leaves_to_search=num_leaves_to_search,
        training_iterations=training_iterations,
    )
    builder = builder.score_ah(dimensions_per_block=dimensions_per_block)

    if num_reordering_candidates is not None:
        builder = builder.reorder(num_reordering_candidates)

    # Set a unique name to prevent unintentional sharing between
    # ScaNN instances.
    searcher = builder.build(shared_name=str(uuid.uuid4()))
    return searcher


def run_scann(searcher):
    pred_movie_ids = searcher.search_batched_parallel(
        user_embedding,
        final_num_neighbors=10,
    ).indices
    return pred_movie_ids


searcher = build_scann(candidates=candidate_embeddings)

t0 = time.time()
pred_movie_ids = run_scann(searcher)
print("Time taken by ScANN (sec):", time.time() - t0)
Time taken by ScANN (sec): 0.0032401084899902344

您可以看到延迟方面的性能明显提升。ScANN(0.003 秒)运行所需的时间是暴力方法层(0.15 秒)的五十分之一!