KerasRS / API 文档 / 嵌入层 / DistributedEmbedding 层

DistributedEmbedding 层

[源代码]

DistributedEmbedding

keras_rs.layers.DistributedEmbedding(
    feature_configs: Union[
        keras_rs.src.layers.embedding.distributed_embedding_config.FeatureConfig,
        tensorflow.python.tpu.tpu_embedding_v2_utils.FeatureConfig,
        Sequence[
            Union[
                keras_rs.src.layers.embedding.distributed_embedding_config.FeatureConfig,
                tensorflow.python.tpu.tpu_embedding_v2_utils.FeatureConfig,
                ForwardRef("Nested[T]"),
            ]
        ],
        Mapping[
            str,
            Union[
                keras_rs.src.layers.embedding.distributed_embedding_config.FeatureConfig,
                tensorflow.python.tpu.tpu_embedding_v2_utils.FeatureConfig,
                ForwardRef("Nested[T]"),
            ],
        ],
    ],
    table_stacking: Union[str, Sequence[str], Sequence[Sequence[str]]] = "auto",
    **kwargs: Any
)

DistributedEmbedding,一个用于加速大型嵌入查找的层。


注意:DistributedEmbedding 处于预览阶段。


DistributedEmbedding 是一个针对带 SparseCore 的 TPU 芯片优化的层,可以显著提高嵌入查找和嵌入训练的速度。它通过将多个查找组合成一个调用,并将嵌入表分片到可用的芯片上来实现。请注意,只有当嵌入表足够大以至于需要分片(因为它们不适合单个芯片)时,才能看到性能提升。更多详细信息请参阅下面的“放置”部分。

在其他硬件(GPU、CPU 和不带 SparseCore 的 TPU)上,DistributedEmbedding 提供相同的 API,没有任何特定的加速。除了通过 keras.distribution.set_distribution 设置的分发方案外,不应用任何特定的分发方案。

DistributedEmbedding 嵌入输入序列,并通过应用可配置的组合器函数将其缩减为单个嵌入。

配置

特征和表

DistributedEmbedding 嵌入层通过一组 keras_rs.layers.FeatureConfig 对象进行配置,这些对象本身引用 keras_rs.layers.TableConfig 对象。

  • TableConfig 定义了一个嵌入表,其中包含词汇大小、嵌入维度等参数,以及用于缩减的组合器和用于训练的优化器。
  • FeatureConfig 定义 DistributedEmbedding 将处理哪些输入特征以及使用哪个嵌入表。请注意,多个特征可以使用相同的嵌入表。
table1 = keras_rs.layers.TableConfig(
    name="table1",
    vocabulary_size=TABLE1_VOCABULARY_SIZE,
    embedding_dim=TABLE1_EMBEDDING_SIZE,
    placement="auto",
)
table2 = keras_rs.layers.TableConfig(
    name="table2",
    vocabulary_size=TABLE2_VOCABULARY_SIZE,
    embedding_dim=TABLE2_EMBEDDING_SIZE,
    placement="auto",
)

feature1 = keras_rs.layers.FeatureConfig(
    name="feature1",
    table=table1,
    input_shape=(PER_REPLICA_BATCH_SIZE,),
    output_shape=(PER_REPLICA_BATCH_SIZE, TABLE1_EMBEDDING_SIZE),
)
feature2 = keras_rs.layers.FeatureConfig(
    name="feature2",
    table=table2,
    input_shape=(PER_REPLICA_BATCH_SIZE,),
    output_shape=(PER_REPLICA_BATCH_SIZE, TABLE2_EMBEDDING_SIZE),
)

feature_configs = {
    "feature1": feature1,
    "feature2": feature2,
}

embedding = keras_rs.layers.DistributedEmbedding(feature_configs)

优化器

DistributedEmbedding 中的每个嵌入表都使用自己的优化器进行训练,该优化器独立于通过 model.compile() 在模型上设置的优化器。

请注意,并非所有优化器都受支持。目前,所有后端和加速器都支持以下优化器:

