DistributedEmbedding
类keras_rs.layers.DistributedEmbedding(
feature_configs: Union[
keras_rs.src.layers.embedding.distributed_embedding_config.FeatureConfig,
tensorflow.python.tpu.tpu_embedding_v2_utils.FeatureConfig,
Sequence[
Union[
keras_rs.src.layers.embedding.distributed_embedding_config.FeatureConfig,
tensorflow.python.tpu.tpu_embedding_v2_utils.FeatureConfig,
ForwardRef("Nested[T]"),
]
],
Mapping[
str,
Union[
keras_rs.src.layers.embedding.distributed_embedding_config.FeatureConfig,
tensorflow.python.tpu.tpu_embedding_v2_utils.FeatureConfig,
ForwardRef("Nested[T]"),
],
],
],
table_stacking: Union[str, Sequence[str], Sequence[Sequence[str]]] = "auto",
**kwargs: Any
)
DistributedEmbedding,一个用于加速大型嵌入查找的层。
DistributedEmbedding
处于预览阶段。DistributedEmbedding
是一个针对带 SparseCore 的 TPU 芯片优化的层,可以显著提高嵌入查找和嵌入训练的速度。它通过将多个查找组合成一个调用,并将嵌入表分片到可用的芯片上来实现。请注意,只有当嵌入表足够大以至于需要分片(因为它们不适合单个芯片)时,才能看到性能提升。更多详细信息请参阅下面的“放置”部分。
在其他硬件(GPU、CPU 和不带 SparseCore 的 TPU)上,DistributedEmbedding
提供相同的 API,没有任何特定的加速。除了通过 keras.distribution.set_distribution
设置的分发方案外,不应用任何特定的分发方案。
DistributedEmbedding
嵌入输入序列,并通过应用可配置的组合器函数将其缩减为单个嵌入。
DistributedEmbedding
嵌入层通过一组 keras_rs.layers.FeatureConfig
对象进行配置,这些对象本身引用 keras_rs.layers.TableConfig
对象。
TableConfig
定义了一个嵌入表,其中包含词汇大小、嵌入维度等参数,以及用于缩减的组合器和用于训练的优化器。FeatureConfig
定义 DistributedEmbedding
将处理哪些输入特征以及使用哪个嵌入表。请注意,多个特征可以使用相同的嵌入表。table1 = keras_rs.layers.TableConfig(
name="table1",
vocabulary_size=TABLE1_VOCABULARY_SIZE,
embedding_dim=TABLE1_EMBEDDING_SIZE,
placement="auto",
)
table2 = keras_rs.layers.TableConfig(
name="table2",
vocabulary_size=TABLE2_VOCABULARY_SIZE,
embedding_dim=TABLE2_EMBEDDING_SIZE,
placement="auto",
)
feature1 = keras_rs.layers.FeatureConfig(
name="feature1",
table=table1,
input_shape=(PER_REPLICA_BATCH_SIZE,),
output_shape=(PER_REPLICA_BATCH_SIZE, TABLE1_EMBEDDING_SIZE),
)
feature2 = keras_rs.layers.FeatureConfig(
name="feature2",
table=table2,
input_shape=(PER_REPLICA_BATCH_SIZE,),
output_shape=(PER_REPLICA_BATCH_SIZE, TABLE2_EMBEDDING_SIZE),
)
feature_configs = {
"feature1": feature1,
"feature2": feature2,
}
embedding = keras_rs.layers.DistributedEmbedding(feature_configs)
DistributedEmbedding
中的每个嵌入表都使用自己的优化器进行训练,该优化器独立于通过 model.compile()
在模型上设置的优化器。
请注意,并非所有优化器都受支持。目前,所有后端和加速器都支持以下优化器:
此外,并非所有优化器参数都受支持(例如 SGD
的 nesterov
选项)。当使用不受支持的优化器或不受支持的优化器参数时,会引发错误。
DistributedEmbedding
中的每个嵌入表都可以放置在 SparseCore 芯片上,或放置在加速器的默认设备上(例如 TPU 上 Tensor Cores 的 HBM)。这由 keras_rs.layers.TableConfig
的 placement
属性控制。
"sparsecore"
放置表示该表应放置在 SparseCore 芯片上。如果选择此选项但没有 SparseCore 芯片,则会引发错误。"default_device"
放置表示即使 SparseCore 可用,该表也不应放置在 SparseCore 上。相反,该表放置在模型通常所在的设备上,即 TPU 和 GPU 上的 HBM。在这种情况下,如果适用,该表使用通过 keras.distribution.set_distribution
设置的方案进行分发。在 GPU、CPU 和不带 SparseCore 的 TPU 上,这是唯一可用的放置方式,也是 "auto"
选择的方式。"auto"
放置表示如果 SparseCore 可用,则使用 "sparsecore"
;否则使用 "default_device"
。这是未指定时的默认值。优化 TPU 性能
"sparsecore"
放置。"default_device"
,并且通常应通过使用 keras.distribution.DataParallel
分发选项在 TPU 上复制。除了 tf.Tensor
,DistributedEmbedding
还接受 tf.RaggedTensor
和 tf.SparseTensor
作为嵌入查找的输入。不规则张量必须在索引为 1 的维度上不规则。请注意,如果传递权重,每个权重张量必须与该特定特征的输入属于同一类,并且对于不规则张量使用完全相同的不规则行长度,对于稀疏张量使用相同的索引。DistributedEmbedding
的所有输出都是密集张量。
要在带 TensorFlow 的 TPU 上使用 DistributedEmbedding
,必须使用 tf.distribute.TPUStrategy
。DistributedEmbedding
层必须在 TPUStrategy
下创建。
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local")
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
device_assignment = tf.tpu.experimental.DeviceAssignment.build(
topology, num_replicas=resolver.get_tpu_system_metadata().num_cores
)
strategy = tf.distribute.TPUStrategy(
resolver, experimental_device_assignment=device_assignment
)
with strategy.scope():
embedding = keras_rs.layers.DistributedEmbedding(feature_configs)
要使用 Keras 的 model.fit()
,必须在 TPUStrategy
下编译模型。然后,可以直接调用 model.fit()
、model.evaluate()
或 model.predict()
。Keras 模型负责使用策略运行模型,并自动分发数据集。
with strategy.scope():
embedding = keras_rs.layers.DistributedEmbedding(feature_configs)
model = create_model(embedding)
model.compile(loss=keras.losses.MeanSquaredError(), optimizer="adam")
model.fit(dataset, epochs=10)
DistributedEmbedding
必须通过嵌套在 tf.function
中的 strategy.run
调用来调用。
@tf.function
def embedding_wrapper(tf_fn_inputs, tf_fn_weights=None):
def strategy_fn(st_fn_inputs, st_fn_weights):
return embedding(st_fn_inputs, st_fn_weights)
return strategy.run(strategy_fn, args=(tf_fn_inputs, tf_fn_weights)))
embedding_wrapper(my_inputs, my_weights)
使用数据集时,数据集必须是分布式。然后可以将迭代器传递给使用 strategy.run
的 tf.function
。
dataset = strategy.experimental_distribute_dataset(dataset)
@tf.function
def run_loop(iterator):
def step(data):
(inputs, weights), labels = data
with tf.GradientTape() as tape:
result = embedding(inputs, weights)
loss = keras.losses.mean_squared_error(labels, result)
tape.gradient(loss, embedding.trainable_variables)
return result
for _ in tf.range(4):
result = strategy.run(step, args=(next(iterator),))
run_loop(iter(dataset))
要在带 JAX 的 TPU 上使用 DistributedEmbedding
,必须创建并设置 Keras Distribution
。
distribution = keras.distribution.DataParallel(devices=jax.device("tpu"))
keras.distribution.set_distribution(distribution)
对于 JAX,输入可以是密集张量,也可以是不规则(嵌套)NumPy 数组。要启用 jit_compile = True
,必须显式调用输入上的 layer.preprocess(...)
,然后将预处理后的输出馈送到模型。有关详细信息,请参阅下一节关于预处理。
不规则输入数组必须在索引为 1 的维度上不规则。请注意,如果传递权重,每个权重张量必须与该特定特征的输入属于同一类,并且对于不规则张量使用完全相同的不规则行长度。DistributedEmbedding
的所有输出都是密集张量。
在 JAX 中,SparseCore 的使用需要专门格式化的数据,这取决于可用硬件的属性。这种数据重新格式化目前不支持即时编译,因此必须在将数据传递到模型之前应用。
预处理适用于密集或不规则 NumPy 数组,或可转换为密集或不规则 NumPy 数组的张量,例如 tf.RaggedTensor
。
添加预处理的一种简单方法是使用 python 生成器将函数附加到输入管道。
# Create the embedding layer.
embedding_layer = DistributedEmbedding(feature_configs)
# Add preprocessing to a data input pipeline.
def preprocessed_dataset_generator(dataset):
for (inputs, weights), labels in iter(dataset):
yield embedding_layer.preprocess(
inputs, weights, training=True
), labels
preprocessed_train_dataset = preprocessed_dataset_generator(train_dataset)
这个显式预处理阶段将输入和可选权重组合在一起,因此新数据可以直接传递到层或模型的 inputs
参数中。
注意:在多主机设置中进行数据并行时,数据需要正确地在主机之间分片。如果原始数据集类型为 tf.data.Dataset
,则需要在应用预处理生成器之前手动分片。
# Manually shard the dataset across hosts.
train_dataset = distribution.distribute_dataset(train_dataset)
distribution.auto_shard_dataset = False # Dataset is already sharded.
# Add a preprocessing stage to the distributed data input pipeline.
train_dataset = preprocessed_dataset_generator(train_dataset)
如果原始数据集不是 tf.data.Dataset
,则必须已在主机之间预分片。
一旦设置了全局分发并定义了输入预处理管道,模型训练就可以正常进行。例如:
# Construct, compile, and fit the model using the preprocessed data.
model = keras.Sequential(
[
embedding_layer,
keras.layers.Dense(2),
keras.layers.Dense(3),
keras.layers.Dense(4),
]
)
model.compile(optimizer="adam", loss="mse", jit_compile=True)
model.fit(preprocessed_train_dataset, epochs=10)
DistributedEmbedding
层也可以直接调用。当与 JIT 编译一起使用时,需要显式预处理。
# Call the layer directly.
activations = embedding_layer(my_inputs, my_weights)
# Call the layer with JIT compilation and explicitly preprocessed inputs.
embedding_layer_jit = jax.jit(embedding_layer)
preprocessed_inputs = embedding_layer.preprocess(my_inputs, my_weights)
activations = embedding_layer_jit(preprocessed_inputs)
类似地,对于自定义训练循环,必须在将数据传递到 JIT 编译的训练步骤之前应用预处理。
# Create an optimizer and loss function.
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
def loss_and_updates(trainable_variables, non_trainable_variables, x, y):
y_pred, non_trainable_variables = model.stateless_call(
trainable_variables, non_trainable_variables, x, training=True
)
loss = keras.losses.mean_squared_error(y, y_pred)
return loss, non_trainable_variables
grad_fn = jax.value_and_grad(loss_and_updates, has_aux=True)
# Create a JIT-compiled training step.
@jax.jit
def train_step(state, x, y):
(
trainable_variables,
non_trainable_variables,
optimizer_variables,
) = state
(loss, non_trainable_variables), grads = grad_fn(
trainable_variables, non_trainable_variables, x, y
)
trainable_variables, optimizer_variables = optimizer.stateless_apply(
optimizer_variables, grads, trainable_variables
)
return loss, (
trainable_variables,
non_trainable_variables,
optimizer_variables,
)
# Build optimizer variables.
optimizer.build(model.trainable_variables)
# Assemble the training state.
trainable_variables = model.trainable_variables
non_trainable_variables = model.non_trainable_variables
optimizer_variables = optimizer.variables
state = trainable_variables, non_trainable_variables, optimizer_variables
# Training loop.
for (inputs, weights), labels in train_dataset:
# Explicitly preprocess the data.
preprocessed_inputs = embedding_layer.preprocess(inputs, weights)
loss, state = train_step(state, preprocessed_inputs, labels)
参数
keras_rs.layers.FeatureConfig
的嵌套结构。None
表示不堆叠表。"auto"
表示自动堆叠表。表名列表或表名列表的列表表示将内部列表中的表堆叠在一起。请注意,较旧的 TPU 不支持表堆叠,在这种情况下,默认值 "auto"
将解释为不堆叠表。call
方法DistributedEmbedding.call(
inputs: Union[
Any,
Sequence[Union[Any, ForwardRef("Nested[T]")]],
Mapping[str, Union[Any, ForwardRef("Nested[T]")]],
],
weights: Union[
Any,
Sequence[Union[Any, ForwardRef("Nested[T]")]],
Mapping[str, Union[Any, ForwardRef("Nested[T]")]],
NoneType,
] = None,
training: bool = False,
)
在嵌入表中查找特征并应用缩减。
参数
feature_configs
相同。或者,可以包含已预处理的输入(参见 preprocess
)。inputs
相同,并且形状必须匹配。返回
密集 2D 张量的嵌套结构,它们是来自所传递特征的缩减嵌入。结构与 inputs
相同。
preprocess
方法DistributedEmbedding.preprocess(
inputs: Union[
Any,
Sequence[Union[Any, ForwardRef("Nested[T]")]],
Mapping[str, Union[Any, ForwardRef("Nested[T]")]],
],
weights: Union[
Any,
Sequence[Union[Any, ForwardRef("Nested[T]")]],
Mapping[str, Union[Any, ForwardRef("Nested[T]")]],
NoneType,
] = None,
training: bool = False,
)
预处理数据并重新格式化以供模型使用。
对于 JAX 后端,将输入数据转换为与 SparseCore 一起使用所需的依赖于硬件的格式。只有为了启用 jit_compile = True
才需要显式调用 preprocess
。
对于非 JAX 后端,预处理将输入和权重捆绑在一起,并按设备放置分隔输入。此步骤完全是可选的。
参数
返回
可直接馈送到层的 inputs
参数中的预处理输入集。