代码示例 / 计算机视觉 / 使用基于注意力机制的深度多实例学习(MIL)进行分类。

使用基于注意力机制的深度多实例学习(MIL)进行分类。

作者: Mohamad Jaber
创建日期 2021/08/16
最后修改 2021/11/25
描述: 用于对实例包进行分类并获取其个体实例得分的 MIL 方法。

ⓘ 本示例使用 Keras 3

在 Colab 中查看 GitHub 源代码


介绍

什么是多实例学习 (MIL)?

通常,在使用有监督学习算法时,学习器会收到一组实例的标签。对于 MIL,学习器会收到一组“包”的标签,每个包包含一组实例。如果一个包中包含至少一个阳性实例,则该包被标记为阳性;如果它不包含任何阳性实例,则该包将被标记为阴性。

动机

在图像分类任务中,通常假设每张图像都清晰地代表一个类别标签。在医学影像学(例如计算病理学等)中,整个图像由单个类别标签(癌性/非癌性)或给定的感兴趣区域表示。然而,人们会希望知道图像中哪些模式实际上导致它属于该类别。在这种情况下,图像将被分割,子图像将构成实例包。

因此,目标是:

  1. 学习一个模型来预测实例包的类别标签。
  2. 找出包中哪些实例导致了阳性类别标签预测。

实现

以下步骤描述了模型的工作原理:

  1. 特征提取器层提取特征嵌入。
  2. 将嵌入输入到 MIL 注意力层以获得注意力分数。该层设计为置换不变的。
  3. 将输入特征及其相应的注意力分数相乘。
  4. 将结果输出传递给 softmax 函数进行分类。

参考


设置

import numpy as np
import keras
from keras import layers
from keras import ops
from tqdm import tqdm
from matplotlib import pyplot as plt

plt.style.use("ggplot")

创建数据集

我们将创建一组包,并根据其内容分配标签。如果一个包中至少有一个阳性实例,则该包被视为阳性包。如果它不包含任何阳性实例,则该包将被视为阴性包。

配置参数

  • POSITIVE_CLASS: 希望保留在阳性包中的目标类别。
  • BAG_COUNT: 训练包的数量。
  • VAL_BAG_COUNT: 验证包的数量。
  • BAG_SIZE: 包中实例的数量。
  • PLOT_SIZE: 要绘制的包的数量。
  • ENSEMBLE_AVG_COUNT: 要创建并平均在一起的模型数量。(可选:通常能获得更好的性能 - 设置为 1 表示单个模型)
POSITIVE_CLASS = 1
BAG_COUNT = 1000
VAL_BAG_COUNT = 300
BAG_SIZE = 3
PLOT_SIZE = 3
ENSEMBLE_AVG_COUNT = 1

准备包

由于注意力操作器是置换不变的,具有阳性类别标签的实例被随机放置在阳性包的实例中。

def create_bags(input_data, input_labels, positive_class, bag_count, instance_count):
    # Set up bags.
    bags = []
    bag_labels = []

    # Normalize input data.
    input_data = np.divide(input_data, 255.0)

    # Count positive samples.
    count = 0

    for _ in range(bag_count):
        # Pick a fixed size random subset of samples.
        index = np.random.choice(input_data.shape[0], instance_count, replace=False)
        instances_data = input_data[index]
        instances_labels = input_labels[index]

        # By default, all bags are labeled as 0.
        bag_label = 0

        # Check if there is at least a positive class in the bag.
        if positive_class in instances_labels:
            # Positive bag will be labeled as 1.
            bag_label = 1
            count += 1

        bags.append(instances_data)
        bag_labels.append(np.array([bag_label]))

    print(f"Positive bags: {count}")
    print(f"Negative bags: {bag_count - count}")

    return (list(np.swapaxes(bags, 0, 1)), np.array(bag_labels))


# Load the MNIST dataset.
(x_train, y_train), (x_val, y_val) = keras.datasets.mnist.load_data()