此外,并非所有优化器参数都受支持(例如 SGDnesterov 选项)。当使用不受支持的优化器或不受支持的优化器参数时,会引发错误。

放置

DistributedEmbedding 中的每个嵌入表都可以放置在 SparseCore 芯片上,或放置在加速器的默认设备上(例如 TPU 上 Tensor Cores 的 HBM)。这由 keras_rs.layers.TableConfigplacement 属性控制。

  • "sparsecore" 放置表示该表应放置在 SparseCore 芯片上。如果选择此选项但没有 SparseCore 芯片,则会引发错误。
  • "default_device" 放置表示即使 SparseCore 可用,该表也不应放置在 SparseCore 上。相反,该表放置在模型通常所在的设备上,即 TPU 和 GPU 上的 HBM。在这种情况下,如果适用,该表使用通过 keras.distribution.set_distribution 设置的方案进行分发。在 GPU、CPU 和不带 SparseCore 的 TPU 上,这是唯一可用的放置方式,也是 "auto" 选择的方式。
  • "auto" 放置表示如果 SparseCore 可用,则使用 "sparsecore";否则使用 "default_device"。这是未指定时的默认值。

优化 TPU 性能

  • 需要分片的大型表应使用 "sparsecore" 放置。
  • 足够小的表应使用 "default_device",并且通常应通过使用 keras.distribution.DataParallel 分发选项在 TPU 上复制。

在带 SparseCore 的 TPU 上与 TensorFlow 一起使用

输入

除了 tf.TensorDistributedEmbedding 还接受 tf.RaggedTensortf.SparseTensor 作为嵌入查找的输入。不规则张量必须在索引为 1 的维度上不规则。请注意,如果传递权重,每个权重张量必须与该特定特征的输入属于同一类,并且对于不规则张量使用完全相同的不规则行长度,对于稀疏张量使用相同的索引。DistributedEmbedding 的所有输出都是密集张量。

设置

要在带 TensorFlow 的 TPU 上使用 DistributedEmbedding,必须使用 tf.distribute.TPUStrategyDistributedEmbedding 层必须在 TPUStrategy 下创建。

resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local")
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
device_assignment = tf.tpu.experimental.DeviceAssignment.build(
    topology, num_replicas=resolver.get_tpu_system_metadata().num_cores
)
strategy = tf.distribute.TPUStrategy(
    resolver, experimental_device_assignment=device_assignment
)

with strategy.scope():
    embedding = keras_rs.layers.DistributedEmbedding(feature_configs)

在 Keras 模型中的用法

要使用 Keras 的 model.fit(),必须在 TPUStrategy 下编译模型。然后,可以直接调用 model.fit()model.evaluate()model.predict()。Keras 模型负责使用策略运行模型,并自动分发数据集。

with strategy.scope():
    embedding = keras_rs.layers.DistributedEmbedding(feature_configs)
    model = create_model(embedding)
    model.compile(loss=keras.losses.MeanSquaredError(), optimizer="adam")

model.fit(dataset, epochs=10)

直接调用

DistributedEmbedding 必须通过嵌套在 tf.function 中的 strategy.run 调用来调用。

@tf.function
def embedding_wrapper(tf_fn_inputs, tf_fn_weights=None):
    def strategy_fn(st_fn_inputs, st_fn_weights):
        return embedding(st_fn_inputs, st_fn_weights)

    return strategy.run(strategy_fn, args=(tf_fn_inputs, tf_fn_weights)))

embedding_wrapper(my_inputs, my_weights)

使用数据集时,数据集必须是分布式。然后可以将迭代器传递给使用 strategy.runtf.function

dataset = strategy.experimental_distribute_dataset(dataset)

@tf.function
def run_loop(iterator):
    def step(data):
        (inputs, weights), labels = data
        with tf.GradientTape() as tape:
            result = embedding(inputs, weights)
            loss = keras.losses.mean_squared_error(labels, result)
        tape.gradient(loss, embedding.trainable_variables)
        return result

    for _ in tf.range(4):
        result = strategy.run(step, args=(next(iterator),))

