[关闭]
@wanghuijiao 2021-01-22T13:51:04.000000Z 字数 4649 阅读 1161

解决数据不均衡:为FairMOT添加WeightedRandomSampler

实验报告


0 前言

1. 添加到FairMOT后的效果

  1. # ap_now:E15-1 epoch=4; ap_before:E14 epoch=10
  2. # 测试集:AVA_manfu_hik.val
  3. +------------+-------+-------+--------+-------+---------+
  4. | class | gts | dets | recall | ap_now|ap_before|
  5. +------------+-------+-------+--------+-------+---------+
  6. | stand | 18092 | 46136 | 0.804 | 0.276 | 0.259 |
  7. | sit | 12715 | 32296 | 0.824 | 0.283 | 0.214 |
  8. | squat | 238 | 1764 | 0.752 | 0.134 | 0.000 |
  9. | lie | 989 | 12472 | 0.817 | 0.078 | 0.038 |
  10. | half_lie | 1754 | 23154 | 0.808 | 0.081 | 0.071 |
  11. | other_pose | 43 | 33 | 0.000 | 0.000 | 0.000 |
  12. | person | 33831 | 69055 | 0.838 | 0.356 | 0.356 |
  13. +------------+-------+-------+--------+-------+---------+
  14. | mAP | | | | 0.173 | 0.134 |
  15. +------------+-------+-------+--------+-------+---------+

2. 实验结论

3. WeightedRandomSampler 源码解析及使用示例

  1. class WeightedRandomSampler(Sampler[int]):
  2. r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights).
  3. Args:
  4. weights (sequence) : a sequence of weights, not necessary summing up to one
  5. num_samples (int): number of samples to draw
  6. replacement (bool): if ``True``, samples are drawn with replacement.
  7. If not, they are drawn without replacement, which means that when a
  8. sample index is drawn for a row, it cannot be drawn again for that row.
  9. generator (Generator): Generator used in sampling.
  10. Example:
  11. >>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
  12. [4, 4, 1, 4, 5]
  13. >>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
  14. [0, 1, 4, 3, 2]
  15. """
  16. weights: Tensor
  17. num_samples: int
  18. replacement: bool
  19. def __init__(self, weights: Sequence[float], num_samples: int,
  20. replacement: bool = True, generator=None) -> None:
  21. if not isinstance(num_samples, _int_classes) or isinstance(num_samples, bool) or \
  22. num_samples <= 0:
  23. raise ValueError("num_samples should be a positive integer "
  24. "value, but got num_samples={}".format(num_samples))
  25. if not isinstance(replacement, bool):
  26. raise ValueError("replacement should be a boolean value, but got "
  27. "replacement={}".format(replacement))
  28. self.weights = torch.as_tensor(weights, dtype=torch.double)
  29. self.num_samples = num_samples
  30. self.replacement = replacement
  31. self.generator = generator
  32. def __iter__(self):
  33. rand_tensor = torch.multinomial(self.weights, self.num_samples, self.replacement, generator=self.generator)
  34. return iter(rand_tensor.tolist())
  35. def __len__(self):
  36. return self.num_samples
  1. class WeightedRandomSampler(Sampler):
  2. r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights)."""
  3. def __init__(self, weights, num_samples, replacement=True):
  4. # ...省略类型检查
  5. # weights用于确定生成索引的权重
  6. self.weights = torch.as_tensor(weights, dtype=torch.double)
  7. self.num_samples = num_samples
  8. # 用于控制是否对数据进行有放回采样
  9. self.replacement = replacement
  10. def __iter__(self):
  11. # 按照加权返回随机索引值
  12. return iter(torch.multinomial(self.weights, self.num_samples, self.replacement).tolist())

对于Weighted Random Sampler类的__init__()来说,replacement参数用于控制采样是否是有放回的(Ture为放回抽样);num_sampler用于控制生成的个数;weights参数对应的是“样本”的权重而不是“类别的权重”。其中__iter__()方法返回的数值为随机数序列,只不过生成的随机数序列是按照weights指定的权重确定的,测试代码如下:

  1. # 位置[0]的权重为0,位置[1]的权重为10,其余位置权重均为1.1
  2. weights = torch.Tensor([0, 0.5, 0.1, 0.1, 0.1, 0.1, 0.1])
  3. wei_sampler = sampler.WeightedRandomSampler(weights, 6, True)
  4. # 下面是输出:
  5. index: 1
  6. index: 2
  7. index: 3
  8. index: 4
  9. index: 1
  10. index: 1

4. 加入FairMOT之前需要准备什么

5. 添加到FairMOT需要做什么

  1. # 加载weights list
  2. with open(opt.weights_dir,'r') as file:
  3. line = file.readlines()
  4. weights = line[0][1:-2].split(', ')
  5. weights = list(map(float, weights))
  6. # list -> Tensor
  7. weights = torch.Tensor(weights)
  8. sampler = torch.utils.data.sampler.WeightedRandomSampler(
  9. weights, len(weights), replacement=True)
  10. # Get dataloader
  11. train_loader = torch.utils.data.DataLoader(dataset,
  12. batch_size=opt.batch_size,
  13. shuffle=False,
  14. num_workers=opt.num_workers,
  15. pin_memory=True,
  16. drop_last=True,
  17. sampler=sampler)
添加新批注
在作者公开此批注前,只有你和作者可见。
回复批注