总说

针对类别数目不均匀的数据,有些类图片多,有些少,如果直接训练,那么就会造成过拟合类别多的数据。最简单的方法就是重采样,直接根据每一类的数目,来重新分配权重。你想想,普通肯定是均匀概率采样的,自然数目多的图片,采样到的概率就大。

神奇的 WeightedRandomSampler

直接丢代码

# 数据集中,每一类的数目。
 class_sample_counts = [150, 200, 300]
 weights = 1./ torch.tensor(class_sample_counts, dtype=torch.float)
# 这个 get_classes_for_all_imgs是关键
 train_targets = train_dataset.get_classes_for_all_imgs()
 samples_weights = weights[train_targets]

 sampler = WeightedRandomSampler(weights=samples_weights, num_samples=len(samples_weights), replacement=True)
 # when using weightedRandomSampler, it is already balanced random, so DO NOT shuffle again
 train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=setting.batch_size, shuffle=False, \
 sampler = sampler, num_workers=opt.nThreads)

get_classes_for_all_imgs是关键,这个在自己的DataLoader中定义。返回的是,所有图片(按机器读取的顺序)的类别
举个例子,
这里假设,图片名称含有”A”的就是第0类,含有“B”的是第1类。
这个最关键的是,你必须给DataLoader的对应的index的每个sample,都手动设置label,否则它怎么知道每个sample是什么label,更别提给每个sample合适的权重概率了。

def TrainDataset(data.Dataset):
    def __init__(self):
        super(TrainDataset, self).__init__()
        img_paths = sorted(glob.glob(self.data_root+'/*'))
        self.classes_for_all_imgs = []
        for img_path in img_paths:
            class_id = 0
            if img_path.find('A'):
                class_id = 0
            elif img_path.find('B'):
                class_id = 1
            elif img_path.find('C'):
                class_id = 2
            else:
                assert 1==2
            self.classes_for_all_imgs.append(class_id)

    def __getitem__(self, index):
        ....
    def get_classes_for_all_imgs(self):
        return self.classes_for_all_imgs

最后要说的,是 shuffle只能为 False。原因:

Random sampling is done using some kind of probability distribution.
If you are shuffling the data, it could be seen as sampling from a uniform distribution without replacement.
Since you are providing the weights to sample each data sample, how should shuffling work in this case?


版权声明:本文为Hungryof原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
原文链接:https://blog.csdn.net/Hungryof/article/details/107609877