run_loop(iter(dataset))

在带 SparseCore 的 TPU 上与 JAX 一起使用

设置

要在带 JAX 的 TPU 上使用 DistributedEmbedding,必须创建并设置 Keras Distribution

distribution = keras.distribution.DataParallel(devices=jax.device("tpu"))
keras.distribution.set_distribution(distribution)

输入

对于 JAX,输入可以是密集张量,也可以是不规则(嵌套)NumPy 数组。要启用 jit_compile = True,必须显式调用输入上的 layer.preprocess(...),然后将预处理后的输出馈送到模型。有关详细信息,请参阅下一节关于预处理。

不规则输入数组必须在索引为 1 的维度上不规则。请注意,如果传递权重,每个权重张量必须与该特定特征的输入属于同一类,并且对于不规则张量使用完全相同的不规则行长度。DistributedEmbedding 的所有输出都是密集张量。

预处理

在 JAX 中,SparseCore 的使用需要专门格式化的数据,这取决于可用硬件的属性。这种数据重新格式化目前不支持即时编译,因此必须在将数据传递到模型之前应用。

预处理适用于密集或不规则 NumPy 数组,或可转换为密集或不规则 NumPy 数组的张量,例如 tf.RaggedTensor

添加预处理的一种简单方法是使用 python 生成器将函数附加到输入管道。

# Create the embedding layer.
embedding_layer = DistributedEmbedding(feature_configs)

# Add preprocessing to a data input pipeline.
def preprocessed_dataset_generator(dataset):
    for (inputs, weights), labels in iter(dataset):
        yield embedding_layer.preprocess(
            inputs, weights, training=True
        ), labels

preprocessed_train_dataset = preprocessed_dataset_generator(train_dataset)

这个显式预处理阶段将输入和可选权重组合在一起,因此新数据可以直接传递到层或模型的 inputs 参数中。

注意:在多主机设置中进行数据并行时,数据需要正确地在主机之间分片。如果原始数据集类型为 tf.data.Dataset,则需要在应用预处理生成器之前手动分片。

# Manually shard the dataset across hosts.
train_dataset = distribution.distribute_dataset(train_dataset)
distribution.auto_shard_dataset = False  # Dataset is already sharded.

# Add a preprocessing stage to the distributed data input pipeline.
train_dataset = preprocessed_dataset_generator(train_dataset)

如果原始数据集不是 tf.data.Dataset,则必须已在主机之间预分片。

在 Keras 模型中的用法

一旦设置了全局分发并定义了输入预处理管道,模型训练就可以正常进行。例如:

# Construct, compile, and fit the model using the preprocessed data.
model = keras.Sequential(
  [
    embedding_layer,
    keras.layers.Dense(2),
    keras.layers.Dense(3),
    keras.layers.Dense(4),
  ]
)
model.compile(optimizer="adam", loss="mse", jit_compile=True)
model.fit(preprocessed_train_dataset, epochs=10)

直接调用

DistributedEmbedding 层也可以直接调用。当与 JIT 编译一起使用时,需要显式预处理。

# Call the layer directly.
activations = embedding_layer(my_inputs, my_weights)

# Call the layer with JIT compilation and explicitly preprocessed inputs.
embedding_layer_jit = jax.jit(embedding_layer)
preprocessed_inputs = embedding_layer.preprocess(my_inputs, my_weights)
activations = embedding_layer_jit(preprocessed_inputs)

类似地,对于自定义训练循环,必须在将数据传递到 JIT 编译的训练步骤之前应用预处理。

# Create an optimizer and loss function.
optimizer = keras.optimizers.Adam(learning_rate=1e-3)

def loss_and_updates(trainable_variables, non_trainable_variables, x, y):
    y_pred, non_trainable_variables = model.stateless_call(
        trainable_variables, non_trainable_variables, x, training=True
    )
    loss = keras.losses.mean_squared_error(y, y_pred)
    return loss, non_trainable_variables

