Retrieval
类keras_rs.layers.Retrieval(k: int = 10, return_scores: bool = True, **kwargs: Any)
检索基抽象类。
此层为所有检索层提供了一个通用接口。为了实现自定义检索层,应继承此抽象类。
参数
True
时,此层返回一个包含最高分数和最高标识符的元组。当为 False
时,此层返回一个包含最高标识符的单个张量。call
方法Retrieval.call(inputs: Any)
返回作为输入传入的查询的最高候选。
参数
返回
如果 returns_scores
为 True,则返回一个包含最高分数和最高标识符的元组,否则返回一个包含最高标识符的张量。
update_candidates
方法Retrieval.update_candidates(
candidate_embeddings: Any, candidate_ids: Optional[Any] = None
)
更新候选集,并可选地更新其候选 ID。
参数
None
,则返回候选的索引。compute_score
方法Retrieval.compute_score(query_embedding: Any, candidate_embedding: Any)
计算查询和候选的标准点积分数。
参数
返回
查询和候选的点积。