KerasRS / API 文档 / 检索层 / 检索层

检索层

[源代码]

Retrieval

keras_rs.layers.Retrieval(k: int = 10, return_scores: bool = True, **kwargs: Any)

检索基抽象类。

此层为所有检索层提供了一个通用接口。为了实现自定义检索层,应继承此抽象类。

参数

  • k: int。要检索的候选数量。
  • return_scores: bool。当为 True 时,此层返回一个包含最高分数和最高标识符的元组。当为 False 时,此层返回一个包含最高标识符的单个张量。

[源代码]

call 方法

Retrieval.call(inputs: Any)

返回作为输入传入的查询的最高候选。

参数

  • inputs: 要返回最高候选的查询。

返回

如果 returns_scores 为 True,则返回一个包含最高分数和最高标识符的元组,否则返回一个包含最高标识符的张量。


[源代码]

update_candidates 方法

Retrieval.update_candidates(
    candidate_embeddings: Any, candidate_ids: Optional[Any] = None
)

更新候选集,并可选地更新其候选 ID。

参数

  • candidate_embeddings: 候选嵌入。
  • candidate_ids: 候选的标识符。如果为 None,则返回候选的索引。

[源代码]

compute_score 方法

Retrieval.compute_score(query_embedding: Any, candidate_embedding: Any)

计算查询和候选的标准点积分数。

参数

  • query_embedding: 对应于要检索最高候选的查询的查询嵌入张量。
  • candidate_embedding: 候选嵌入张量。

返回

查询和候选的点积。