# Create training data.
train_data, train_labels = create_bags(
    x_train, y_train, POSITIVE_CLASS, BAG_COUNT, BAG_SIZE
)

# Create validation data.
val_data, val_labels = create_bags(
    x_val, y_val, POSITIVE_CLASS, VAL_BAG_COUNT, BAG_SIZE
)
Positive bags: 283
Negative bags: 717
Positive bags: 104
Negative bags: 196

创建模型

现在我们将构建注意力层,准备一些实用工具,然后构建和训练整个模型。

注意力操作器实现

该层的输出大小由单个包的大小决定。

注意力机制使用包中实例的加权平均,其中权重的总和必须等于 1(与包大小无关)。

权重矩阵(参数)是 wv。为了包含正负值,使用了逐元素的双曲正切非线性函数。

可以使用门控注意力机制来处理复杂的关系。计算中添加了另一个权重矩阵 u。使用 sigmoid 非线性函数来克服双曲正切非线性函数在 x ∈ [−1, 1] 范围内的近似线性行为。

class MILAttentionLayer(layers.Layer):
    """Implementation of the attention-based Deep MIL layer.

    Args:
      weight_params_dim: Positive Integer. Dimension of the weight matrix.
      kernel_initializer: Initializer for the `kernel` matrix.
      kernel_regularizer: Regularizer function applied to the `kernel` matrix.
      use_gated: Boolean, whether or not to use the gated mechanism.

    Returns:
      List of 2D tensors with BAG_SIZE length.
      The tensors are the attention scores after softmax with shape `(batch_size, 1)`.
    """

    def __init__(
        self,
        weight_params_dim,
        kernel_initializer="glorot_uniform",
        kernel_regularizer=None,
        use_gated=False,
        **kwargs,
    ):
        super().__init__(**kwargs)

        self.weight_params_dim = weight_params_dim
        self.use_gated = use_gated

        self.kernel_initializer = keras.initializers.get(kernel_initializer)
        self.kernel_regularizer = keras.regularizers.get(kernel_regularizer)

        self.v_init = self.kernel_initializer
        self.w_init = self.kernel_initializer
        self.u_init = self.kernel_initializer

        self.v_regularizer = self.kernel_regularizer
        self.w_regularizer = self.kernel_regularizer
        self.u_regularizer = self.kernel_regularizer

    def build(self, input_shape):
        # Input shape.
        # List of 2D tensors with shape: (batch_size, input_dim).
        input_dim = input_shape[0][1]

        self.v_weight_params = self.add_weight(
            shape=(input_dim, self.weight_params_dim),
            initializer=self.v_init,
            name="v",
            regularizer=self.v_regularizer,
            trainable=True,
        )

        self.w_weight_params = self.add_weight(
            shape=(self.weight_params_dim, 1),
            initializer=self.w_init,
            name="w",
            regularizer=self.w_regularizer,
            trainable=True,
        )

        if self.use_gated:
            self.u_weight_params = self.add_weight(
                shape=(input_dim, self.weight_params_dim),
                initializer=self.u_init,
                name="u",
                regularizer=self.u_regularizer,
                trainable=True,
            )
        else:
            self.u_weight_params = None

        self.input_built = True

    def call(self, inputs):
        # Assigning variables from the number of inputs.
        instances = [self.compute_attention_scores(instance) for instance in inputs]

        # Stack instances into a single tensor.
        instances = ops.stack(instances)

        # Apply softmax over instances such that the output summation is equal to 1.
        alpha = ops.softmax(instances, axis=0)

        # Split to recreate the same array of tensors we had as inputs.
        return [alpha[i] for i in range(alpha.shape[0])]

    def compute_attention_scores(self, instance):
        # Reserve in-case "gated mechanism" used.
        original_instance = instance

        # tanh(v*h_k^T)
        instance = ops.tanh(ops.tensordot(instance, self.v_weight_params, axes=1))

        # for learning non-linear relations efficiently.
        if self.use_gated:
            instance = instance * ops.sigmoid(
                ops.tensordot(original_instance, self.u_weight_params, axes=1)
            )

        # w^T*(tanh(v*h_k^T)) / w^T*(tanh(v*h_k^T)*sigmoid(u*h_k^T))
        return ops.tensordot(instance, self.w_weight_params, axes=1)

