Keras 3 API 文档 / 层 API / 核心层 / InputSpec 对象

InputSpec 对象

[源代码]

InputSpec

keras.InputSpec(
    dtype=None,
    shape=None,
    ndim=None,
    max_ndim=None,
    min_ndim=None,
    axes=None,
    allow_last_axis_squeeze=False,
    name=None,
    optional=False,
)

指定层每个输入的秩、数据类型和形状。

层可以(如果适用)公开一个 input_spec 属性:一个 InputSpec 实例,或一个 InputSpec 实例的嵌套结构(每个输入张量一个)。这些对象使层能够对 Layer.__call__ 的第一个参数运行输入兼容性检查,包括输入结构、输入秩、输入形状和输入数据类型。

形状中的 None 条目与任何维度兼容。

参数

  • dtype:输入的预期数据类型。
  • shape:形状元组,输入的预期形状(动态轴可以包含 None)。包含批次大小。
  • ndim:整数,输入的预期秩。
  • max_ndim:整数,输入的最大秩。
  • min_ndim:整数,输入的最小秩。
  • axes:将整数轴映射到特定维度值的字典。
  • allow_last_axis_squeeze:如果为 True,允许秩为 N+1 的输入,只要输入的最后一轴为 1;也允许秩为 N-1 的输入,只要规范的最后一轴为 1。
  • name:当将数据作为字典传递时,对应此输入的预期键。
  • optional:布尔值,表示输入是否可选。可选输入可以接受 None 值。

示例

class MyLayer(Layer):
    def __init__(self):
        super().__init__()
        # The layer will accept inputs with
        # shape (*, 28, 28) & (*, 28, 28, 1)
        # and raise an appropriate error message otherwise.
        self.input_spec = InputSpec(
            shape=(None, 28, 28, 1),
            allow_last_axis_squeeze=True)