HardNegativeMining
类keras_rs.layers.HardNegativeMining(num_hard_negatives: int, **kwargs: Any)
过滤 logits 和标签以返回难负样本。
输出将包含请求数量的难负样本以及正候选样本的 logits 和标签。
参数
示例
# Create layer with the configured number of hard negatives to mine.
hard_negative_mining = keras_rs.layers.HardNegativeMining(
num_hard_negatives=10
)
# This will retrieve the top 10 negative candidates plus the positive
# candidate from `labels` for each row.
out_logits, out_labels = hard_negative_mining(in_logits, in_labels)
call
方法HardNegativeMining.call(logits: Any, labels: Any)
通过按查询的难负样本挖掘来过滤 logits 和标签。
结果将包含 num_hard_negatives
个负样本以及正候选样本的 logits 和标签。
参数
[batch_size, num_candidates]
,但可以有更多维度或为 1D,如 [num_candidates]
。logits
的形状相同。返回
一个包含两个张量的元组,其最后一个维度由 num_candidates
替换为 num_hard_negatives + 1
。
[..., num_hard_negatives + 1]
的 logits 张量。[..., num_hard_negatives + 1]
的独热标签张量。