KerasRS / API 文档 / 检索层 / BruteForceRetrieval 层

BruteForceRetrieval 层

[源代码]

BruteForceRetrieval

keras_rs.layers.BruteForceRetrieval(
    candidate_embeddings: Optional[Any] = None,
    candidate_ids: Optional[Any] = None,
    k: int = 10,
    return_scores: bool = True,
    **kwargs: Any
)

暴力顶 k 检索。

此层维护一组候选对象,并能够为给定查询精确检索顶 k 个候选对象。它通过计算所有候选对象的查询分数并提取排在前列的候选对象来实现。返回的顶 k 个候选对象按分数排序。

默认情况下,此层返回一个包含最高分数和最高标识符的元组,但可以配置为返回一个包含最高标识符的单个张量。

候选对象的标识符可以指定为张量。如果未提供,则使用的 ID 仅是候选对象索引。

请注意,此层的序列化不保留候选对象,仅保存 kreturn_scores 参数。在反序列化层之后,必须调用 update_candidates

参数

  • candidate_embeddings:候选嵌入。如果为 None,则必须在使用此层之前通过 update_candidates 提供候选对象。
  • candidate_ids:候选对象的标识符。如果为 None,则返回候选对象的索引。
  • k:要检索的候选对象数量。
  • return_scores:当为 True 时,此层返回一个包含最高分数和最高标识符的元组。当为 False 时,此层返回一个包含最高标识符的单个张量。
  • **kwargs: 传递给基类的参数。

示例

retrieval = keras_rs.layers.BruteForceRetrieval(k=100)

# At some later point, we update the candidates.
retrieval.update_candidates(candidate_embeddings, candidate_ids)

# We can then retrieve the top candidates for any number of queries.
# Scores are stored highest first. Scores correspond to ids in the same row.
tops_scores, top_ids = retrieval(query_embeddings)

[源代码]

call 方法

BruteForceRetrieval.call(inputs: Any)

返回作为输入传入的查询的最高候选对象。

参数

  • inputs:要返回最高候选对象的查询。

返回

如果 returns_scores 为 True,则返回一个包含最高分数和最高标识符的元组;否则,返回一个包含最高标识符的张量。


[源代码]

update_candidates 方法

BruteForceRetrieval.update_candidates(
    candidate_embeddings: Any, candidate_ids: Optional[Any] = None
)

更新候选对象集,并可选地更新其候选 ID。

参数

  • candidate_embeddings:候选嵌入。
  • candidate_ids:候选对象的标识符。如果为 None,则返回候选对象的索引。