FluxBackbone
类keras_hub.models.FluxBackbone(
input_channels,
hidden_size,
mlp_ratio,
num_heads,
depth,
depth_single_blocks,
axes_dim,
theta,
use_bias,
guidance_embed=False,
image_shape=(None, 768, 3072),
text_shape=(None, 768, 3072),
image_ids_shape=(None, 768, 3072),
text_ids_shape=(None, 768, 3072),
y_shape=(None, 128),
**kwargs
)
用于序列流匹配的 Transformer 模型。
该模型处理图像和文本数据,并附带位置和时间步嵌入,可选择应用引导嵌入。双流块处理独立的图像和文本流,而单流块则合并这些流。移植自:https://github.com/black-forest-labs/flux
参数
num_heads
整除。调用参数
引发
hidden_size
不能被 num_heads
整除,或者 sum(axes_dim)
不等于位置嵌入维度。from_preset
方法FluxBackbone.from_preset(preset, load_weights=True, **kwargs)
从模型预设实例化一个 keras_hub.models.Backbone
。
预设是一个包含配置、权重和其他文件资源的目录,用于保存和加载预训练模型。preset
可以作为以下之一传递:
'bert_base_en'
'kaggle://user/bert/keras/bert_base_en'
'hf://user/bert_base_en'
'./bert_base_en'
此构造函数可以通过两种方式调用:从基类如 keras_hub.models.Backbone.from_preset()
调用,或从模型类如 keras_hub.models.GemmaBackbone.from_preset()
调用。如果从基类调用,返回对象的子类将根据预设目录中的配置推断。
对于任何 Backbone
子类,您可以运行 cls.presets.keys()
来列出该类上所有可用的内置预设。
参数
示例
# Load a Gemma backbone with pre-trained weights.
model = keras_hub.models.Backbone.from_preset(
"gemma_2b_en",
)
# Load a Bert backbone with a pre-trained config and random weights.
model = keras_hub.models.Backbone.from_preset(
"bert_base_en",
load_weights=False,
)