Keras 3 API 文档 / 实用工具 / Keras 配置实用工具

Keras 配置实用工具

[源代码]

version 函数

keras.version()

[源代码]

clear_session 函数

keras.utils.clear_session(free_memory=True)

重置 Keras 生成的所有状态。

Keras 管理着一个全局状态,它用于实现函数式模型构建 API 并使自动生成的层名唯一。

如果你在循环中创建了许多模型,这个全局状态会随着时间消耗越来越多的内存,你可能需要清除它。调用 clear_session() 会释放全局状态:这有助于避免旧模型和层造成的混乱,尤其是在内存有限的情况下。

参数

  • free_memory:是否调用 Python 垃圾回收。通常建议调用它以确保已删除对象所使用的内存立即释放。但是,它可能需要几秒钟才能执行,因此在短循环中使用 clear_session() 时,你可能希望跳过它。

示例 1:在循环中创建模型时调用 clear_session()

for _ in range(100):
  # Without `clear_session()`, each iteration of this loop will
  # slightly increase the size of the global state managed by Keras
  model = keras.Sequential([
      keras.layers.Dense(10) for _ in range(10)])

for _ in range(100):
  # With `clear_session()` called at the beginning,
  # Keras starts with a blank state at each iteration
  # and memory consumption is constant over time.
  keras.backend.clear_session()
  model = keras.Sequential([
      keras.layers.Dense(10) for _ in range(10)])

示例 2:重置层名生成计数器

>>> layers = [keras.layers.Dense(10) for _ in range(10)]
>>> new_layer = keras.layers.Dense(10)
>>> print(new_layer.name)
dense_10
>>> keras.backend.clear_session()
>>> new_layer = keras.layers.Dense(10)
>>> print(new_layer.name)
dense

[源代码]

enable_traceback_filtering 函数

keras.config.enable_traceback_filtering()

开启回溯过滤。

原始的 Keras 回溯(也称为堆栈跟踪)包含许多内部帧,这可能难以阅读,同时对最终用户而言没有可操作性。默认情况下,Keras 会过滤它引发的大多数异常中的内部帧,以使回溯简短、可读,并专注于对你(你的代码)有可操作性的内容。

另请参阅 keras.config.disable_traceback_filtering()keras.config.is_traceback_filtering_enabled()

如果你之前通过 keras.config.disable_traceback_filtering() 禁用了回溯过滤,你可以通过 keras.config.enable_traceback_filtering() 重新启用它。


[源代码]

disable_traceback_filtering 函数

keras.config.disable_traceback_filtering()

关闭回溯过滤。

原始的 Keras 回溯(也称为堆栈跟踪)包含许多内部帧,这可能难以阅读,同时对最终用户而言没有可操作性。默认情况下,Keras 会过滤它引发的大多数异常中的内部帧,以使回溯简短、可读,并专注于对你(你的代码)有可操作性的内容。

另请参阅 keras.config.enable_traceback_filtering()keras.config.is_traceback_filtering_enabled()

如果你之前通过 keras.config.disable_traceback_filtering() 禁用了回溯过滤,你可以通过 keras.config.enable_traceback_filtering() 重新启用它。


[源代码]

is_traceback_filtering_enabled 函数

keras.config.is_traceback_filtering_enabled()

检查回溯过滤是否已启用。

原始的 Keras 回溯(也称为堆栈跟踪)包含许多内部帧,这可能难以阅读,同时对最终用户而言没有可操作性。默认情况下,Keras 会过滤它引发的大多数异常中的内部帧,以使回溯简短、可读,并专注于对你(你的代码)有可操作性的内容。

另请参阅 keras.config.enable_traceback_filtering()keras.config.disable_traceback_filtering()

如果你之前通过 keras.config.disable_traceback_filtering() 禁用了回溯过滤,你可以通过 keras.config.enable_traceback_filtering() 重新启用它。

返回

布尔值,如果回溯过滤已启用则为 True,否则为 False


[源代码]

enable_interactive_logging 函数

keras.config.enable_interactive_logging()

开启交互式日志记录。

当交互式日志记录启用时,Keras 通过标准输出显示日志。这在使用 Keras 的交互式环境(如 shell 或笔记本)中提供了最佳体验。


[源代码]

disable_interactive_logging 函数

keras.config.disable_interactive_logging()

关闭交互式日志记录。

当交互式日志记录禁用时,Keras 将日志发送到 absl.logging。这在使用 Keras 的非交互式方式(例如在服务器上运行训练或推理作业)时是最佳选择。


[源代码]

is_interactive_logging_enabled 函数

keras.config.is_interactive_logging_enabled()

检查交互式日志记录是否已启用。

要在将日志写入标准输出和 absl.logging 之间切换,你可以使用 keras.config.enable_interactive_logging()keras.config.disable_interactive_logging()

返回

布尔值,如果交互式日志记录已启用则为 True,否则为 False


[源代码]

enable_unsafe_deserialization 函数

keras.config.enable_unsafe_deserialization()

全局禁用安全模式,允许反序列化 lambda。


[源代码]

floatx 函数

keras.config.floatx()

以字符串形式返回默认浮点类型。

例如 'bfloat16', 'float16', 'float32', 'float64'

返回

字符串,当前的默认浮点类型。

示例

>>> keras.config.floatx()
'float32'

[源代码]

set_floatx 函数

keras.config.set_floatx(value)

设置默认浮点数据类型。

注意:不建议将其设置为 "float16" 进行训练,因为这很可能导致数值稳定性问题。相反,混合精度利用 float16float32 的混合。它可以通过调用 keras.mixed_precision.set_dtype_policy('mixed_float16') 来配置。

参数

  • value:字符串;'bfloat16''float16''float32''float64'

示例

>>> keras.config.floatx()
'float32'
>>> keras.config.set_floatx('float64')
>>> keras.config.floatx()
'float64'
>>> # Set it back to float32
>>> keras.config.set_floatx('float32')

引发

  • ValueError:在值无效的情况下。

[源代码]

image_data_format 函数

keras.config.image_data_format()

返回默认图像数据格式约定。

返回

一个字符串,要么是 'channels_first',要么是 'channels_last'

示例

>>> keras.config.image_data_format()
'channels_last'

[源代码]

set_image_data_format 函数

keras.config.set_image_data_format(data_format)

设置图像数据格式约定的值。

参数

  • data_format:字符串。'channels_first''channels_last'

示例

>>> keras.config.image_data_format()
'channels_last'
>>> keras.config.set_image_data_format('channels_first')
>>> keras.config.image_data_format()
'channels_first'
>>> # Set it back to `'channels_last'`
>>> keras.config.set_image_data_format('channels_last')

[源代码]

epsilon 函数

keras.config.epsilon()

返回在数值表达式中使用的模糊因子的值。

返回

一个浮点数。

示例

>>> keras.config.epsilon()
1e-07

[源代码]

set_epsilon 函数

keras.config.set_epsilon(value)

设置在数值表达式中使用的模糊因子的值。

参数

  • value:浮点数。epsilon 的新值。

示例

>>> keras.config.epsilon()
1e-07
>>> keras.config.set_epsilon(1e-5)
>>> keras.config.epsilon()
1e-05
>>> # Set it back to the default value.
>>> keras.config.set_epsilon(1e-7)

[源代码]

backend 函数

keras.config.backend()

用于确定当前后端的公共方法。

返回

字符串,Keras 当前使用的后端名称。"tensorflow""torch""jax" 之一。

示例

>>> keras.config.backend()
'tensorflow'