grad_fn = jax.value_and_grad(loss_and_updates, has_aux=True)

# Create a JIT-compiled training step.
@jax.jit
def train_step(state, x, y):
    (
      trainable_variables,
      non_trainable_variables,
      optimizer_variables,
    ) = state
    (loss, non_trainable_variables), grads = grad_fn(
        trainable_variables, non_trainable_variables, x, y
    )
    trainable_variables, optimizer_variables = optimizer.stateless_apply(
        optimizer_variables, grads, trainable_variables
    )
    return loss, (
        trainable_variables,
        non_trainable_variables,
        optimizer_variables,
    )

# Build optimizer variables.
optimizer.build(model.trainable_variables)

# Assemble the training state.
trainable_variables = model.trainable_variables
non_trainable_variables = model.non_trainable_variables
optimizer_variables = optimizer.variables
state = trainable_variables, non_trainable_variables, optimizer_variables

# Training loop.
for (inputs, weights), labels in train_dataset:
    # Explicitly preprocess the data.
    preprocessed_inputs = embedding_layer.preprocess(inputs, weights)
    loss, state = train_step(state, preprocessed_inputs, labels)

参数

  • feature_configskeras_rs.layers.FeatureConfig 的嵌套结构。
  • table_stacking:要使用的表堆叠。None 表示不堆叠表。"auto" 表示自动堆叠表。表名列表或表名列表的列表表示将内部列表中的表堆叠在一起。请注意,较旧的 TPU 不支持表堆叠,在这种情况下,默认值 "auto" 将解释为不堆叠表。
  • **kwargs:要传递给层基类的其他参数。

[源代码]

call 方法

DistributedEmbedding.call(
    inputs: Union[
        Any,
        Sequence[Union[Any, ForwardRef("Nested[T]")]],
        Mapping[str, Union[Any, ForwardRef("Nested[T]")]],
    ],
    weights: Union[
        Any,
        Sequence[Union[Any, ForwardRef("Nested[T]")]],
        Mapping[str, Union[Any, ForwardRef("Nested[T]")]],
        NoneType,
    ] = None,
    training: bool = False,
)

在嵌入表中查找特征并应用缩减。

参数

  • inputs:要嵌入和缩减的 2D 张量的嵌套结构。结构必须与构造期间传递的 feature_configs 相同。或者,可以包含已预处理的输入(参见 preprocess)。
  • weights:在缩减之前要应用的可选的 2D 权重张量的嵌套结构。如果存在,结构必须与 inputs 相同,并且形状必须匹配。
  • training:我们是训练还是评估模型。

返回

密集 2D 张量的嵌套结构,它们是来自所传递特征的缩减嵌入。结构与 inputs 相同。


[源代码]

preprocess 方法

DistributedEmbedding.preprocess(
    inputs: Union[
        Any,
        Sequence[Union[Any, ForwardRef("Nested[T]")]],
        Mapping[str, Union[Any, ForwardRef("Nested[T]")]],
    ],
    weights: Union[
        Any,
        Sequence[Union[Any, ForwardRef("Nested[T]")]],
        Mapping[str, Union[Any, ForwardRef("Nested[T]")]],
        NoneType,
    ] = None,
    training: bool = False,
)

预处理数据并重新格式化以供模型使用。

对于 JAX 后端,将输入数据转换为与 SparseCore 一起使用所需的依赖于硬件的格式。只有为了启用 jit_compile = True 才需要显式调用 preprocess

对于非 JAX 后端,预处理将输入和权重捆绑在一起,并按设备放置分隔输入。此步骤完全是可选的。

参数

  • inputs:不规则或密集的样本 ID 集。
  • weights:可选的不规则或密集的样本权重集。
  • training:如果为 true,将更新内部参数,例如预处理数据所需的缓冲区大小。

返回

可直接馈送到层的 inputs 参数中的预处理输入集。