KerasRS / API 文档 / 检索层 / HardNegativeMining 层

HardNegativeMining 层

[源代码]

HardNegativeMining

keras_rs.layers.HardNegativeMining(num_hard_negatives: int, **kwargs: Any)

过滤 logits 和标签以返回难负样本。

输出将包含请求数量的难负样本以及正候选样本的 logits 和标签。

参数

  • num_hard_negatives: 要返回的难负样本数量。
  • **kwargs: 传递给基类的参数。

示例

# Create layer with the configured number of hard negatives to mine.
hard_negative_mining = keras_rs.layers.HardNegativeMining(
    num_hard_negatives=10
)

# This will retrieve the top 10 negative candidates plus the positive
# candidate from `labels` for each row.
out_logits, out_labels = hard_negative_mining(in_logits, in_labels)

[源代码]

call 方法

HardNegativeMining.call(logits: Any, labels: Any)

通过按查询的难负样本挖掘来过滤 logits 和标签。

结果将包含 num_hard_negatives 个负样本以及正候选样本的 logits 和标签。

参数

  • logits: Logits 张量,通常为 [batch_size, num_candidates],但可以有更多维度或为 1D,如 [num_candidates]
  • labels: 独热标签张量,必须与 logits 的形状相同。

返回

一个包含两个张量的元组,其最后一个维度由 num_candidates 替换为 num_hard_negatives + 1

  • logits: [..., num_hard_negatives + 1] 的 logits 张量。
  • labels: [..., num_hard_negatives + 1] 的独热标签张量。