作者: Abheesht Sharma, Fabien Hertschuh
创建日期 2025/04/28
最后修改日期 2025/04/28
描述: 使用一个模型同时进行召回和排序。
在基本召回和基本排序教程中,我们分别为召回和排序任务创建了单独的模型。然而,在许多情况下,构建一个单一的、联合的多任务模型可以比为每个任务创建独立模型带来更好的性能。在数据分布不均匀的情况下尤其如此——例如,丰富的数据(如点击)与稀疏的数据(如购买、退货或人工评论)。在这种场景下,联合模型可以利用从丰富数据中学到的表示来改进稀疏数据上的预测,这种技术称为迁移学习。例如,研究表明,通过结合使用丰富点击日志数据的辅助任务,训练用于从稀疏调查数据预测用户评分的模型可以得到显著增强。
在本例中,我们使用 MovieLens 数据集开发一个多目标推荐系统。我们整合了隐式反馈(例如,观看电影)和显式反馈(例如,评分),以创建一个更强大、更有效的推荐模型。对于前者,我们预测“观看电影”,即用户是否观看过某部电影;对于后者,我们预测用户对某部电影给出的评分。
首先,我们导入必要的包。
!pip install -q keras-rs
import os
os.environ["KERAS_BACKEND"] = "jax" # `"tensorflow"`/`"torch"`
import keras
import tensorflow as tf # Needed for the dataset
import tensorflow_datasets as tfds
import keras_rs
我们使用 MovieLens 数据集。数据加载和处理步骤与之前的教程类似,因此这里不再详细讨论。
# Ratings data with user and movie data.
ratings = tfds.load("movielens/100k-ratings", split="train")
# Features of all the available movies.
movies = tfds.load("movielens/100k-movies", split="train")
获取用户和电影计数,以便我们可以定义嵌入层。
users_count = (
ratings.map(lambda x: tf.strings.to_number(x["user_id"], out_type=tf.int32))
.reduce(tf.constant(0, tf.int32), tf.maximum)
.numpy()
)
movies_count = movies.cardinality().numpy()
我们的输入是 "user_id"
和 "movie_id"
。排序任务的标签是 "user_rating"
。"user_rating"
是一个介于 0 到 4 之间的整数。我们将其约束在 [0, 1]
范围内。
def preprocess_rating(x):
return (
{
"user_id": tf.strings.to_number(x["user_id"], out_type=tf.int32),
"movie_id": tf.strings.to_number(x["movie_id"], out_type=tf.int32),
},
(x["user_rating"] - 1.0) / 4.0,
)
shuffled_ratings = ratings.map(preprocess_rating).shuffle(
100_000, seed=42, reshuffle_each_iteration=False
)
将数据集分割成训练集和测试集。
train_ratings = shuffled_ratings.take(80_000).batch(1000).cache()
test_ratings = shuffled_ratings.skip(80_000).take(20_000).batch(1000).cache()
我们以类似于基本召回和基本排序指南的方式构建模型。
对于召回任务(即预测用户是否观看过某部电影),我们计算相应的用户和电影嵌入的相似度,并使用交叉熵损失,其中正样本对被标记为 1,批次中的所有其他样本都被视为“负样本”。我们报告此任务的 Top-K 准确率。
对于排序任务(即给定一个用户-电影对,预测评分),我们将用户和电影嵌入连接起来,并将其传递给一个密集模块。我们在这里使用 MSE 损失,并报告均方根误差 (RMSE)。
最终损失是上述两个损失的加权组合,权重分别为 "retrieval_loss_wt"
和 "ranking_loss_wt"
。这些权重决定了模型将侧重于哪个任务。
class MultiTaskModel(keras.Model):
def __init__(
self,
num_users,
num_candidates,
embedding_dimension=32,
layer_sizes=(256, 128),
retrieval_loss_wt=1.0,
ranking_loss_wt=1.0,
**kwargs,
):
super().__init__(**kwargs)
# Our query tower, simply an embedding table.
self.user_embedding = keras.layers.Embedding(num_users, embedding_dimension)
# Our candidate tower, simply an embedding table.
self.candidate_embedding = keras.layers.Embedding(
num_candidates, embedding_dimension
)
# Rating model.
self.rating_model = keras.Sequential(
[
keras.layers.Dense(layer_size, activation="relu")
for layer_size in layer_sizes
]
+ [keras.layers.Dense(1)]
)
# The layer that performs the retrieval.
self.retrieval = keras_rs.layers.BruteForceRetrieval(k=10, return_scores=False)
self.retrieval_loss_fn = keras.losses.CategoricalCrossentropy(
from_logits=True,
reduction="sum",
)
self.ranking_loss_fn = keras.losses.MeanSquaredError()
# Top-k accuracy for retrieval
self.top_k_metric = keras.metrics.SparseTopKCategoricalAccuracy(
k=10, from_sorted_ids=True
)
# RMSE for ranking
self.rmse_metric = keras.metrics.RootMeanSquaredError()
# Attributes.
self.num_users = num_users
self.num_candidates = num_candidates
self.embedding_dimension = embedding_dimension
self.layer_sizes = layer_sizes
self.retrieval_loss_wt = retrieval_loss_wt
self.ranking_loss_wt = ranking_loss_wt
def build(self, input_shape):
self.user_embedding.build(input_shape)
self.candidate_embedding.build(input_shape)
# In this case, the candidates are directly the movie embeddings.
# We take a shortcut and directly reuse the variable.
self.retrieval.candidate_embeddings = self.candidate_embedding.embeddings
self.retrieval.build(input_shape)
self.rating_model.build((None, 2 * self.embedding_dimension))
super().build(input_shape)
def call(self, inputs, training=False):
# Unpack inputs. Note that we have the if condition throughout this
# `call()` method so that we can do a `.predict()` for the retrieval
# task.
user_id = inputs["user_id"]
if "movie_id" in inputs:
movie_id = inputs["movie_id"]
result = {}
# Get user, movie embeddings.
user_embeddings = self.user_embedding(user_id)
result["user_embeddings"] = user_embeddings
if "movie_id" in inputs:
candidate_embeddings = self.candidate_embedding(movie_id)
result["candidate_embeddings"] = candidate_embeddings
# Pass both embeddings through the rating block of the model.
rating = self.rating_model(
keras.ops.concatenate([user_embeddings, candidate_embeddings], axis=1)
)
result["rating"] = rating
if not training:
# Skip the retrieval of top movies during training as the
# predictions are not used.
result["predictions"] = self.retrieval(user_embeddings)
return result
def compute_loss(self, x, y, y_pred, sample_weight, training=True):
user_embeddings = y_pred["user_embeddings"]
candidate_embeddings = y_pred["candidate_embeddings"]
# 1. Retrieval
# Compute the affinity score by multiplying the two embeddings.
scores = keras.ops.matmul(
user_embeddings,
keras.ops.transpose(candidate_embeddings),
)
# Retrieval labels: One-hot vectors
num_users = keras.ops.shape(user_embeddings)[0]
num_candidates = keras.ops.shape(candidate_embeddings)[0]
retrieval_labels = keras.ops.eye(num_users, num_candidates)
# Retrieval loss
retrieval_loss = self.retrieval_loss_fn(retrieval_labels, scores, sample_weight)
# 2. Ranking
ratings = y
pred_rating = y_pred["rating"]
# Ranking labels are just ratings.
ranking_labels = keras.ops.expand_dims(ratings, -1)
# Ranking loss
ranking_loss = self.ranking_loss_fn(ranking_labels, pred_rating, sample_weight)
# Total loss is a weighted combination of the two losses.
total_loss = (
self.retrieval_loss_wt * retrieval_loss
+ self.ranking_loss_wt * ranking_loss
)
return total_loss
def compute_metrics(self, x, y, y_pred, sample_weight=None):
# RMSE can be computed irrespective of whether we are
# training/evaluating.
self.rmse_metric.update_state(
y,
y_pred["rating"],
sample_weight=sample_weight,
)
if "predictions" in y_pred:
# We are evaluating or predicting. Update `top_k_metric`.
movie_ids = x["movie_id"]
predictions = y_pred["predictions"]
# For `top_k_metric`, which is a `SparseTopKCategoricalAccuracy`, we
# only take top rated movies, and we put a weight of 0 for the rest.
rating_weight = keras.ops.cast(keras.ops.greater(y, 0.9), "float32")
sample_weight = (
rating_weight
if sample_weight is None
else keras.ops.multiply(rating_weight, sample_weight)
)
self.top_k_metric.update_state(
movie_ids, predictions, sample_weight=sample_weight
)
return self.get_metrics_result()
else:
# We are training. `top_k_metric` is not updated and is zero, so
# don't report it.
result = self.get_metrics_result()
result.pop(self.top_k_metric.name)
return result
我们将在这里训练三个不同的模型。通过传递正确的损失权重可以轻松实现这一点
# Rating-specialised model
model = MultiTaskModel(
num_users=users_count + 1,
num_candidates=movies_count + 1,
ranking_loss_wt=1.0,
retrieval_loss_wt=0.0,
)
model.compile(optimizer=keras.optimizers.Adagrad(0.1))
model.fit(train_ratings, epochs=5)
model.evaluate(test_ratings)
# Retrieval-specialised model
model = MultiTaskModel(
num_users=users_count + 1,
num_candidates=movies_count + 1,
ranking_loss_wt=0.0,
retrieval_loss_wt=1.0,
)
model.compile(optimizer=keras.optimizers.Adagrad(0.1))
model.fit(train_ratings, epochs=5)
model.evaluate(test_ratings)
# Multi-task model
model = MultiTaskModel(
num_users=users_count + 1,
num_candidates=movies_count + 1,
ranking_loss_wt=1.0,
retrieval_loss_wt=1.0,
)
model.compile(optimizer=keras.optimizers.Adagrad(0.1))
model.fit(train_ratings, epochs=5)
model.evaluate(test_ratings)
Epoch 1/5
80/80 ━━━━━━━━━━━━━━━━━━━━ 5s 13ms/step - loss: 0.1089 - root_mean_squared_error: 0.3242
Epoch 2/5
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.0777 - root_mean_squared_error: 0.2788
Epoch 3/5
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0764 - root_mean_squared_error: 0.2763
Epoch 4/5
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 0.0742 - root_mean_squared_error: 0.2724
Epoch 5/5
20/20 ━━━━━━━━━━━━━━━━━━━━ 3s 28ms/step - loss: 0.0716 - root_mean_squared_error: 0.2675 - sparse_top_k_categorical_accuracy: 0.0063
Epoch 1/5
80/80 ━━━━━━━━━━━━━━━━━━━━ 2s 11ms/step - loss: 6855.5034 - root_mean_squared_error: 0.6792
Epoch 2/5
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 6523.5024 - root_mean_squared_error: 0.6524
Epoch 3/5
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6376.9727 - root_mean_squared_error: 0.6512
Epoch 4/5
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6288.7183 - root_mean_squared_error: 0.6527
Epoch 5/5
20/20 ━━━━━━━━━━━━━━━━━━━━ 1s 28ms/step - loss: 6551.5796 - root_mean_squared_error: 0.6573 - sparse_top_k_categorical_accuracy: 0.0197
Epoch 1/5
80/80 ━━━━━━━━━━━━━━━━━━━━ 2s 11ms/step - loss: 6860.5400 - root_mean_squared_error: 0.3157
Epoch 2/5
80/80 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - loss: 6520.5342 - root_mean_squared_error: 0.2598
Epoch 3/5
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 2ms/step - loss: 6376.9668 - root_mean_squared_error: 0.2528
Epoch 4/5
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step - loss: 6291.7310 - root_mean_squared_error: 0.2502
Epoch 5/5
20/20 ━━━━━━━━━━━━━━━━━━━━ 1s 28ms/step - loss: 6552.1499 - root_mean_squared_error: 0.2483 - sparse_top_k_categorical_accuracy: 0.0178
[6554.578125, 0.01855670101940632, 0.25010260939598083]
让我们绘制一个指标表格并记录我们的观察结果
模型 | Top-K 准确率 (↑) | RMSE (↓) |
---|---|---|
评分专用 | 0.005 | 0.26 |
召回专用 | 0.020 | 0.78 |
多任务 | 0.022 | 0.25 |
正如预期的那样,评分专用模型的 RMSE 很好,但 Top-K 准确率较差。对于召回专用模型,情况则相反。
对于多任务模型,我们注意到该模型在两个任务上都表现良好(甚至略优于两个专用模型)。总的来说,我们可以期待多任务学习带来更好的结果,特别是在一个任务数据来源丰富而另一个任务在稀疏数据上训练的情况下。
现在,让我们进行预测!我们将首先进行召回,然后对于召回的电影列表,我们将使用同一个模型预测评分。
movie_id_to_movie_title = {
int(x["movie_id"]): x["movie_title"] for x in movies.as_numpy_iterator()
}
movie_id_to_movie_title[0] = "" # Because id 0 is not in the dataset.
user_id = 5
retrieved_movie_ids = model.predict(
{
"user_id": keras.ops.array([user_id]),
}
)
retrieved_movie_ids = keras.ops.convert_to_numpy(retrieved_movie_ids["predictions"][0])
retrieved_movies = [movie_id_to_movie_title[x] for x in retrieved_movie_ids]
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 84ms/step
对于这些召回的电影,我们现在可以获得相应的评分。
pred_ratings = model.predict(
{
"user_id": keras.ops.array([user_id] * len(retrieved_movie_ids)),
"movie_id": keras.ops.array(retrieved_movie_ids),
}
)["rating"]
pred_ratings = keras.ops.convert_to_numpy(keras.ops.squeeze(pred_ratings, axis=1))
for movie_id, prediction in zip(retrieved_movie_ids, pred_ratings):
print(f"{movie_id_to_movie_title[movie_id]}: {5.0 * prediction:,.2f}")
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 418ms/step
b'Blob, The (1958)': 1.54
b'Little Rascals, The (1994)': 1.83
b'Jaws 3-D (1983)': 2.01
b'Black Beauty (1994)': 2.23
b'Burnt Offerings (1976)': 2.00
b'Mighty Morphin Power Rangers: The Movie (1995)': 2.11
b'Beverly Hillbillies, The (1993)': 2.12
b'Flintstones, The (1994)': 2.42
b'Heavy Metal (1981)': 2.67
b'Lassie (1994)': 2.02