莱州网站建设青岛华夏商务网/seo整站优化一年价格多少
完整代码总结
这段代码的目的是通过构建一个部分标签学习(Partial Label Learning, PLL)框架来生成一个包含部分标签的数据集,并且支持根据给定的标签列表对数据集进行筛选和过滤。代码包含了多个类和函数,主要分为以下几部分:
- 数据预处理与加载:使用 PyTorch 和 torchvision 来加载 CIFAR-10 数据集,并对其进行标准化处理。
- 部分标签数据集的生成:为每个样本生成多个候选标签,并模拟部分标签学习中的标签不确定性。
- 数据集筛选:根据用户提供的标签列表来过滤掉包含特定标签的样本,生成一个新的数据集。
- DataLoader 设置:通过 DataLoader 对数据集进行批量加载,并在训练时进行处理。
各方法与类的解释
1. PartialLabelDataset 类
该类用于生成一个部分标签数据集,每个样本会被赋予一个候选标签集,其中可能包含真实标签以及一些随机标签。
__init__(self, dataset, candidate_size)
:初始化数据集,将输入的原始数据集与候选标签集大小保存为类的属性。candidate_size
表示每个样本的候选标签数量。generate_partial_labels(self)
:为每个样本生成部分标签。每个样本会从真实标签开始,然后添加若干个随机的标签,直到候选标签集的大小为candidate_size
。生成的标签会被打乱顺序,以模拟标签不确定性。__getitem__(self, index)
:获取索引index
对应样本的图像数据、部分标签和真实标签。真实标签是从数据集中直接获取的,部分标签是根据generate_partial_labels()
方法生成的。__len__(self)
:返回数据集中样本的数量。
2. FilteredPartialLabelDataset 类
该类用于过滤掉原始部分标签数据集中的特定标签样本,并根据过滤后的数据生成新的数据集。
__init__(self, dataset, partial_labels, filtered_indices)
:初始化该类时,需要输入原始数据集、完整的部分标签列表以及要保留的样本索引列表(即不包含过滤标签的样本)。__getitem__(self, index)
:根据过滤后的索引,从原始数据集中获取图像和标签数据。__len__(self)
:返回筛选后的样本数量。
3. filter_partial_label_dataset 函数
这个函数用于对原始部分标签数据集进行标签筛选,去掉包含特定标签的样本,并返回过滤后的数据集和 DataLoader。
dataset
:原始数据集(如 CIFAR-10)。partial_labels
:包含完整部分标签的列表,函数会基于此生成新的部分标签数据集。candidate_size
:每个样本的候选标签集大小。filtered_labels
:一个标签列表,表示需要从部分标签中排除的标签。batch_size
:DataLoader 的批次大小。shuffle
:是否在 DataLoader 中打乱数据。num_workers
:DataLoader 的工作线程数。
函数首先根据 filtered_labels
过滤掉部分标签中包含这些标签的样本,接着根据过滤后的样本索引创建一个新的 FilteredPartialLabelDataset
。最终返回该新的数据集和对应的 DataLoader。
4. main 函数
该函数是代码的入口,负责生成部分标签数据集并创建 DataLoader。
- 通过
PartialLabelDataset
类生成一个包含部分标签的数据集(候选标签集大小为3)。 - 创建一个 DataLoader,用于批量加载部分标签数据集。
- 打印出部分标签数据集的一个批次样本的形状和标签信息。
在 main()
函数中,partial_label_dataset
被用来生成部分标签数据集,并且通过 filter_partial_label_dataset
函数对数据集进行标签过滤,排除包含标签 [5, 6, 7, 8, 9]
的样本。
代码流程图
-
数据加载与预处理:
- 使用
torchvision.datasets.CIFAR10
下载并加载 CIFAR-10 数据集。 - 对图像进行标准化处理(均值和标准差为0.5)。
- 使用
-
生成部分标签数据集:
- 在
PartialLabelDataset
中为每个样本生成多个候选标签(候选标签数为3),这些标签包括真实标签及随机标签。 - 使用
generate_partial_labels()
方法生成候选标签,并打乱顺序。
- 在
-
数据筛选:
- 使用
filter_partial_label_dataset
函数,根据用户提供的标签列表(如[5, 6, 7, 8, 9]
)过滤掉部分标签中包含这些标签的样本,创建新的数据集。
- 使用
-
数据加载器:
- 通过
DataLoader
创建数据加载器,使得在训练过程中可以批量读取数据。
- 通过
-
输出样本信息:
- 在
main()
函数中打印出部分标签的一个批次示例,包括图像的形状、部分标签和真实标签。
- 在
优点和可扩展性
- 部分标签学习:这段代码模拟了部分标签学习的场景,其中每个样本都有多个候选标签,这为部分标签学习任务提供了一个基础框架。
- 灵活的标签过滤:通过
filter_partial_label_dataset
函数,用户可以方便地过滤掉特定标签的样本。 - 可扩展性:可以将这个框架扩展到其他数据集(如 CIFAR-100、ImageNet 等),并灵活调整候选标签大小和过滤标签。
总结
这段代码提供了一个部分标签学习框架,可以用来处理具有部分标签的不完整数据集,并提供了一种方法来筛选数据集中的特定标签。通过生成候选标签和对数据进行过滤,代码实现了部分标签学习任务的数据预处理与加载,为相关研究和应用提供了有效支持。
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms# 定义数据预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])# 下载 CIFAR-10 数据集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)# 定义合并后的部分标签数据集类
class PartialLabelDataset(Dataset):def __init__(self, dataset, candidate_size):"""初始化部分标签数据集:param dataset: 原始数据集对象(如 CIFAR-10):param candidate_size: 候选标签集的大小:param filtered_labels: 不得存在于部分标签中的标签列表(可选)"""self.dataset = datasetself.candidate_size = candidate_sizeself.num_classes = len(dataset.classes)self.targets = dataset.targetsself.partial_labels = self.generate_partial_labels()def generate_partial_labels(self):"""为每个图像生成部分标签:param filtered_labels: 不得存在于部分标签中的标签列表(可选):return: 部分标签列表"""partial_labels = []for target in self.targets:candidates = [target]while len(candidates) < self.candidate_size:random_label = np.random.randint(0, self.num_classes)if random_label not in candidates :candidates.append(random_label)#打乱候选标签np.random.shuffle(candidates)partial_labels.append(candidates)return partial_labelsdef __getitem__(self, index):image, _ = self.dataset[index]partial_label = torch.tensor(self.partial_labels[index], dtype=torch.long)true_label = torch.tensor(self.targets[index], dtype=torch.long) # 真实标签return image, partial_label, true_labeldef __len__(self):return len(self.dataset)
class FilteredPartialLabelDataset(Dataset):def __init__(self, dataset, partial_labels, filtered_indices):"""初始化筛选后的部分标签数据集:param dataset: 原始数据集对象:param partial_labels: 完整部分标签列表:param filtered_indices: 筛选后的样本索引列表"""self.dataset = datasetself.partial_labels = [partial_labels[i] for i in filtered_indices]self.indices = filtered_indicesdef __getitem__(self, index):original_index = self.indices[index] # image, _ = self.dataset[original_index]partial_label = torch.tensor(self.partial_labels[index], dtype=torch.long)true_label = torch.tensor(self.dataset.targets[original_index], dtype=torch.long) # 真实标签return image, partial_label, true_label #表示这个类实例化之后,返回的就是这个样本的图像和部分标签def __len__(self):return len(self.indices)
def filter_partial_label_dataset(dataset, partial_labels, candidate_size=3, filtered_labels=None, batch_size=64, shuffle=True, num_workers=2):"""过滤数据集以排除部分标签中含有任何 filtered_labels 的样本。:param dataset: 原始数据集(例如 CIFAR-10):param candidate_size: 候选标签集的大小(默认:3):param filtered_labels: 不得存在于部分标签中的标签列表:param batch_size: DataLoader 的批次大小(默认:4):param shuffle: 是否在 DataLoader 中打乱数据(默认:True):param num_workers: DataLoader 的工作线程数(默认:2):return: (过滤后的数据集, DataLoader) 元组"""if filtered_labels is None:raise ValueError("Filtered labels must be specified.")# 将部分标签转换为 NumPy 数组以进行高效过滤partial_labels_np = np.array(partial_labels)# 创建样本中不包含任何 filtered_labels 的掩码filtered_labels_mask = np.any(np.isin(partial_labels_np, filtered_labels), axis=1)final_mask = ~filtered_labels_mask # 这个索引列中,只有不含要过滤的标签的样本才为 True# 获取过滤后的索引filtered_indices = np.where(final_mask)[0] # 过滤后的样本的索引,每个值对是该样本在原始数据集中的索引,可以据此得到该样本的真实标签# 创建过滤后的部分标签数据集new_partial_label_dataset = FilteredPartialLabelDataset(dataset, partial_labels, filtered_indices)# 创建 DataLoadernew_partial_label_loader = DataLoader(new_partial_label_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)# 打印过滤后样本的信息print("过滤后的样本数量:", len(filtered_indices))# 可选:打印一个批次的示例for images, partial_labels_batch , true_labels_batch in new_partial_label_loader:print("新图像的形状:", images.shape)print("新部分标签:", partial_labels_batch)print("新真实标签:", true_labels_batch)breakreturn new_partial_label_dataset, new_partial_label_loader# 主函数:生成部分标签数据集并过滤
def main():# 生成部分标签数据集,不包含标签5、6、7、8、9partial_label_dataset = PartialLabelDataset(trainset, candidate_size=3)# 创建 DataLoadertrainloader = DataLoader(partial_label_dataset, batch_size=4, shuffle=True, num_workers=2)# 打印部分标签示例for images, partial_labels, true_labels in trainloader:print("图像的形状:", images.shape)print("部分标签:", partial_labels)print("真实标签:", true_labels)breakif __name__ == '__main__':main()partial_label_dataset = PartialLabelDataset(trainset, candidate_size=3)partial_labels = partial_label_dataset.generate_partial_labels()filter_partial_label_dataset(trainset, partial_labels, candidate_size=3, filtered_labels=[5, 6, 7, 8, 9], batch_size=4, shuffle=True, num_workers=2)