JaxLayer

[source]

JaxLayer

keras.layers.JaxLayer(
    call_fn, init_fn=None, params=None, state=None, seed=None, **kwargs
)

Keras Layer,用于封装 JAX 模型。

当使用 JAX 作为 Keras 的后端时,此层允许在 Keras 中使用 JAX 组件。

模型函数

此层接受 JAX 模型,形式为一个函数 call_fn,该函数必须接受以下具有精确名称的参数

  • params: 模型的可训练参数。
  • state (可选): 模型的非可训练状态。如果模型没有非可训练状态,则可以省略。
  • rng (可选): 一个 jax.random.PRNGKey 实例。如果模型在训练或推理期间都不需要 RNG,则可以省略。
  • inputs: 模型的输入,一个 JAX 数组或一个数组的 PyTree
  • training (可选): 一个参数,用于指定我们处于训练模式还是推理模式,训练模式下会传递 True。如果模型在训练模式和推理模式下的行为相同,则可以省略。

inputs 参数是强制性的。模型的输入必须通过单个参数提供。如果 JAX 模型将多个输入作为单独的参数,则必须将它们组合成单个结构,例如在 tupledict 中。

模型权重初始化

模型的 paramsstate 的初始化可以由此层处理,在这种情况下,必须提供 init_fn 参数。 这允许使用正确的形状动态初始化模型。或者,如果形状已知,则可以使用 params 参数以及可选的 state 参数来创建已初始化的模型。

如果提供 init_fn 函数,则它必须接受以下具有精确名称的参数

  • rng: 一个 jax.random.PRNGKey 实例。
  • inputs: 一个 JAX 数组或一个数组的 PyTree,其中包含占位符值,以提供输入的形状。
  • training (可选): 一个参数,用于指定我们处于训练模式还是推理模式。True 始终传递给 init_fn。 无论 call_fn 是否具有 training 参数,都可以省略。

具有非可训练状态的模型

对于具有非可训练状态的 JAX 模型

  • call_fn 必须具有 state 参数
  • call_fn 必须返回一个 tuple,其中包含模型的输出和模型新的非可训练状态
  • init_fn 必须返回一个 tuple,其中包含模型的初始可训练参数和模型的初始非可训练状态。

此代码显示了具有非可训练状态的模型的 call_fninit_fn 签名的可能组合。在此示例中,模型在 call_fn 中具有 training 参数和 rng 参数。

def stateful_call(params, state, rng, inputs, training):
    outputs = ...
    new_state = ...
    return outputs, new_state

def stateful_init(rng, inputs):
    initial_params = ...
    initial_state = ...
    return initial_params, initial_state

不具有非可训练状态的模型

对于不具有非可训练状态的 JAX 模型

  • call_fn 必须没有 state 参数
  • call_fn 必须仅返回模型的输出
  • init_fn 必须仅返回模型的初始可训练参数。

此代码显示了不具有非可训练状态的模型的 call_fninit_fn 签名的可能组合。在此示例中,模型在 call_fn 中没有 training 参数,也没有 rng 参数。

def stateless_call(params, inputs):
    outputs = ...
    return outputs

def stateless_init(rng, inputs):
    initial_params = ...
    return initial_params

符合所需的签名

如果模型具有与 JaxLayer 所需签名不同的签名,则可以轻松编写一个包装器方法来调整参数。 此示例显示了一个模型,该模型具有作为单独参数的多个输入,在 dict 中期望多个 RNG,并且具有与 training 含义相反的 deterministic 参数。 为了符合要求,输入使用 tuple 组合在单个结构中,RNG 被拆分并用于填充预期的 dict,并且布尔标志被取反

def my_model_fn(params, rngs, input1, input2, deterministic):
    ...
    if not deterministic:
        dropout_rng = rngs["dropout"]
        keep = jax.random.bernoulli(dropout_rng, dropout_rate, x.shape)
        x = jax.numpy.where(keep, x / dropout_rate, 0)
        ...
    ...
    return outputs

def my_model_wrapper_fn(params, rng, inputs, training):
    input1, input2 = inputs
    rng1, rng2 = jax.random.split(rng)
    rngs = {"dropout": rng1, "preprocessing": rng2}
    deterministic = not training
    return my_model_fn(params, rngs, input1, input2, deterministic)

keras_layer = JaxLayer(my_model_wrapper_fn, params=initial_params)

与 Haiku 模块一起使用

JaxLayer 允许使用 Haiku 组件,形式为 haiku.Module。 这是通过按照 Haiku 模式转换模块,然后在 call_fn 参数中传递 module.apply,并在需要时在 init_fn 参数中传递 module.init 来实现的。

如果模型具有非可训练状态,则应使用 haiku.transform_with_state 进行转换。 如果模型没有非可训练状态,则应使用 haiku.transform 进行转换。 此外,可选地,如果模块在 "apply" 中不使用 RNG,则可以使用 haiku.without_apply_rng 进行转换。

以下示例显示了如何从 Haiku 模块创建 JaxLayer,该模块通过 hk.next_rng_key() 使用随机数生成器,并接受训练位置参数

class MyHaikuModule(hk.Module):
    def __call__(self, x, training):
        x = hk.Conv2D(32, (3, 3))(x)
        x = jax.nn.relu(x)
        x = hk.AvgPool((1, 2, 2, 1), (1, 2, 2, 1), "VALID")(x)
        x = hk.Flatten()(x)
        x = hk.Linear(200)(x)
        if training:
            x = hk.dropout(rng=hk.next_rng_key(), rate=0.3, x=x)
        x = jax.nn.relu(x)
        x = hk.Linear(10)(x)
        x = jax.nn.softmax(x)
        return x

def my_haiku_module_fn(inputs, training):
    module = MyHaikuModule()
    return module(inputs, training)

transformed_module = hk.transform(my_haiku_module_fn)

keras_layer = JaxLayer(
    call_fn=transformed_module.apply,
    init_fn=transformed_module.init,
)

参数

  • call_fn: 调用模型的函数。 有关它接受的参数列表及其返回的输出,请参见以上描述。 init_fn: 调用以初始化模型的函数。 有关它接受的参数列表及其返回的输出,请参见以上描述。 如果为 None,则必须提供 params 和/或 state
  • params: 一个 PyTree,包含模型的所有可训练参数。 这允许传递已训练的参数或控制初始化。 如果 paramsstate 均为 None,则在构建时调用 init_fn 以初始化模型的可训练参数。
  • state: 一个 PyTree,包含模型的所有非可训练状态。 这允许传递学习到的状态或控制初始化。 如果 paramsstate 均为 None,并且 call_fn 接受 state 参数,则在构建时调用 init_fn 以初始化模型的非可训练状态。
  • seed: 随机数生成器的种子。 可选。
  • dtype: 层的计算和权重的 dtype。 也可以是 keras.DTypePolicy。 可选。 默认为默认策略。