可视化工具

绘制给定包数量(由 PLOT_SIZE 确定)相对于类别的情况。

此外,如果激活,可以看到每个包的类别标签预测及其相关的实例得分(在模型训练完成后)。

def plot(data, labels, bag_class, predictions=None, attention_weights=None):
    """ "Utility for plotting bags and attention weights.

    Args:
      data: Input data that contains the bags of instances.
      labels: The associated bag labels of the input data.
      bag_class: String name of the desired bag class.
        The options are: "positive" or "negative".
      predictions: Class labels model predictions.
      If you don't specify anything, ground truth labels will be used.
      attention_weights: Attention weights for each instance within the input data.
      If you don't specify anything, the values won't be displayed.
    """
    return  ## TODO
    labels = np.array(labels).reshape(-1)

    if bag_class == "positive":
        if predictions is not None:
            labels = np.where(predictions.argmax(1) == 1)[0]
            bags = np.array(data)[:, labels[0:PLOT_SIZE]]

        else:
            labels = np.where(labels == 1)[0]
            bags = np.array(data)[:, labels[0:PLOT_SIZE]]

    elif bag_class == "negative":
        if predictions is not None:
            labels = np.where(predictions.argmax(1) == 0)[0]
            bags = np.array(data)[:, labels[0:PLOT_SIZE]]
        else:
            labels = np.where(labels == 0)[0]
            bags = np.array(data)[:, labels[0:PLOT_SIZE]]

    else:
        print(f"There is no class {bag_class}")
        return

    print(f"The bag class label is {bag_class}")
    for i in range(PLOT_SIZE):
        figure = plt.figure(figsize=(8, 8))
        print(f"Bag number: {labels[i]}")
        for j in range(BAG_SIZE):
            image = bags[j][i]
            figure.add_subplot(1, BAG_SIZE, j + 1)
            plt.grid(False)
            if attention_weights is not None:
                plt.title(np.around(attention_weights[labels[i]][j], 2))
            plt.imshow(image)
        plt.show()


# Plot some of validation data bags per class.
plot(val_data, val_labels, "positive")
plot(val_data, val_labels, "negative")

创建模型

首先,我们将为每个实例创建一些嵌入,调用注意力操作器,然后使用 softmax 函数输出类别概率。

def create_model(instance_shape):
    # Extract features from inputs.
    inputs, embeddings = [], []
    shared_dense_layer_1 = layers.Dense(128, activation="relu")
    shared_dense_layer_2 = layers.Dense(64, activation="relu")
    for _ in range(BAG_SIZE):
        inp = layers.Input(instance_shape)
        flatten = layers.Flatten()(inp)
        dense_1 = shared_dense_layer_1(flatten)
        dense_2 = shared_dense_layer_2(dense_1)
        inputs.append(inp)
        embeddings.append(dense_2)

    # Invoke the attention layer.
    alpha = MILAttentionLayer(
        weight_params_dim=256,
        kernel_regularizer=keras.regularizers.L2(0.01),
        use_gated=True,
        name="alpha",
    )(embeddings)

    # Multiply attention weights with the input layers.
    multiply_layers = [
        layers.multiply([alpha[i], embeddings[i]]) for i in range(len(alpha))
    ]

    # Concatenate layers.
    concat = layers.concatenate(multiply_layers, axis=1)

    # Classification output node.
    output = layers.Dense(2, activation="softmax")(concat)

    return keras.Model(inputs, output)

类别权重

由于这类问题很容易变成不平衡数据分类问题,因此应考虑类别加权。

