@wanghuijiao
2021-01-22T13:51:04.000000Z
字数 4649
阅读 1161
实验报告
# ap_now:E15-1 epoch=4; ap_before:E14 epoch=10
# 测试集:AVA_manfu_hik.val
+------------+-------+-------+--------+-------+---------+
| class | gts | dets | recall | ap_now|ap_before|
+------------+-------+-------+--------+-------+---------+
| stand | 18092 | 46136 | 0.804 | 0.276 | 0.259 |
| sit | 12715 | 32296 | 0.824 | 0.283 | 0.214 |
| squat | 238 | 1764 | 0.752 | 0.134 | 0.000 |
| lie | 989 | 12472 | 0.817 | 0.078 | 0.038 |
| half_lie | 1754 | 23154 | 0.808 | 0.081 | 0.071 |
| other_pose | 43 | 33 | 0.000 | 0.000 | 0.000 |
| person | 33831 | 69055 | 0.838 | 0.356 | 0.356 |
+------------+-------+-------+--------+-------+---------+
| mAP | | | | 0.173 | 0.134 |
+------------+-------+-------+--------+-------+---------+
/ssd01/wanghuijiao/FairMOT/output/output_video_E15_2/e15_0_hik5_ct0.1_dt0.2_nms0.2.mp4
结论
拟尝试解决方案
WeightedRandomSampler
class WeightedRandomSampler(Sampler[int]):
r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights).
Args:
weights (sequence) : a sequence of weights, not necessary summing up to one
num_samples (int): number of samples to draw
replacement (bool): if ``True``, samples are drawn with replacement.
If not, they are drawn without replacement, which means that when a
sample index is drawn for a row, it cannot be drawn again for that row.
generator (Generator): Generator used in sampling.
Example:
>>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
[4, 4, 1, 4, 5]
>>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
[0, 1, 4, 3, 2]
"""
weights: Tensor
num_samples: int
replacement: bool
def __init__(self, weights: Sequence[float], num_samples: int,
replacement: bool = True, generator=None) -> None:
if not isinstance(num_samples, _int_classes) or isinstance(num_samples, bool) or \
num_samples <= 0:
raise ValueError("num_samples should be a positive integer "
"value, but got num_samples={}".format(num_samples))
if not isinstance(replacement, bool):
raise ValueError("replacement should be a boolean value, but got "
"replacement={}".format(replacement))
self.weights = torch.as_tensor(weights, dtype=torch.double)
self.num_samples = num_samples
self.replacement = replacement
self.generator = generator
def __iter__(self):
rand_tensor = torch.multinomial(self.weights, self.num_samples, self.replacement, generator=self.generator)
return iter(rand_tensor.tolist())
def __len__(self):
return self.num_samples
class WeightedRandomSampler(Sampler):
r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights)."""
def __init__(self, weights, num_samples, replacement=True):
# ...省略类型检查
# weights用于确定生成索引的权重
self.weights = torch.as_tensor(weights, dtype=torch.double)
self.num_samples = num_samples
# 用于控制是否对数据进行有放回采样
self.replacement = replacement
def __iter__(self):
# 按照加权返回随机索引值
return iter(torch.multinomial(self.weights, self.num_samples, self.replacement).tolist())
对于Weighted Random Sampler
类的__init__()
来说,replacement
参数用于控制采样是否是有放回的(Ture
为放回抽样);num_sampler
用于控制生成的个数;weights
参数对应的是“样本”的权重而不是“类别的权重”。其中__iter__()
方法返回的数值为随机数序列,只不过生成的随机数序列是按照weights
指定的权重确定的,测试代码如下:
# 位置[0]的权重为0,位置[1]的权重为10,其余位置权重均为1.1
weights = torch.Tensor([0, 0.5, 0.1, 0.1, 0.1, 0.1, 0.1])
wei_sampler = sampler.WeightedRandomSampler(weights, 6, True)
# 下面是输出:
index: 1
index: 2
index: 3
index: 4
index: 1
index: 1
/ssd01/wanghuijiao/FairMOT/src/weights_generate.py
.train
,输出按图片顺序的对应权重值文件.weights
stand: 83410, sit: 66170, squat: 714, lie: 6004, half: 10063, other: 10
stand: 1/stand_num
sit: 1/sit_num
squat: 1/(squat_num*10)
lie: 1/lie_num
(后续尝试将hik对应的lie图片的权重*100来finetune模型,效果没有明显提升)half_lie: 1/half_lie_num
other: 1/stand_num
(other_pose的数量太少了,且模型目标是对前5种检测精度要求较高,因此赋予与stand相同的权重)Sampler
# 加载weights list
with open(opt.weights_dir,'r') as file:
line = file.readlines()
weights = line[0][1:-2].split(', ')
weights = list(map(float, weights))
# list -> Tensor
weights = torch.Tensor(weights)
sampler = torch.utils.data.sampler.WeightedRandomSampler(
weights, len(weights), replacement=True)
# Get dataloader
train_loader = torch.utils.data.DataLoader(dataset,
batch_size=opt.batch_size,
shuffle=False,
num_workers=opt.num_workers,
pin_memory=True,
drop_last=True,
sampler=sampler)