作者: Abheesht Sharma, Fabien Hertschuh
创建日期 2025/04/28
最后修改日期 2025/04/28
描述:使用 ScANN 实现更快的检索。
检索模型旨在从海量数据池中快速识别出一小部分高度相关的候选项目,这些数据池通常包含数百万甚至数亿个项目。为了实时有效地响应用户的上下文和行为,这些模型必须在短短几毫秒内完成这项任务。
近似最近邻 (ANN) 搜索是实现这种效率水平的关键技术。在本教程中,我们将演示如何利用 ScANN——一个尖端的最近邻检索库——轻松地将检索扩展到数百万个项目。
ScANN 由 Google Research 开发,是一个高性能库,专为大规模密集向量相似性搜索而设计。它可以有效地索引候选嵌入数据库,从而在推理过程中实现快速搜索。通过利用先进的向量压缩技术和经过精心调优的算法,ScaNN 在速度和准确性之间取得了最佳平衡。因此,它可以显著优于暴力搜索方法,在准确性损失最小的情况下提供快速检索。
我们将从与基本检索示例相同的代码开始。数据处理、模型构建和训练完全相同。如果您之前已经看过基本检索示例,可以跳过此部分。
注意:ScANN 在 KerasRS 中没有自己的独立层,因为 ScANN 库仅支持 TensorFlow。在此示例中,我们直接使用 ScANN 库并演示它与 KerasRS 的用法。
让我们安装 scann
库并导入所有必需的包。我们还将后端设置为 JAX。
!pip install -q keras-rs
!pip install -q scann
import os
os.environ["KERAS_BACKEND"] = "jax" # `"tensorflow"`/`"torch"`
import time
import uuid
import keras
import tensorflow as tf # Needed for the dataset
import tensorflow_datasets as tfds
from scann import scann_ops
import keras_rs
[?25h
# Ratings data with user and movie data.
ratings = tfds.load("movielens/100k-ratings", split="train")
# Features of all the available movies.
movies = tfds.load("movielens/100k-movies", split="train")
# Get user and movie counts so that we can define embedding layers for both.
users_count = (
ratings.map(lambda x: tf.strings.to_number(x["user_id"], out_type=tf.int32))
.reduce(tf.constant(0, tf.int32), tf.maximum)
.numpy()
)
movies_count = movies.cardinality().numpy()
# Preprocess the dataset, by selecting only the relevant columns.
def preprocess_rating(x):
return (
# Input is the user IDs
tf.strings.to_number(x["user_id"], out_type=tf.int32),
# Labels are movie IDs + ratings between 0 and 1.
{
"movie_id": tf.strings.to_number(x["movie_id"], out_type=tf.int32),
"rating": (x["user_rating"] - 1.0) / 4.0,
},
)
shuffled_ratings = ratings.map(preprocess_rating).shuffle(
100_000, seed=42, reshuffle_each_iteration=False
)
# Train-test split.
train_ratings = shuffled_ratings.take(80_000).batch(1000).cache()
test_ratings = shuffled_ratings.skip(80_000).take(20_000).batch(1000).cache()
class RetrievalModel(keras.Model):
def __init__(
self,
num_users,
num_candidates,
embedding_dimension=32,
**kwargs,
):
super().__init__(**kwargs)
# Our query tower, simply an embedding table.
self.user_embedding = keras.layers.Embedding(num_users, embedding_dimension)
# Our candidate tower, simply an embedding table.
self.candidate_embedding = keras.layers.Embedding(
num_candidates, embedding_dimension
)
self.loss_fn = keras.losses.MeanSquaredError()
def build(self, input_shape):
self.user_embedding.build(input_shape)
self.candidate_embedding.build(input_shape)
super().build(input_shape)
def call(self, inputs, training=False):
user_embeddings = self.user_embedding(inputs)
result = {
"user_embeddings": user_embeddings,
}
return result
def compute_loss(self, x, y, y_pred, sample_weight, training=True):
candidate_id, rating = y["movie_id"], y["rating"]
user_embeddings = y_pred["user_embeddings"]
candidate_embeddings = self.candidate_embedding(candidate_id)
labels = keras.ops.expand_dims(rating, -1)
# Compute the affinity score by multiplying the two embeddings.
scores = keras.ops.sum(
keras.ops.multiply(user_embeddings, candidate_embeddings),
axis=1,
keepdims=True,
)
return self.loss_fn(labels, scores, sample_weight)
model = RetrievalModel(users_count + 1000, movies_count + 1000)
model.compile(optimizer=keras.optimizers.Adagrad(learning_rate=0.1))
history = model.fit(
train_ratings, validation_data=test_ratings, validation_freq=5, epochs=50
)
Epoch 1/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 3s 9ms/step - loss: 0.4772
Epoch 2/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - loss: 0.4772
Epoch 3/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4771
Epoch 4/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4771
Epoch 5/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 2s 25ms/step - loss: 0.4771 - val_loss: 0.4835
Epoch 6/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4770
Epoch 7/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4770
Epoch 8/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4769
Epoch 9/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4769
Epoch 10/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4769 - val_loss: 0.4835
Epoch 11/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4768
Epoch 12/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4768
Epoch 13/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4767
Epoch 14/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4767
Epoch 15/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4766 - val_loss: 0.4834
Epoch 16/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4766
Epoch 17/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4765
Epoch 18/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4765
Epoch 19/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4764
Epoch 20/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4763 - val_loss: 0.4833
Epoch 21/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4762
Epoch 22/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4761
Epoch 23/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4760
Epoch 24/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4759
Epoch 25/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4758 - val_loss: 0.4829
Epoch 26/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4757
Epoch 27/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4756
Epoch 28/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4754
Epoch 29/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4752
Epoch 30/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4750 - val_loss: 0.4823
Epoch 31/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4748
Epoch 32/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4746
Epoch 33/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4744
Epoch 34/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4741
Epoch 35/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4738 - val_loss: 0.4810
Epoch 36/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4734
Epoch 37/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4730
Epoch 38/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 1ms/step - loss: 0.4726
Epoch 39/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4721
Epoch 40/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4715 - val_loss: 0.4788
Epoch 41/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4709
Epoch 42/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4702
Epoch 43/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4695
Epoch 44/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4686
Epoch 45/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4677 - val_loss: 0.4749
Epoch 46/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4666
Epoch 47/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4654
Epoch 48/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4641
Epoch 49/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4627
Epoch 50/50
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 0.4610 - val_loss: 0.4679
在尝试 ScANN 之前,我们先采用暴力方法,即对于给定的用户,计算所有电影的得分,然后排序并选取排名前 k 的电影。当然,当电影数量庞大时,这种方法的可伸缩性不高。
candidate_embeddings = keras.ops.array(model.candidate_embedding.embeddings.numpy())
# Artificially duplicate candidate embeddings to simulate a large number of
# movies.
candidate_embeddings = keras.ops.concatenate(
[candidate_embeddings]
+ [
candidate_embeddings
* keras.random.uniform(keras.ops.shape(candidate_embeddings))
for _ in range(100)
],
axis=0,
)
user_embedding = model.user_embedding(keras.ops.array([10, 5, 42, 345]))
# Define the brute force retrieval layer.
brute_force_layer = keras_rs.layers.BruteForceRetrieval(
candidate_embeddings=candidate_embeddings,
k=10,
return_scores=False,
)
现在,让我们对该层进行前向传播。请注意,在之前的教程中,我们将上述层作为模型类的一个属性,然后调用 .predict()。这显然会更快(因为它编译为 XLA 代码),但由于 ScANN 无法这样做,我们在这里只进行正常的前向传播,不进行编译,以确保公平比较。
t0 = time.time()
pred_movie_ids = brute_force_layer(user_embedding)
print("Time taken by brute force layer (sec):", time.time() - t0)
Time taken by brute force layer (sec): 0.6420145034790039
现在,让我们使用 ScANN 检索电影。我们将使用 Google Research 的 ScANN 库构建该层,然后调用它。要完全理解所有参数,请参阅ScANN README 文件。
def build_scann(
candidates,
k=10,
distance_measure="dot_product",
dimensions_per_block=2,
num_reordering_candidates=500,
num_leaves=100,
num_leaves_to_search=30,
training_iterations=12,
):
builder = scann_ops.builder(
db=candidates,
num_neighbors=k,
distance_measure=distance_measure,
)
builder = builder.tree(
num_leaves=num_leaves,
num_leaves_to_search=num_leaves_to_search,
training_iterations=training_iterations,
)
builder = builder.score_ah(dimensions_per_block=dimensions_per_block)
if num_reordering_candidates is not None:
builder = builder.reorder(num_reordering_candidates)
# Set a unique name to prevent unintentional sharing between
# ScaNN instances.
searcher = builder.build(shared_name=str(uuid.uuid4()))
return searcher
def run_scann(searcher):
pred_movie_ids = searcher.search_batched_parallel(
user_embedding,
final_num_neighbors=10,
).indices
return pred_movie_ids
searcher = build_scann(candidates=candidate_embeddings)
t0 = time.time()
pred_movie_ids = run_scann(searcher)
print("Time taken by ScANN (sec):", time.time() - t0)
Time taken by ScANN (sec): 0.0032401084899902344
您可以看到延迟方面的性能明显提升。ScANN(0.003 秒)运行所需的时间是暴力方法层(0.15 秒)的五十分之一!