TopPSampler

[源代码]

TopPSampler

keras_hub.samplers.TopPSampler(p=0.1, k=None, seed=None, **kwargs)

Top-P 采样器类。

此采样器实现了 Top-P 搜索算法。Top-P 搜索从输出概率总和大于 p 的最小子集中选择标记。换句话说,Top-P 将首先按可能性对标记预测进行排序,并忽略所选标记的累积概率超过 p 之后的所有标记,然后从剩余标记中选择一个标记。

参数

  • p: 浮点数,Top-P 的 p 值。
  • k: 整型。如果设置,此参数定义了一个启发式“Top-K”截止点,它在“Top-P”采样之前应用。所有不在 Top k 中的 logits 都将被丢弃,剩余的 logits 将被排序以找到 p 的截止点。设置此参数可以通过减少要排序的标记数量来显著加快采样速度。默认为 None
  • seed: int。随机种子。默认为 None

调用参数

{{call_args}}

示例

causal_lm = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")

# Pass by name to compile.
causal_lm.compile(sampler="top_p")
causal_lm.generate(["Keras is a"])

# Pass by object to compile.
sampler = keras_hub.samplers.TopPSampler(p=0.1, k=1_000)
causal_lm.compile(sampler=sampler)
causal_lm.generate(["Keras is a"])