Keras 2 API 文档 / 工具 / 模型绘图工具

模型绘图工具

[源代码]

plot_model 函数

tf_keras.utils.plot_model(
    model,
    to_file="model.png",
    show_shapes=False,
    show_dtype=False,
    show_layer_names=True,
    rankdir="TB",
    expand_nested=False,
    dpi=96,
    layer_range=None,
    show_layer_activations=False,
    show_trainable=False,
)

将 TF-Keras 模型转换为 dot 格式并保存到文件。

示例

input = tf.keras.Input(shape=(100,), dtype='int32', name='input')
x = tf.keras.layers.Embedding(
    output_dim=512, input_dim=10000, input_length=100)(input)
x = tf.keras.layers.LSTM(32)(x)
x = tf.keras.layers.Dense(64, activation='relu')(x)
x = tf.keras.layers.Dense(64, activation='relu')(x)
x = tf.keras.layers.Dense(64, activation='relu')(x)
output = tf.keras.layers.Dense(1, activation='sigmoid', name='output')(x)
model = tf.keras.Model(inputs=[input], outputs=[output])
dot_img_file = '/tmp/model_1.png'
tf.keras.utils.plot_model(model, to_file=dot_img_file, show_shapes=True)

参数

  • model: TF-Keras 模型实例
  • to_file: 绘图图像的文件名。
  • show_shapes: 是否显示形状信息。
  • show_dtype: 是否显示层的数据类型。
  • show_layer_names: 是否显示层名称。
  • rankdir: 传递给 PyDot 的 rankdir 参数,一个字符串,指定绘图的格式:'TB' 创建垂直绘图;'LR' 创建水平绘图。
  • expand_nested: 是否将嵌套模型扩展为集群。
  • dpi: 每英寸点数。
  • layer_range: list 类型输入,包含两个 str 项目,分别是起始层名称和结束层名称(均包含在内),指示要为其生成绘图的层范围。它也接受正则表达式模式而不是精确名称。在这种情况下,起始谓词将是它匹配 layer_range[0] 的第一个元素,结束谓词将是它匹配 layer_range[1] 的最后一个元素。默认为 None,它考虑模型的所有层。请注意,您必须传递范围,以使生成的子图必须完整。
  • show_layer_activations: 显示层激活(仅适用于具有 activation 属性的层)。
  • show_trainable: 是否显示层是否可训练。层可训练时显示“T”,不可训练时显示“NT”。

引发异常

  • ImportError: 如果 graphviz 或 pydot 不可用。
  • ValueError: 如果在构建模型之前调用 plot_model

返回值

如果安装了 Jupyter,则为 Jupyter Notebook 图像对象。这使得能够在笔记本中内联显示模型图。


[源代码]

model_to_dot 函数

tf_keras.utils.model_to_dot(
    model,
    show_shapes=False,
    show_dtype=False,
    show_layer_names=True,
    rankdir="TB",
    expand_nested=False,
    dpi=96,
    subgraph=False,
    layer_range=None,
    show_layer_activations=False,
    show_trainable=False,
)

将 TF-Keras 模型转换为 dot 格式。

参数

  • model: TF-Keras 模型实例。
  • show_shapes: 是否显示形状信息。
  • show_dtype: 是否显示层的数据类型。
  • show_layer_names: 是否显示层名称。
  • rankdir: 传递给 PyDot 的 rankdir 参数,一个字符串,指定绘图的格式:'TB' 创建垂直绘图;'LR' 创建水平绘图。
  • expand_nested: 是否将嵌套模型扩展为集群。
  • dpi: 每英寸点数。
  • subgraph: 是否返回 pydot.Cluster 实例。
  • layer_range: list 类型输入,包含两个 str 项目,分别是起始层名称和结束层名称(均包含在内),指示要为其生成 pydot.Dot 的层范围。它也接受正则表达式模式而不是精确名称。在这种情况下,起始谓词将是它匹配 layer_range[0] 的第一个元素,结束谓词将是它匹配 layer_range[1] 的最后一个元素。默认为 None,它考虑模型的所有层。请注意,您必须传递范围,以使生成的子图必须完整。
  • show_layer_activations: 显示层激活(仅适用于具有 activation 属性的层)。
  • show_trainable: 是否显示层是否可训练。层可训练时显示“T”,不可训练时显示“NT”。

返回值

一个 pydot.Dot 实例,表示 TF-Keras 模型,或者如果 subgraph=True,则为一个 pydot.Cluster 实例,表示嵌套模型。

引发异常

  • ValueError: 如果在构建模型之前调用 model_to_dot
  • ImportError: 如果 pydot 不可用。