假设有 1000 个包。通常情况下,约 90% 的包不包含任何阳性标签,而约 10% 的包包含。这类数据可以称为不平衡数据

使用类别权重,模型会倾向于给予稀有类别更高的权重。

def compute_class_weights(labels):
    # Count number of positive and negative bags.
    negative_count = len(np.where(labels == 0)[0])
    positive_count = len(np.where(labels == 1)[0])
    total_count = negative_count + positive_count

    # Build class weight dictionary.
    return {
        0: (1 / negative_count) * (total_count / 2),
        1: (1 / positive_count) * (total_count / 2),
    }

构建并训练模型

本节构建并训练模型。

def train(train_data, train_labels, val_data, val_labels, model):
    # Train model.
    # Prepare callbacks.
    # Path where to save best weights.

    # Take the file name from the wrapper.
    file_path = "/tmp/best_model.weights.h5"

    # Initialize model checkpoint callback.
    model_checkpoint = keras.callbacks.ModelCheckpoint(
        file_path,
        monitor="val_loss",
        verbose=0,
        mode="min",
        save_best_only=True,
        save_weights_only=True,
    )

    # Initialize early stopping callback.
    # The model performance is monitored across the validation data and stops training
    # when the generalization error cease to decrease.
    early_stopping = keras.callbacks.EarlyStopping(
        monitor="val_loss", patience=10, mode="min"
    )

    # Compile model.
    model.compile(
        optimizer="adam",
        loss="sparse_categorical_crossentropy",
        metrics=["accuracy"],
    )

    # Fit model.
    model.fit(
        train_data,
        train_labels,
        validation_data=(val_data, val_labels),
        epochs=20,
        class_weight=compute_class_weights(train_labels),
        batch_size=1,
        callbacks=[early_stopping, model_checkpoint],
        verbose=0,
    )

    # Load best weights.
    model.load_weights(file_path)

    return model


# Building model(s).
instance_shape = train_data[0][0].shape
models = [create_model(instance_shape) for _ in range(ENSEMBLE_AVG_COUNT)]

# Show single model architecture.
print(models[0].summary())

# Training model(s).
trained_models = [
    train(train_data, train_labels, val_data, val_labels, model)
    for model in tqdm(models)
]
Model: "functional_1"
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┓
┃ Layer (type)         Output Shape       Param #  Connected to         ┃
┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━┩
│ input_layer         │ (None, 28, 28)    │       0 │ -                    │
│ (InputLayer)        │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ input_layer_1       │ (None, 28, 28)    │       0 │ -                    │
│ (InputLayer)        │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ input_layer_2       │ (None, 28, 28)    │       0 │ -                    │
│ (InputLayer)        │                   │         │                      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ flatten (Flatten)   │ (None, 784)       │       0 │ input_layer[0][0]    │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ flatten_1 (Flatten) │ (None, 784)       │       0 │ input_layer_1[0][0]  │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ flatten_2 (Flatten) │ (None, 784)       │       0 │ input_layer_2[0][0]  │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ dense (Dense)       │ (None, 128)       │ 100,480 │ flatten[0][0],       │
│                     │                   │         │ flatten_1[0][0],     │
│                     │                   │         │ flatten_2[0][0]      │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ dense_1 (Dense)     │ (None, 64)        │   8,256 │ dense[0][0],         │
│                     │                   │         │ dense[1][0],         │
│                     │                   │         │ dense[2][0]          │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ alpha               │ [(None, 1),       │  33,024 │ dense_1[0][0],       │
│ (MILAttentionLayer) │ (None, 1), (None, │         │ dense_1[1][0],       │
│                     │ 1)]               │         │ dense_1[2][0]        │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ multiply (Multiply) │ (None, 64)        │       0 │ alpha[0][0],         │
│                     │                   │         │ dense_1[0][0]        │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ multiply_1          │ (None, 64)        │       0 │ alpha[0][1],         │
│ (Multiply)          │                   │         │ dense_1[1][0]        │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ multiply_2          │ (None, 64)        │       0 │ alpha[0][2],         │
│ (Multiply)          │                   │         │ dense_1[2][0]        │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ concatenate         │ (None, 192)       │       0 │ multiply[0][0],      │
│ (Concatenate)       │                   │         │ multiply_1[0][0],    │
│                     │                   │         │ multiply_2[0][0]     │
├─────────────────────┼───────────────────┼─────────┼──────────────────────┤
│ dense_2 (Dense)     │ (None, 2)         │     386 │ concatenate[0][0]    │
└─────────────────────┴───────────────────┴─────────┴──────────────────────┘
 Total params: 142,146 (555.26 KB)
 Trainable params: 142,146 (555.26 KB)
 Non-trainable params: 0 (0.00 B)
