总说
针对类别数目不均匀的数据,有些类图片多,有些少,如果直接训练,那么就会造成过拟合类别多的数据。最简单的方法就是重采样,直接根据每一类的数目,来重新分配权重。你想想,普通肯定是均匀概率采样的,自然数目多的图片,采样到的概率就大。
神奇的 WeightedRandomSampler
直接丢代码
# 数据集中,每一类的数目。
class_sample_counts = [150, 200, 300]
weights = 1./ torch.tensor(class_sample_counts, dtype=torch.float)
# 这个 get_classes_for_all_imgs是关键
train_targets = train_dataset.get_classes_for_all_imgs()
samples_weights = weights[train_targets]
sampler = WeightedRandomSampler(weights=samples_weights, num_samples=len(samples_weights), replacement=True)
# when using weightedRandomSampler, it is already balanced random, so DO NOT shuffle again
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=setting.batch_size, shuffle=False, \
sampler = sampler, num_workers=opt.nThreads)
get_classes_for_all_imgs
是关键,这个在自己的DataLoader中定义。返回的是,所有图片(按机器读取的顺序)的类别。
举个例子,
这里假设,图片名称含有”A”的就是第0类,含有“B”的是第1类。
这个最关键的是,你必须给DataLoader的对应的index的每个sample,都手动设置label,否则它怎么知道每个sample是什么label,更别提给每个sample合适的权重概率了。
def TrainDataset(data.Dataset):
def __init__(self):
super(TrainDataset, self).__init__()
img_paths = sorted(glob.glob(self.data_root+'/*'))
self.classes_for_all_imgs = []
for img_path in img_paths:
class_id = 0
if img_path.find('A'):
class_id = 0
elif img_path.find('B'):
class_id = 1
elif img_path.find('C'):
class_id = 2
else:
assert 1==2
self.classes_for_all_imgs.append(class_id)
def __getitem__(self, index):
....
def get_classes_for_all_imgs(self):
return self.classes_for_all_imgs
最后要说的,是 shuffle只能为 False。原因:
Random sampling is done using some kind of probability distribution.
If you are shuffling the data, it could be seen as sampling from a uniform distribution without replacement.
Since you are providing the weights to sample each data sample, how should shuffling work in this case?
版权声明:本文为Hungryof原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。