Keras 2 API 文档 / 层 API / 核心层 / 输入对象

输入对象

[源代码]

Input 函数

tf_keras.Input(
    shape=None,
    batch_size=None,
    name=None,
    dtype=None,
    sparse=None,
    tensor=None,
    ragged=None,
    type_spec=None,
    **kwargs
)

Input() 用于实例化一个 TF-Keras 张量。

TF-Keras 张量是一个类似符号张量的对象,我们通过添加一些属性来增强它,这样我们就可以仅通过知道模型的输入和输出就能构建一个 TF-Keras 模型。

例如,如果 abc 是 TF-Keras 张量,则可以执行以下操作:model = Model(input=[a, b], output=c)

参数

  • shape: 形状元组(整数),不包括批次大小。例如,shape=(32,) 表示预期输入将是 32 维向量的批次。该元组的元素可以是 None;'None' 元素表示形状未知的维度。
  • batch_size: 可选的静态批次大小(整数)。
  • name: 层的可选名称字符串。在模型中应唯一(不要重复使用相同的名称)。如果未提供,它将自动生成。
  • dtype: 输入期望的数据类型,以字符串表示(float32float64int32 等)。
  • sparse: 一个布尔值,指定要创建的占位符是否为稀疏的。'ragged' 和 'sparse' 中只能有一个为 True。请注意,如果 sparse 为 False,仍然可以将稀疏张量传递到输入中 - 它们将使用默认值 0 进行稠密化。
  • tensor: 可选的现有张量,用于包装到 Input 层中。如果设置,该层将使用此张量的 tf.TypeSpec 而不是创建一个新的占位符张量。
  • ragged: 一个布尔值,指定要创建的占位符是否为不规则的。'ragged' 和 'sparse' 中只能有一个为 True。在这种情况下,'shape' 参数中的 'None' 值表示不规则维度。有关 RaggedTensors 的更多信息,请参阅 本指南
  • type_spec: 用于创建输入占位符的 tf.TypeSpec 对象。提供此参数时,除 name 外的所有其他参数都必须为 None。
  • **kwargs: 已弃用的参数支持。支持 batch_shapebatch_input_shape

返回值

一个 tensor

示例

# this is a logistic regression in Keras
x = Input(shape=(32,))
y = Dense(16, activation='softmax')(x)
model = Model(x, y)

请注意,即使启用了渴望执行,Input 也会生成一个类似符号张量的对象(即占位符)。此类似符号张量的对象可以与以张量作为输入的较低级 TensorFlow 操作一起使用,如下所示

x = Input(shape=(32,))
y = tf.square(x)  # This op will be treated like a layer
model = Model(x, y)

(此行为不适用于诸如控制流之类的更高级别的 TensorFlow API,以及 tf.GradientTape 直接监视。)

但是,生成的模型不会跟踪用作 TensorFlow 操作输入的任何变量。所有变量的使用都必须发生在 TF-Keras 层内,以确保模型的权重会跟踪它们。

TF-Keras Input 还可以从任意 tf.TypeSpec 创建占位符,例如

x = Input(type_spec=tf.RaggedTensorSpec(shape=[None, None],
                                        dtype=tf.float32, ragged_rank=1))
y = x.values
model = Model(x, y)

传递任意 tf.TypeSpec 时,它必须表示整个批次的签名,而不仅仅是一个示例。

引发异常

  • ValueError: 如果同时提供了 sparseragged
  • ValueError: 如果同时提供了 shape 和(batch_input_shapebatch_shape)。
  • ValueError: 如果 shapetensortype_spec 均为 None。
  • ValueError: 如果在传递 type_spec 的同时,除了 type_spec 之外的参数不为 None。
  • ValueError: 如果提供了任何无法识别的参数。