BruteForceRetrieval
类keras_rs.layers.BruteForceRetrieval(
candidate_embeddings: Optional[Any] = None,
candidate_ids: Optional[Any] = None,
k: int = 10,
return_scores: bool = True,
**kwargs: Any
)
暴力顶 k 检索。
此层维护一组候选对象,并能够为给定查询精确检索顶 k 个候选对象。它通过计算所有候选对象的查询分数并提取排在前列的候选对象来实现。返回的顶 k 个候选对象按分数排序。
默认情况下,此层返回一个包含最高分数和最高标识符的元组,但可以配置为返回一个包含最高标识符的单个张量。
候选对象的标识符可以指定为张量。如果未提供,则使用的 ID 仅是候选对象索引。
请注意,此层的序列化不保留候选对象,仅保存 k
和 return_scores
参数。在反序列化层之后,必须调用 update_candidates
。
参数
None
,则必须在使用此层之前通过 update_candidates
提供候选对象。None
,则返回候选对象的索引。True
时,此层返回一个包含最高分数和最高标识符的元组。当为 False
时,此层返回一个包含最高标识符的单个张量。示例
retrieval = keras_rs.layers.BruteForceRetrieval(k=100)
# At some later point, we update the candidates.
retrieval.update_candidates(candidate_embeddings, candidate_ids)
# We can then retrieve the top candidates for any number of queries.
# Scores are stored highest first. Scores correspond to ids in the same row.
tops_scores, top_ids = retrieval(query_embeddings)
call
方法BruteForceRetrieval.call(inputs: Any)
返回作为输入传入的查询的最高候选对象。
参数
返回
如果 returns_scores
为 True,则返回一个包含最高分数和最高标识符的元组;否则,返回一个包含最高标识符的张量。
update_candidates
方法BruteForceRetrieval.update_candidates(
candidate_embeddings: Any, candidate_ids: Optional[Any] = None
)
更新候选对象集,并可选地更新其候选 ID。
参数
None
,则返回候选对象的索引。