1 TSM算法论文和Github
2 参考https://blog.csdn.net/qq_18644873/article/details/89305928
3 如今行为动作识别,都是在探讨如何更好的描述时域信息特征。该文章在TSN基础上,提出Temporal Shift Module (TSM),既能保持高效又能有高性能。TSM模块是参考《Shift: A Zero FLOP, Zero Parameter Alternative to Spatial Convolutions》(该论文是探讨shift操作代替卷积操作,该论文还没看明白),提出了对时域进行shift操作,对于offline,对所有时域选择1/8channel数进行从前到后shift和从后到前shift;对online,选择对1/4全部进行从前到后shift,然后放到残差结构里面,即减少了数据移动操作,也提高了性能。
文章中提出的原因是,因为移动之后提高了时域的感受野,能进行更复杂的时域建模(For each inserted temporal shift module, the temporal receptive field will be enlarged by 2, as if running a convolution with the kernel size of 3 along the temporal dimension. Therefore, our TSM model has a very large temporal receptive field to conduct highly complicated temporal modeling.)
4 对于shift操作,第一个超参是移动多少,最终选定1/8channel left shift,然后1/8 channel right shift。其中shift操作选定是residual TSM,对于每个residual block,都用shift操作替代每个block中的conv1.
5 对于添加的Nonlocal操作,参照原文Nonlocal模块,对于resnet50在下图中前面4个block中,在第一个和第三个block后面增加了一个Nonlocal模块,然后对于后面6个block,在第一,三,五后面增加一个Nonlocal模块
6 代码中一些工作:
a. shift操作,其实就是将该帧特征,融入前后帧的特征信息,以增大时域感受野,当然对于shift操作,也是放在残差模块中。
class TemporalShift(nn.Module):def __init__(self, net, n_segment=3, n_div=8, inplace=True):super(TemporalShift, self).__init__()self.net = netself.n_segment = n_segmentself.fold_div = n_divself.inplace = inplaceif inplace:print('=> Using in-place shift...')print('=> Using fold div: {}'.format(self.fold_div))def forward(self, x):x = self.shift(x, self.n_segment, fold_div=self.fold_div, inplace=self.inplace)return self.net(x)@staticmethoddef shift(x, n_segment, fold_div=3, inplace=False):nt, c, h, w = x.size()n_batch = nt // n_segmentx = x.view(n_batch, n_segment, c, h, w)fold = c // fold_divif inplace:out = InplaceShift.apply(x, fold)else:out = torch.zeros_like(x)out[:, :-1, :fold] = x[:, 1:, :fold] # shift leftout[:, 1:, fold: 2 * fold] = x[:, :-1, fold: 2 * fold] # shift rightout[:, :, 2 * fold:] = x[:, :, 2 * fold:] # not shiftreturn out.view(nt, c, h, w)
a. 对于数据的稀疏采样和密集采样
def _sample_indices(self, record):""":param record: VideoRecord:return: list"""if self.dense_sample: # i3d dense samplesample_pos = max(1, 1 + record.num_frames - 64)t_stride = 64 // self.num_segmentsstart_idx = 0 if sample_pos == 1 else np.random.randint(0, sample_pos - 1)offsets = [(idx * t_stride + start_idx) % record.num_frames for idx in range(self.num_segments)]return np.array(offsets) + 1else: # normal sampleaverage_duration = (record.num_frames - self.new_length + 1) // self.num_segmentsif average_duration > 0:offsets = np.multiply(list(range(self.num_segments)), average_duration) + randint(average_duration,size=self.num_segments)elif record.num_frames > self.num_segments:offsets = np.sort(randint(record.num_frames - self.new_length + 1, size=self.num_segments))else:offsets = np.zeros((self.num_segments,))return offsets + 1
b. 一般图片的数据增强操作,对于训练集采用GroupMultiScaleCrop,对于测试集则是先scale在centercrop
class GroupMultiScaleCrop(object):def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True):self.scales = scales if scales is not None else [1, .875, .75, .66]self.max_distort = max_distortself.fix_crop = fix_cropself.more_fix_crop = more_fix_cropself.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size]self.interpolation = Image.BILINEARdef __call__(self, img_group):im_size = img_group[0].sizecrop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size)crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group]ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation)for img in crop_img_group]return ret_img_groupdef _sample_crop_size(self, im_size):image_w, image_h = im_size[0], im_size[1]# find a crop sizebase_size = min(image_w, image_h)crop_sizes = [int(base_size * x) for x in self.scales]crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes]crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes]pairs = []for i, h in enumerate(crop_h):for j, w in enumerate(crop_w):if abs(i - j) <= self.max_distort:pairs.append((w, h))crop_pair = random.choice(pairs)if not self.fix_crop:w_offset = random.randint(0, image_w - crop_pair[0])h_offset = random.randint(0, image_h - crop_pair[1])else:w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1])return crop_pair[0], crop_pair[1], w_offset, h_offsetdef _sample_fix_offset(self, image_w, image_h, crop_w, crop_h):offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h)return random.choice(offsets)@staticmethoddef fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h):w_step = (image_w - crop_w) // 4h_step = (image_h - crop_h) // 4ret = list()ret.append((0, 0)) # upper leftret.append((4 * w_step, 0)) # upper rightret.append((0, 4 * h_step)) # lower leftret.append((4 * w_step, 4 * h_step)) # lower rightret.append((2 * w_step, 2 * h_step)) # centerif more_fix_crop:ret.append((0, 2 * h_step)) # center leftret.append((4 * w_step, 2 * h_step)) # center rightret.append((2 * w_step, 4 * h_step)) # lower centerret.append((2 * w_step, 0 * h_step)) # upper centerret.append((1 * w_step, 1 * h_step)) # upper left quarterret.append((3 * w_step, 1 * h_step)) # upper right quarterret.append((1 * w_step, 3 * h_step)) # lower left quarterret.append((3 * w_step, 3 * h_step)) # lower righ quarterreturn ret
c. 对于预训练模型,采用partialBN,即第一层bn冻结,开放后面bn层参数
def train(self, mode=True):"""Override the default train() to freeze the BN parameters:return:"""super(TSN, self).train(mode)count = 0if self._enable_pbn and mode:print("Freezing BatchNorm2D except the first one.")for m in self.base_model.modules():if isinstance(m, nn.BatchNorm2d):count += 1if count >= (2 if self._enable_pbn else 1):m.eval()# shutdown update in frozen modem.weight.requires_grad = Falsem.bias.requires_grad = False
d. 对于不同层采用不同学习率进行训练,参考链接
def get_optim_policies(self):first_conv_weight = []first_conv_bias = []normal_weight = []normal_bias = []lr5_weight = []lr10_bias = []bn = []custom_ops = []conv_cnt = 0bn_cnt = 0for m in self.modules():if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv3d):ps = list(m.parameters())conv_cnt += 1if conv_cnt == 1:first_conv_weight.append(ps[0])if len(ps) == 2:first_conv_bias.append(ps[1])else:normal_weight.append(ps[0])if len(ps) == 2:normal_bias.append(ps[1])elif isinstance(m, torch.nn.Linear):ps = list(m.parameters())if self.fc_lr5:lr5_weight.append(ps[0])else:normal_weight.append(ps[0])if len(ps) == 2:if self.fc_lr5:lr10_bias.append(ps[1])else:normal_bias.append(ps[1])elif isinstance(m, torch.nn.BatchNorm2d):bn_cnt += 1# later BN's are frozenif not self._enable_pbn or bn_cnt == 1:bn.extend(list(m.parameters()))elif isinstance(m, torch.nn.BatchNorm3d):bn_cnt += 1# later BN's are frozenif not self._enable_pbn or bn_cnt == 1:bn.extend(list(m.parameters()))elif len(m._modules) == 0:if len(list(m.parameters())) > 0:raise ValueError("New atomic module type: {}. Need to give it a learning policy".format(type(m)))return [{'params': first_conv_weight, 'lr_mult': 5 if self.modality == 'Flow' else 1, 'decay_mult': 1,'name': "first_conv_weight"},{'params': first_conv_bias, 'lr_mult': 10 if self.modality == 'Flow' else 2, 'decay_mult': 0,'name': "first_conv_bias"},{'params': normal_weight, 'lr_mult': 1, 'decay_mult': 1,'name': "normal_weight"},{'params': normal_bias, 'lr_mult': 2, 'decay_mult': 0,'name': "normal_bias"},{'params': bn, 'lr_mult': 1, 'decay_mult': 0,'name': "BN scale/shift"},{'params': custom_ops, 'lr_mult': 1, 'decay_mult': 1,'name': "custom_ops"},# for fc{'params': lr5_weight, 'lr_mult': 5, 'decay_mult': 1,'name': "lr5_weight"},{'params': lr10_bias, 'lr_mult': 10, 'decay_mult': 0,'name': "lr10_bias"},]
TRN模型
1 TRN模型的backbone也是参照TSN模型,以代码中举例说明,前面提取特征,一样以8帧代表一个clip,得到8帧一共8×256的特征,然后用TRN模块,会从8帧中,选取[8, 7,6,5,4,3,2]分别作为子模块,对于2就是将8帧随机按顺序取其中2帧作为子模块的输入,对于所有子模块特征,应用2个卷积(先将channel变成256,在变成最终num_class),得到最终num_class特征,例如最终分类10类,得到batchx9的特征,然后将所有的子模块特征相加得到最终分类特征。
2 但是该模块,扩展性不好,对于较大输入帧数假设输入64帧,那样子模块太多,无法训练