None

100%|██████████████████████████████████████████████████████████████████████████████████| 1/1 [00:36<00:00, 36.67s/it]

模型评估

模型现在准备进行评估。对于每个模型,我们还创建一个相关的中间模型来获取注意力层的权重。

我们将计算每个 ENSEMBLE_AVG_COUNT 模型的预测结果,并将它们平均以获得最终预测。

def predict(data, labels, trained_models):
    # Collect info per model.
    models_predictions = []
    models_attention_weights = []
    models_losses = []
    models_accuracies = []

    for model in trained_models:
        # Predict output classes on data.
        predictions = model.predict(data)
        models_predictions.append(predictions)

        # Create intermediate model to get MIL attention layer weights.
        intermediate_model = keras.Model(model.input, model.get_layer("alpha").output)

        # Predict MIL attention layer weights.
        intermediate_predictions = intermediate_model.predict(data)

        attention_weights = np.squeeze(np.swapaxes(intermediate_predictions, 1, 0))
        models_attention_weights.append(attention_weights)

        loss, accuracy = model.evaluate(data, labels, verbose=0)
        models_losses.append(loss)
        models_accuracies.append(accuracy)

    print(
        f"The average loss and accuracy are {np.sum(models_losses, axis=0) / ENSEMBLE_AVG_COUNT:.2f}"
        f" and {100 * np.sum(models_accuracies, axis=0) / ENSEMBLE_AVG_COUNT:.2f} % resp."
    )

    return (
        np.sum(models_predictions, axis=0) / ENSEMBLE_AVG_COUNT,
        np.sum(models_attention_weights, axis=0) / ENSEMBLE_AVG_COUNT,
    )


# Evaluate and predict classes and attention scores on validation data.
class_predictions, attention_params = predict(val_data, val_labels, trained_models)

# Plot some results from our validation data.
plot(
    val_data,
    val_labels,
    "positive",
    predictions=class_predictions,
    attention_weights=attention_params,
)
plot(
    val_data,
    val_labels,
    "negative",
    predictions=class_predictions,
    attention_weights=attention_params,
)
 10/10 ━━━━━━━━━━━━━━━━━━━━ 1s 53ms/step
 10/10 ━━━━━━━━━━━━━━━━━━━━ 1s 39ms/step
The average loss and accuracy are 0.03 and 99.00 % resp.

结论

从上面的图可以看出,权重总和为 1。在一个预测为阳性的包中,导致阳性标记的实例将比包中其余实例具有明显更高的注意力分数。然而,在一个预测为阴性的包中,有两种情况:

  • 所有实例将具有近似相似的分数。
  • 一个实例将具有相对较高的分数(但不如阳性实例高)。这是因为该实例的特征空间与阳性实例的特征空间接近。

注意事项

  • 如果模型过拟合,权重将平均分配给所有包。因此,正则化技术是必要的。
  • 在论文中,不同包的大小可能不同。为简单起见,此处固定了包的大小。
  • 为了不依赖于单个模型的随机初始权重,应考虑平均集成方法。