RetinaNetObjectDetector
类keras_hub.models.RetinaNetObjectDetector(
backbone,
num_classes,
bounding_box_format="yxyx",
anchor_generator=None,
label_encoder=None,
use_prediction_head_norm=False,
classification_head_prior_probability=0.01,
pre_logits_num_conv_layers=4,
preprocessor=None,
activation=None,
dtype=None,
prediction_decoder=None,
**kwargs
)
RetinaNet 目标检测模型。
该类实现了 RetinaNet 目标检测架构。它由特征提取主干网络、特征金字塔网络 (FPN) 以及两个预测头(用于分类和边界框回归)组成。
参数
keras.Model
。一个 keras.models.RetinaNetBackbone
类,定义了主干网络架构。提供用于检测的特征图。keras_hub.layers.AnchorGenerator
实例。在图像的不同尺度和纵横比下生成锚框。如果为 None,则创建一个具有以下参数的默认 AnchorGenerator
:- bounding_box_format
:与模型的 bounding_box_format
相同。- min_level
:主干网络的 min_level
。- max_level
:主干网络的 max_level
。- num_scales
:3。- aspect_ratios
:[0.5, 1.0, 2.0]。- anchor_size
:4.0。您可以通过实例化 keras_hub.layers.AnchorGenerator
类并传递所需的参数来创建自定义 AnchorGenerator
。yxyx
。RetinaNetLabelEncoder
实例。将真实框和类别编码为训练目标。它根据 IoU 将真实框与锚点匹配,并将框坐标编码为偏移量。如果为 None
,则创建默认编码器。有关详细信息,请参阅 RetinaNetLabelEncoder
类。如果为 None,则创建一个具有标准参数的默认编码器。- anchor_generator
:与模型的相同。- bounding_box_format
:与模型的 bounding_box_format
相同。- positive_threshold
:0.5 - negative_threshold
:0.4 - encoding_format
:“center_xywh” - box_variance
:[1.0, 1.0, 1.0, 1.0] - background_class
:-1 - ignore_class
:-2False
。RetinaNetObjectDetectorPreprocessor
的实例或自定义预处理器。负责在输入主干网络之前处理图像。keras.layers.Layer
实例,负责将 RetinaNet 预测(框回归和分类)转换为最终的边界框和具有置信度分数的类别。默认为 NonMaxSuppression
实例。from_preset
方法RetinaNetObjectDetector.from_preset(preset, load_weights=True, **kwargs)
从模型预设实例化一个 keras_hub.models.Task
。
预设是一个包含配置、权重和其他文件资产的目录,用于保存和加载预训练模型。preset
可以作为以下之一传递:
'bert_base_en'
'kaggle://user/bert/keras/bert_base_en'
'hf://user/bert_base_en'
'./bert_base_en'
对于任何 Task
子类,您都可以运行 cls.presets.keys()
来列出该类上所有可用的内置预设。
此构造函数可以通过两种方式调用。一种是通过任务特定的基类,例如 keras_hub.models.CausalLM.from_preset()
,另一种是通过模型类,例如 keras_hub.models.BertTextClassifier.from_preset()
。如果从基类调用,返回对象的子类将从预设目录中的配置推断出来。
参数
True
,已保存的权重将被加载到模型架构中。如果为 False
,所有权重将被随机初始化。示例
# Load a Gemma generative task.
causal_lm = keras_hub.models.CausalLM.from_preset(
"gemma_2b_en",
)
# Load a Bert classification task.
model = keras_hub.models.TextClassifier.from_preset(
"bert_base_en",
num_classes=2,
)
预设 | 参数 | 描述 |
---|---|---|
retinanet_resnet50_fpn_v2_coco | 31.56M | RetinaNet 模型,带有 ResNet50 骨干网络,在 800x800 分辨率的 COCO 数据集上进行微调,FPN 特征从 P5 级别创建。 |
retinanet_resnet50_fpn_coco | 34.12M | RetinaNet 模型,带有 ResNet50 骨干网络,在 800x800 分辨率的 COCO 数据集上进行微调。 |
backbone
属性keras_hub.models.RetinaNetObjectDetector.backbone
一个具有核心架构的 keras_hub.models.Backbone
模型。
preprocessor
属性keras_hub.models.RetinaNetObjectDetector.preprocessor
用于预处理输入的 keras_hub.models.Preprocessor
层。