第一人称动作识别文献阅读——LaViLa:从大型语言模型中学习视频表征信息

article/2025/7/19 5:31:47

目录

摘要

Abstract

1 引言

2 准备工作

3 LaViLa

3.1 NARRATOR

3.2 REPHRASER

3.3 双编码器训练

总结


摘要

本周阅读的论文题目是《Learning Video Representations from Large Language Models》(《从大型语言模型中学习视频表征信息》)。本文中提出了LaViLa,这是一种通过利用大型语言模型来学习视频-语言表示的新方法。LaViLa将预训练的LLMs重新用于视觉输入,并对其进行微调以创建自动视频叙述者。与传统的视频文本对齐方法相比,自动生成的叙述具有许多优点,包括对长视频的密集覆盖、视觉信息和文本的更好时间同步以及文本的更高多样性。LaViLa通过这些额外的自动生成叙述进行对比学习的视频-文本嵌入在多个第一人称和第三人称视频任务上,无论是在零样本还是微调设置中都优于之前的最佳水平。

Abstract

This week's paper is titled "Learning Video Representations from Large Language Models." In this paper, LaViLa, a new method for learning video-language representations by utilizing large language models, is proposed. LaViLa repurposes pre-trained LLMs for visual input and fine-tunes them to create automated video narrators. Auto-generated narratives have a number of advantages over traditional video-text alignment methods, including dense coverage of long videos, better time synchronization of visual information and text, and greater diversity of text. LaViLa uses these additional auto-generated narratives for contrastive learning video-text embedded on multiple first-person and third-person video tasks, both at zero shot and in fine-tuned settings, which are better than the previous best.

1 引言

使用网络规模的图像文本数据学习视觉表征信息是计算机视觉的一个强大工具,并且视觉语言方法在各种任务中推动了最前沿的发展,包括零样本分类、新颖物体检测甚至图像生成。然而,与数十亿规模的图像文本数据集相比,类似的方法在视频方面受到了配对视频文本语料库规模较小的限制,尽管在过去十年中原始视频数据的获取激增。生成式视觉语言模型(VLM)最初用于图像/视频标题,使用循环网络和基于Transformer的架构,生成式VLM通过在视觉-文本对上训练多模态Transformer统一了多个视觉任务,并且生成式VLM也擅长通过零样本或小样本提示进行多模态任务。所以,在本文中展示了通过利用生成式VLM自动生成此类视频的文本配对,从而充分利用大量视频数据,再使用这些自动生成的注释学习视频语言模型得到更强的表征信息。

本文中使用的方法为LaViLa,即Language-model augmented Video-Language pre-trainin(语言模型增强视频语言预训练),利用预训练的LLMs,其权重中包含丰富的知识宝库和对话能力,来用于“视觉条件叙述者”,并在所有可用的视频-文本配对片段上进行微调。如下图所示,LaViLa利用LLMs对长视频进行密集叙述,并使用这些叙述来训练强大的双编码器模型:

 而先前的工作使用的是人类标注的稀疏文本,或者从语音转录的弱对齐文本,而LaViLa能够利用由LLM生成的密集、多样且对齐良好的文本。一旦训练完成,使用该模型通过生成丰富的文本描述来密集标注数千小时的视频。这种伪监督可以渗透整个视频,包括注释片段之间和之外。与另一个训练用于改写现有叙述的LLM相结合,LaViLa能够为视频-文本对比学习创建更大、更多样化的文本目标。

LaViLa的出色表现可以归因于多个因素:

  • 可以提供时间上的密集监督长视频;
  • 生成的文本与视觉输入很好地对齐;
  • 能够自动对每个视频进行多次且密度更高的叙述,从而在只有少量可用的情况下显著扩展注释学习到更强的表征。

2 准备工作

视频V 是一系列动态图像I 的流,视频的帧数|V| 可以任意长,而视频模型通常在较短的剪辑上操作,这些剪辑通常在几秒钟的范围内。因此,在浏览长篇视频时使用一组N 个短剪辑来表示它,即\texttt{X} ,每个剪辑x_i 由特定的起始和结束帧定义,x_i=\begin{Bmatrix} I_{t_i},\cdots ,I_{e_i} \end{Bmatrix} ,其中0\leq t_i\leq e_i\leq |V| ,并且通常与某些注释y_i 相关联,这种注释可以是类别标签或剪辑的自由形式文本描述。然后用带有相应注释的注释剪辑集来表示视频,即(\texttt{X} ,\texttt{y})=\begin{Bmatrix} (x_1,y_1),\cdots ,(x_N,y_N) \end{Bmatrix} 。但是由于注释成本和视觉冗余,注释剪辑通常不能密集地覆盖整个视频,即\bigcup _i[t_i,e_i]\nsubseteq [0,|V|]

典型的视频模型\texttt{F}(\texttt{X} ,\texttt{y}) 通过标准训练目标(如交叉熵损失)从这些剪辑级别注释中学习,当注释是具有固定词汇的分类标签时。然而,最近基于双编码器的对比方法(如CLIP)已经变得流行。它们使用自由形式的文本注释,这些注释被标记化成离散符号的序列,即y=(s_1,s_2,\cdots s_L)\in \begin{Bmatrix} 1,0 \end{Bmatrix}^{|\mathbb{S}|\times L} 。该模型由一个视觉编码器f_{\textup{v}}:\mathbb{R}^{T\times 3\times H\times W} \mapsto \mathbb{R}^{D_{\textup{v}}} ,一个投影头h_{\textup{v}}:\mathbb{R}^{D_{\textup{v}}} \mapsto \mathbb{R}^{d} 和一个文本编码器f_{\textup{t}}:\begin{Bmatrix} 1,0 \end{Bmatrix} ^{|\mathbb{S}|\times L} \mapsto \mathbb{R}^{D_{\textup{t}}} ,以及一个投影头h_{\textup{t}}:\mathbb{R}^{D_{\textup{t}}} \mapsto \mathbb{R}^{d} 并行获得全局视觉和文本嵌入分别:

\textup{v}=h_{\textup{v}}(f_{\textup{v}}(x)) , \textup{u}=h_{\textup{t}}(f_{\textup{t}}(y)) .

CLIP代码如下:

class CLIP(nn.Module):def __init__(self,embed_dim: int,# visionvision_width: int,vision_model: nn.Module,# textcontext_length: int,vocab_size: int,transformer_width: int,transformer_heads: int,transformer_layers: int,tempearture_init=0.07,**kwargs,):super().__init__()self.context_length = context_lengthself.vision_width = vision_widthself.visual = vision_modelself.transformer = Transformer(width=transformer_width,layers=transformer_layers,heads=transformer_heads,attn_mask=self.build_attention_mask(),)self.vocab_size = vocab_sizeself.token_embedding = nn.Embedding(vocab_size, transformer_width)self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))self.ln_final = nn.LayerNorm(transformer_width)  # used to be `models.transformer.LayerNorm``self.image_projection = nn.Parameter(torch.empty(vision_width, embed_dim))self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))print("=> initialize initial temperature with {}".format(tempearture_init))self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / tempearture_init))self.initialize_parameters()def initialize_parameters(self):nn.init.normal_(self.token_embedding.weight, std=0.02)nn.init.normal_(self.positional_embedding, std=0.01)proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)attn_std = self.transformer.width ** -0.5fc_std = (2 * self.transformer.width) ** -0.5for block in self.transformer.resblocks:nn.init.normal_(block.attn.in_proj_weight, std=attn_std)nn.init.normal_(block.attn.out_proj.weight, std=proj_std)nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)nn.init.normal_(self.image_projection, std=self.vision_width ** -0.5)nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)def build_attention_mask(self):# lazily create causal attention mask, with full attention between the vision tokens# pytorch uses additive attention mask; fill with -infmask = torch.empty(self.context_length, self.context_length)mask.fill_(float("-inf"))mask.triu_(1)  # zero out the lower diagonalreturn maskdef encode_image(self, image, use_checkpoint=False, apply_project=True):x = self.visual(image, use_checkpoint=use_checkpoint)if isinstance(x, list):assert len(x) == 1x = x[0]if not apply_project:return xx = x @ self.image_projectionreturn xdef encode_text(self, text, use_checkpoint=False):x = self.token_embedding(text)  # [batch_size, n_ctx, d_model]x = x + self.positional_embeddingx = x.permute(1, 0, 2)  # NLD -> LNDx = self.transformer(x, use_checkpoint=use_checkpoint)x = x.permute(1, 0, 2)  # LND -> NLDx = self.ln_final(x)# x.shape = [batch_size, n_ctx, transformer.width]# take features from the eot embedding (eot_token is the highest number in each sequence)x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projectionreturn xdef forward(self, image, text, use_checkpoint=False, norm_embed=False):image_embed = self.encode_image(image, use_checkpoint=use_checkpoint)text_embed = self.encode_text(text, use_checkpoint=use_checkpoint)if norm_embed:image_embed = F.normalize(image_embed, dim=-1)text_embed = F.normalize(text_embed, dim=-1)return {'image_embed': image_embed,'text_embed': text_embed,'logit_scale': self.logit_scale.exp()}

一个对比损失(如InfoNCE)学习全局嵌入,将批次样本\mathfrak{B} 中的对应视频和文本嵌入关联起来,得到:

\frac{1}{| \mathfrak{B} |}\sum _{(x,y)\in \mathfrak{B} }(\textup{InfoNCE}(\textup{v,u})+\textup{InfoNCE}(\textup{u,v})) .

损失函数代码如下:

class CLIPLoss(nn.Module):def __init__(self,use_vissl=False,local_loss=False,gather_with_grad=False,cache_labels=False,rank=0,world_size=1,):super().__init__()self.use_vissl = use_visslself.local_loss = local_lossself.gather_with_grad = gather_with_gradself.cache_labels = cache_labelsself.rank = rankself.world_size = world_size# cache stateself.prev_num_logits = 0self.labels = {}def forward(self, outputs):image_features = outputs['image_embed']text_features = outputs['text_embed']logit_scale = outputs['logit_scale']device = image_features.deviceif self.world_size > 1:if self.use_vissl:all_image_features = gather_from_all(image_features)all_text_features = gather_from_all(text_features)logits_per_image = logit_scale * all_image_features @ all_text_features.Tlogits_per_text = logits_per_image.Telse:all_image_features, all_text_features = gather_features(image_features, text_features,self.local_loss, self.gather_with_grad, self.rank, self.world_size)if self.local_loss:logits_per_image = logit_scale * image_features @ all_text_features.Tlogits_per_text = logit_scale * text_features @ all_image_features.Telse:logits_per_image = logit_scale * all_image_features @ all_text_features.Tlogits_per_text = logits_per_image.Telse:logits_per_image = logit_scale * image_features @ text_features.Tlogits_per_text = logit_scale * text_features @ image_features.T# calculated ground-truth and cache if enablednum_logits = logits_per_image.shape[0]if self.prev_num_logits != num_logits or device not in self.labels:labels = torch.arange(num_logits, device=device, dtype=torch.long)if self.world_size > 1 and self.local_loss:labels = labels + num_logits * self.rankif self.cache_labels:self.labels[device] = labelsself.prev_num_logits = num_logitselse:labels = self.labels[device]loss = (F.cross_entropy(logits_per_image, labels) +F.cross_entropy(logits_per_text, labels)) / 2# compute accuracywith torch.no_grad():pred = torch.argmax(logits_per_image, dim=-1)correct = pred.eq(labels).sum()acc = 100 * correct / logits_per_image.size(0)return {'loss': loss, 'clip_loss': loss, 'clip_acc': acc}

3 LaViLa

在LaViLa 中,使用LLMs作为监督来训练双编码器模型,其中LLMs作为视觉条件叙述者自动从视频片段生成文本描述。LaViLa使用了来自两个LLMs的监督:

  1. NARRATOR:一种视觉条件下的LLM,用叙述对现有和新剪辑进行伪标签,生成新的注释(\texttt{X}' ,\texttt{y}') ;
  2. REPHRASER:一种标准的LLM,对现有剪辑中的叙述进行释义,增强这些注释到(\texttt{X} ,\texttt{y}'') 。

如下图所示,NARRATOR生成对正在发生的行为的新描述,可能关注其他与之交互的对象;REPHRASER用于增强文本输入,例如改变人类叙述的词序,并替换常见的动词或名词,使注释更加多样化:

最后,在所有这些注释的组合上训练双编码器,即(\texttt{X} ,\texttt{y})\cup (\texttt{X}' ,\texttt{y}')\cup (\texttt{X} ,\texttt{y}'')

3.1 NARRATOR

传统LLMs被训练从头开始生成一系列文本标记(s_1,\cdots ,s_L) ,通过建模给定所有已看到的标记的下一个标记的概率:p(s_l|s_{<l}) 。NARRATOR将现有的LLMs重新用于条件化视觉输入,并在原始注释(\texttt{X} ,\texttt{y}) 上训练,然后在完整视频上产生密集的新注释(\texttt{X}' ,\texttt{y}') 。遵循语言模型中因子概率的公式,建模视觉条件化文本似然:

p_{\textup{NARRATOR}}(y'|x')=\prod _{l=1}^{L}p(s_l'|s'_{<l},x') .

如下图所示,NARRATOR紧密遵循标准LLMs的架构以视频帧为输入,通过视频编码器获得视觉嵌入,然后通过注意力池化获得。仅添加了几个额外的交叉注意力模块以提供视觉约束,文本解码器自回归地为这些新帧生成新的叙述:

这使得NARRATOR能够从预训练的权重初始化,从而使得用于训练NARRATOR(与视频片段相关的叙述)的数据规模远小于通常用于训练LLMs的大规模文本语料库。

此外,视频叙述的多样性较低且噪声较大,因为它们要么是由少数注释员收集的,要么是从语音自动转录的。由此,在多模态少样本自适应中使用“冻结-LM”方法。在冻结的预训练模型LLM中:

  • 在每个Transformer解码器层之前添加一个交叉注意力模块,以便文本输入可以关注视觉信息;
  • 然后,通过残差连接将交叉注意力输出与输入文本特征相加,并进入 Transformer 解码器层;
  • 每个交叉注意力模块包含一个交叉注意力层,该层将文本标记作为查询,将视觉嵌入作为键和值;
  • 之后是一个前馈网络(FFN);
  • 在交叉注意力和 FFN 的开始处应用层归一化;
  • 还添加了tanh门,初始值为零,使得新模型的输出在开始时与原始语言模型相同。

虽然任何视频模型的特征都适用于条件化,但为了方便,本文中采用上一节中的视频编码器\texttt{F} ,该编码器在真实数据(\texttt{X} ,\texttt{y}) 上进行了对比训练。然后使用全局池化之前的特征,以便LLM能够利用细粒度的时空信息。

在训练过程中,训练NARRATOR非所有或真实标注的子集(\texttt{X} ,\texttt{y}) 。对于每一对(x ,y) ,字幕损失是每个步骤中正确单词的负对数似然之和:

\pounds _{\textup{NARRATOR}}(x,y)=\sum _{l=1}^{L}\textup{log}\; p(s_l|s_{<l},x) .

在推理时,通过输入视觉输入x 加上一个特殊的句子起始标记<s>来查询NARRATOR ,然后递归地从分布中进行采样,即\tilde{s}_{l}\sim p(s|[<s>,\cdots ,\tilde{s}_{l-1}],x) 直到到达句尾标记</s> 。在每一步中,从包含绝大多数概率质量的词元子集中采样,这被称为原子核采样,核采样效果有两方面:

  • 一方面,它比基于最大似然的方法生成更多样化、开放式和类似人类的文本;
  • 另一方面,由于没有基于句子级似然的后处理采样,生成的文本可能包含无关或噪声信息。

为了解决这个问题,在相同的视觉输入上重复采样过程K 次,对比预训练目标对采样引起的噪声具有鲁棒性,最终性能得益于更多样化的叙述集。

为了对视频片段进行字幕采样:

  1. 首先,简单地重新为数据集\texttt{X} 中标记的现有片段添加字幕,从而扩展了注释;
  2. 此外,长视频通常叙述稀疏,这意味着所有标记片段的时间总和无法覆盖整个视频;
  3. 因此,使用NARRATOR来注释视频的剩余部分,通过伪字幕获得额外的注释;
  4. 在简单地假设视频是平稳过程的前提下,从视频中均匀地采样片段未标记的区间,剪辑时长等于所有真实剪辑的平均值,即\Delta =\frac{1}{N}\sum _{i=1}^{N}(e_i-t_i) ,当采样步长同样计算后;
  5. 最后,通过结合重新配对的和伪配对的叙述,称NARRATOR为生成的最终标注集(\texttt{X}' ,\texttt{y}') 。

最后进行后处理,来彻底的伪配对可能包含一些无信息的视觉片段,并生成无用的文本。因此,添加一个过滤过程以消除低质量片段及其相关描述。使用在真实配对片段上训练的基线双编码器模型\texttt{F} ,计算伪标签对的视觉和文本嵌入,并根据相似度分数进行过滤,即\textup{Filter}(f_v(x'_j)^{\top }\cdot f_t(y'_j)) ,其中\textup{Filter}(\cdot ) 可以是所有生成文本的前百分之几或阈值过滤。在本实验中,使用 0.5 的阈值。

代码如下:

class VCLM_HF(nn.Module):def __init__(self,# visionvision_width: int,vision_model: nn.Module,# texttext_width: int,text_decoder: nn.Module,num_img_queries=256,dim_head=64,heads=8,**kwargs,):super().__init__()self.vision_width = vision_widthself.visual = vision_modelself.text_width = text_widthself.text_decoder = text_decoderself.img_queries = nn.Parameter(torch.empty(num_img_queries, text_width))self.img_attn_pool = CrossAttention(dim=text_width, context_dim=vision_width,dim_head=dim_head, heads=heads,norm_context=True)self.img_attn_pool_norm = LayerNorm(text_width)self.initialize_parameters()def initialize_parameters(self):nn.init.normal_(self.img_queries, std=self.text_width ** -0.5)def encode_image(self, image, use_checkpoint=False):if isinstance(self.visual, VisionTransformer):# openai_model.VisionTransformer accepts (N, C, H, W) instead of (N, C, T, H, W)image = image.permute(0, 2, 1, 3, 4)  # BCTHW -> BTCHWbb, tt, _, _, _ = image.shapex = self.visual(image.reshape(-1, *image.shape[2:]), use_checkpoint=use_checkpoint, cls_at_last=False)  # NLDx = x.view(bb, tt, *x.shape[1:])x = x.permute(0, 3, 1, 2)elif isinstance(self.visual, SpaceTimeTransformer):image = image.permute(0, 2, 1, 3, 4).contiguous()  # BCTHW -> BTCHWbb, tt, _, _, _ = image.shapex = self.visual.forward_features(image, use_checkpoint=use_checkpoint, cls_at_last=False)  # NLDx = x.permute(0, 2, 1)else:x = self.visual(image, use_checkpoint=use_checkpoint, mean_at_last=False)if isinstance(x, list):assert len(x) == 1x = x[0]x = x.flatten(start_dim=2)  # BDTHW -> BD(THW)x = x.permute(0, 2, 1)      # BDN -> BNDimg_queries = repeat(self.img_queries, 'n d -> b n d', b=x.shape[0])img_queries = self.img_attn_pool(img_queries, x)img_queries = self.img_attn_pool_norm(img_queries)return img_queriesdef forward(self, image, text, mask=None, use_checkpoint=False, norm_embed=False):if use_checkpoint:self.text_decoder.gradient_checkpointing_enable()else:self.text_decoder.gradient_checkpointing_disable()text, labels = text[:, :-1], text[:, 1:]# mask = mask[:, :-1]image_tokens = self.encode_image(image, use_checkpoint=use_checkpoint)output_decoder = self.text_decoder(text.contiguous(), encoder_hidden_states=image_tokens)text_tokens_logits = output_decoder.logitstext_tokens_logits = rearrange(text_tokens_logits, 'b n c -> b c n')return {'text_tokens_logits': text_tokens_logits,'labels': labels}def generate(self, image_tokens, tokenizer, target=None, max_text_length=77, top_k=None, top_p=None,num_return_sequences=1, temperature=1.0, teacher_forcing=False, early_stopping=False):image_tokens = image_tokens.repeat_interleave(num_return_sequences, dim=0)device = image_tokens.devicegenerated_text_ids = torch.LongTensor([[tokenizer.bos_token_id]] * image_tokens.shape[0]).to(device)condition_text_ids = generated_text_ids.clone()logits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, typical_p=None, temperature=temperature, num_beams=1)nlls, num_tokens = torch.zeros(image_tokens.shape[0]).to(device), torch.zeros(image_tokens.shape[0]).to(device)is_reach_eos = torch.zeros(image_tokens.shape[0]).bool().to(device)with torch.no_grad():for i in range(max_text_length - 1):output_decoder = self.text_decoder(condition_text_ids, encoder_hidden_states=image_tokens)decoded_token_logits = output_decoder.logitsnext_token_logits = decoded_token_logits[:, -1, :]if target is not None:nll = F.cross_entropy(next_token_logits, target[:, i+1], ignore_index=tokenizer.pad_token_id, reduction='none')nlls += nllnum_tokens += target[:, i+1].ne(tokenizer.pad_token_id)else:nll = torch.special.entr(F.softmax(next_token_logits, dim=1)).sum(dim=1)nlls += nll * (~is_reach_eos)num_tokens += (~is_reach_eos)# filtered_p = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p, device=device)next_token_logits = logits_warper(generated_text_ids, next_token_logits)filtered_p = F.softmax(next_token_logits, dim=-1)next_token = torch.multinomial(filtered_p, num_samples=1)is_reach_eos = is_reach_eos | (next_token[:, 0] == tokenizer.eos_token_id)if early_stopping and torch.all(is_reach_eos):breakif teacher_forcing:condition_text_ids = target[:, :i+2]else:condition_text_ids = torch.cat((generated_text_ids, next_token), dim=1)generated_text_ids = torch.cat((generated_text_ids, next_token), dim=1)if target is not None:return generated_text_ids, torch.exp(nlls / num_tokens)else:return generated_text_ids, torch.exp(nlls / num_tokens)def beam_sample(self, image_tokens, tokenizer, target=None, max_text_length=77, top_k=None, top_p=None,temperature=1.0, length_penalty=1.,num_beams=3, num_return_sequences=1, teacher_forcing=False, early_stopping=False):batch_size = image_tokens.shape[0]device = image_tokens.deviceinput_ids = torch.ones((batch_size, 1), device=device, dtype=torch.long)input_ids = input_ids * tokenizer.bos_token_idexpanded_return_idx = (torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, num_beams * num_return_sequences).view(-1).to(device))input_ids = input_ids.index_select(0, expanded_return_idx)batch_beam_size, cur_len = input_ids.shapelogits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, typical_p=None, temperature=temperature, num_beams=num_beams)beam_scorer = BeamSearchScorer(batch_size=batch_size * num_return_sequences, num_beams=num_beams,device=device,length_penalty=length_penalty,)batch_size = len(beam_scorer._beam_hyps)num_beams = beam_scorer.num_beamsbeam_scores = torch.zeros((batch_size, num_beams)).to(device)beam_scores = beam_scores.view((batch_size * num_beams,))is_reach_eos = torch.zeros(batch_beam_size).bool().to(device)with torch.no_grad():for i in range(max_text_length - 1):output_decoder = self.text_decoder(input_ids,encoder_hidden_states=image_tokens.repeat_interleave(num_beams * num_return_sequences, dim=0))decoded_token_logits = output_decoder.logitsnext_token_logits = decoded_token_logits[:, -1, :]next_token_scores = F.log_softmax(next_token_logits, dim=-1)  # (batch_size * num_beams, vocab_size)# supposed to be the line below, but ignore temporarily# next_token_scores_processed = logits_processor(input_ids, next_token_scores)next_token_scores_processed = next_token_scoresnext_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)# supposed to be the line below, but do a simple top_k+top_p temporarilynext_token_scores = logits_warper(input_ids, next_token_scores)# next_token_scores = top_k_top_p_filtering(next_token_scores, top_k=top_k, top_p=top_p, device=device)vocab_size = next_token_scores.shape[-1]next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)probs = F.softmax(next_token_scores, dim=-1)next_tokens = torch.multinomial(probs, num_samples=2 * num_beams)next_token_scores = torch.gather(next_token_scores, -1, next_tokens)next_token_scores, _indices = torch.sort(next_token_scores, descending=True, dim=1)next_tokens = torch.gather(next_tokens, -1, _indices)next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")next_tokens = next_tokens % vocab_size# statelessbeam_outputs = beam_scorer.process(input_ids,next_token_scores,next_tokens,next_indices,pad_token_id=tokenizer.pad_token_id,eos_token_id=tokenizer.eos_token_id,)beam_scores = beam_outputs["next_beam_scores"]beam_next_tokens = beam_outputs["next_beam_tokens"]beam_idx = beam_outputs["next_beam_indices"]input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)is_reach_eos = is_reach_eos | (input_ids[:, -1] == tokenizer.eos_token_id)if beam_scorer.is_done or torch.all(is_reach_eos):breaksequence_outputs = beam_scorer.finalize(input_ids,beam_scores,next_tokens,next_indices,pad_token_id=tokenizer.pad_token_id,eos_token_id=tokenizer.eos_token_id,max_length=max_text_length,)sequences = sequence_outputs["sequences"]sequence_scores = sequence_outputs["sequence_scores"]return sequences, sequence_scoresdef group_beam_search(self, image_tokens, tokenizer, target=None, max_text_length=77, top_k=None, top_p=None,temperature=1.0, length_penalty=1.,num_beams=6, num_beam_groups=3,num_return_sequences=1, teacher_forcing=False, early_stopping=False):batch_size = image_tokens.shape[0]device = image_tokens.deviceinput_ids = torch.ones((batch_size, 1), device=device, dtype=torch.long)input_ids = input_ids * tokenizer.bos_token_idexpanded_return_idx = (torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1, num_beams).view(-1).to(device))input_ids = input_ids.index_select(0, expanded_return_idx)batch_beam_size, cur_len = input_ids.shapelogits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, typical_p=None, temperature=temperature, num_beams=num_beams)beam_scorer = BeamSearchScorer(batch_size=batch_size, num_beams=num_beams,num_beam_groups=num_beam_groups,num_beam_hyps_to_keep=num_return_sequences, device=device,length_penalty=length_penalty,)num_sub_beams = num_beams // num_beam_groupsbeam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)beam_scores[:, ::num_sub_beams] = 0beam_scores = beam_scores.view((batch_size * num_beams,))is_reach_eos = torch.zeros(batch_beam_size).bool().to(device)with torch.no_grad():# predicted tokens in cur_len stepcurrent_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)# indices which will form the beams in the next time stepreordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)for i in range(max_text_length - 1):output_decoder = self.text_decoder(input_ids,encoder_hidden_states=image_tokens.repeat_interleave(num_beams, dim=0))decoded_token_logits = output_decoder.logitsfor beam_group_idx in range(num_beam_groups):group_start_idx = beam_group_idx * num_sub_beamsgroup_end_idx = min(group_start_idx + num_sub_beams, num_beams)group_size = group_end_idx - group_start_idx# indices of beams of current group among all sentences in batchbatch_group_indices = []for batch_idx in range(batch_size):batch_group_indices.extend([batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)])group_input_ids = input_ids[batch_group_indices]# select outputs of beams of current group onlynext_token_logits = decoded_token_logits[batch_group_indices, -1, :]next_token_scores = F.log_softmax(next_token_logits, dim=-1)  # (batch_size * num_beams, vocab_size)vocab_size = next_token_scores.shape[-1]# supposed to be the line below, but ignore temporarily# next_token_scores_processed = logits_processor(input_ids, next_token_scores)next_token_scores_processed = next_token_scoresnext_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)next_token_scores = next_token_scores.expand_as(next_token_scores_processed)next_token_scores = logits_warper(input_ids, next_token_scores)# next_token_scores = top_k_top_p_filtering(next_token_scores, top_k=top_k, top_p=top_p, device=device)# reshape for beam searchnext_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)next_token_scores, next_tokens = torch.topk(next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True)next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")next_tokens = next_tokens % vocab_size# statelessbeam_outputs = beam_scorer.process(group_input_ids,next_token_scores,next_tokens,next_indices,pad_token_id=tokenizer.pad_token_id,eos_token_id=tokenizer.eos_token_id,beam_indices=None)beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]beam_next_tokens = beam_outputs["next_beam_tokens"]beam_idx = beam_outputs["next_beam_indices"]input_ids[batch_group_indices] = group_input_ids[beam_idx]group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)current_tokens[batch_group_indices] = group_input_ids[:, -1]reordering_indices[batch_group_indices] = (num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size))input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)is_reach_eos = is_reach_eos | (input_ids[:, -1] == tokenizer.eos_token_id)if beam_scorer.is_done or torch.all(is_reach_eos):breaksequence_outputs = beam_scorer.finalize(input_ids,beam_scores,next_tokens,next_indices,pad_token_id=tokenizer.pad_token_id,eos_token_id=tokenizer.eos_token_id,max_length=max_text_length,beam_indices=None,)sequences = sequence_outputs["sequences"]sequence_scores = sequence_outputs["sequence_scores"]return sequences, sequence_scoresdef _get_logits_warper(self, top_k=None, top_p=None, typical_p=None,temperature=None, num_beams=None, renormalize_logits=None,):top_k = top_k if top_k is not None else 0top_p = top_p if top_p is not None else 1.0typical_p = typical_p if typical_p is not None else 1.temperature = temperature if temperature is not None else 1.warpers = LogitsProcessorList()if temperature is not None and temperature != 1.0:warpers.append(TemperatureLogitsWarper(temperature))if top_k is not None and top_k != 0:warpers.append(TopKLogitsWarper(top_k=top_k, min_tokens_to_keep=(2 if num_beams > 1 else 1)))if top_p is not None and top_p < 1.0:warpers.append(TopPLogitsWarper(top_p=top_p, min_tokens_to_keep=(2 if num_beams > 1 else 1)))if typical_p is not None and typical_p < 1.0:warpers.append(TypicalLogitsWarper(mass=typical_p, min_tokens_to_keep=(2 if num_beams > 1 else 1)))# `LogitNormalization` should always be the last logit processor, when presentif renormalize_logits is True:warpers.append(LogitNormalization())return warpers

3.2 REPHRASER

NARRATOR生成数据比真实配对大几倍,为确保不会过度拟合伪标签数据,通过释义增加真实叙述的数量,特别是使用文本到文本的LLM,该模型建模条件文本似然:

p_{\textup{REPHRASER}}(y''|y)=\prod _{l=1}^{L}p(s_l''|s''_{<l},y)  .

如下图所示,REPHRASER输入叙述,通过文本编码器传递,并使用文本解码器自回归地生成重述的输出:

文本到文本模型通过编码器-解码器架构实现,例如T5,根据原始句子自动回归生成新句子。REPHRASER能够进行基本的操作,如替换同义词或改变词序,这作为自动数据增强的有效方式。生成的标注被称为(\texttt{X} ,\texttt{y}'') 。

3.3 双编码器训练

在每次迭代中,首先采样一个视频片段批次\mathfrak{B} 。它包括带有标记时间戳和叙述的片段子集\mathfrak{B}_l ,以及从无叙述的视频中随机采样的片段子集\mathfrak{B}_u 。对于片段x_i \in \mathfrak{B}_u ,通过查询NARRATORy'_i \sim p_{\textup{REPHRASER}}(y'|x) 获得伪字幕y'_i ,从而得到带有LLM生成的叙述的片段集\tilde{\mathfrak{B}}_u 。对于片段(x_i,y_i) \in \mathfrak{B}_l ,文本监督来自REPHRASER或NARRATOR,概率为0.5,称由此产生的对为\tilde{\mathfrak{B}}_l 。类似地,遵循CLIP,在批次中样本的相似度得分上使用对称交叉熵损失\tilde{\mathfrak{B}}_l\cup \tilde{\mathfrak{B}}_u 。

在实践中,提前运行REPHRASER和NARRATOR并缓存生成的视频-叙述对,以便在预训练期间没有计算开销,因此,在LaViLa中训练双编码器与训练标准双编码器对比模型一样快。

def VCLM_OPENAI_TIMESFORMER_LARGE_336PX_GPT2_XL(gated_xattn=False,freeze_lm_vclm=False,freeze_visual_vclm=False,freeze_visual_vclm_temporal=False,num_frames=4,timesformer_gated_xattn=False,**kwargs,
):vision_model = SpaceTimeTransformer(img_size=336, patch_size=14,embed_dim=1024, depth=24, num_heads=16,num_frames=num_frames,time_init='zeros',attention_style='frozen-in-time',ln_pre=True,act_layer=QuickGELU,is_tanh_gating=timesformer_gated_xattn,)clip_model, _ = load_openai_clip('ViT-L/14@336px', 'cpu')print("=> Loading CLIP (ViT-L/14@336px) weights")remapped_state_dict = remap_keys(clip_model.visual.state_dict(), transformer_layers=24)res = vision_model.load_state_dict(remapped_state_dict, strict=False)print(res)vision_model.head = nn.Identity()vision_model.pre_logits = nn.Identity()vision_model.fc = nn.Identity()gpt2 = GPT2LMHeadModel.from_pretrained("gpt2-xl",use_cache=False,)new_config = augment_gpt2_config(gpt2.config, cross_attn_freq=3, gated_xattn=gated_xattn)text_decoder = GatedGPT2LMHeadModel(new_config)for n, p in gpt2.named_parameters():rsetattr(text_decoder, n + '.data', p.data)if freeze_lm_vclm:print('Freeze the LM part of TextDecoder of VCLM')text_decoder.freeze_lm_weights()if freeze_visual_vclm:print('Freeze the spatial part of VideoEncoder of VCLM')vision_model.freeze_spatial_weights()if freeze_visual_vclm_temporal:print('Freeze the temporal part of VideoEncoder of VCLM')vision_model.freeze_temporal_weights()model = VCLM_HF(vision_width=1024,vision_model=vision_model,text_width=1600,text_decoder=text_decoder,num_img_queries=256,dim_head=64,heads=25,**kwargs,)return model

总结

在本文中,提出了LaViLa,这是一种通过使用LLMs自动叙述长视频的新方法,用于视频-语言表示学习,为视频理解提供了一种高效的数据增强范式。LaViLa利用LLM生成伪标签,减少对人工标注的依赖。LaViLa的双编码器架构,即视频编码器与文本编码器通过对比学习对齐特征,支持跨模态检索和生成任务。LaViLa在使用相同数量的人叙述视频的基线模型上取得了显著的改进,并在六个流行的基准任务上取得了新的最先进成果,这些任务涵盖了第一人称和第三人称视频理解基准。当增加更多训练叙述、使用更大的视觉骨干网络和使用更强的LLMs时,LaViLa也显示出积极的缩放行为和性能。


http://www.hkcw.cn/article/rSLdiJvleI.shtml

相关文章

外网访问内网海康威视监控视频的方案:WebRTC + Coturn 搭建

外网访问内网海康威视监控视频的方案&#xff1a;WebRTC Coturn 需求背景 在仓库中有海康威视的监控摄像头&#xff0c;内网中是可以直接访问到监控摄像的画面&#xff0c;由于项目的需求&#xff0c;需要在外网中也能看到监控画面。 实现这个功能的意义在于远程操控设备的…

基于PyQt5的UI界面开发——图像与视频的加载与显示

介绍 这里我们的主要目标是实现一个基于PyQt5和OpenCV的图像浏览和视频播放应用。用户可以选择本地的图像或视频文件夹&#xff0c;进行图像自动播放和图像切换以及视频播放和调用摄像头等操作&#xff0c;并且支持图像保存功能。项目的核心设计包括文件路径选择、图像或视频的…

浙江3名高中生深夜被困深山 成功获救未受伤

5月31日端午节晚上,三名高中生因降雨失温被困在浙江台州的大雷山。救援人员接到通知后迅速展开搜救行动,最终找到三人并提供了保温毯和雨衣等物资。所幸没有人员伤亡。网络视频显示,当晚天空下着大雨,民警、消防人员及救援队在集结搜寻。大雷山位于浙江省台州市中西部,是永…

西藏那曲双湖县发生3.6级地震 震源深度10公里

据中国地震台网正式测定,6月2日12时32分在西藏那曲市双湖县发生3.6级地震,震源深度10公里,震中位于北纬33.63度,东经89.36度。震中5公里范围内平均海拔约5344米。根据中国地震台网速报目录,震中周边200公里内近5年来共发生了70次3级以上地震,其中最大一次是2021年3月30日…

太原一路虎车在酒吧门口横冲直撞 警方已介入调查

6月2日凌晨4时许,有网友发帖称山西省太原市发生了一起越野车冲撞事件。当天上午,太原警方表示已接到报警并正在调查。根据网友发布的视频,这起事件发生在太原市小店区一商家门口。视频中一辆路虎越野车在前进和倒退时有冲撞物品的行为,路人纷纷避让,有人试图拉开路虎车门。…

苏州一医院医生被停工 误操作引发患者投诉

近日,苏州市立医院东区发生了一起B超检查项目的误操作事件。5月30日,患者王女士因身体不适前往医院就诊,医生开具了腹部B超检查单。然而,在实际操作中,超声科医生却误将其操作成阴道B超。这一错误在王女士察觉后才被发现。她询问项目名称时,实习医生仅背对回应“看错了”…

舟山多个海岛游客“被困” 当地回应 天气影响航班调整

6月1日和6月2日,一些在舟山海岛的游客发帖称,由于没有航班离岛,自己被困在海岛上。这些发帖者包括东极岛、枸杞岛等海岛的游客。据网友发布的图片显示,6月1日,东极海运发布提示:因受海面风浪影响,船舶无法航行,当天庙子湖至沈家门9:20、10:00航班停航,已购买该时段船票…

广铁计划加开列车321列 应对返程高峰

6月2日,端午小长假最后一天,广铁迎来返程客流高峰,旅客纷纷踏上归途。当天预计发送旅客237.3万人次,较去年增长9.3%。整个假期期间,旅客运输总体平稳有序。为满足出行需求,广铁集团优化调整运力,通过加开图外列车、动车组重联、增加夜间高铁等方式提升运力。6月2日计划加…

热刺决定解雇波斯特科格鲁,新主帅人选曝光 弗兰克成热门接班人

北京时间6月1日,热刺决定解雇波斯特科格鲁。托马斯-弗兰克成为热刺新教练的接班人选之一,马尔科-席尔瓦也在考虑名单上。此前法国媒体报道称,托特纳姆热刺本周与托马斯-弗兰克进行了直接接触,双方讨论了夏季转会计划和一些转会目标。责任编辑:zhangxiaohua

北京铁路抵京旅客75.8万 端午假期返程高峰

6月2日是端午假期的最后一天,中国铁路北京局预计发送旅客137万人次。其中,北京地区预计发送51.4万人次、到达75.8万人次。中午11点半,北京站迎来了一波出站客流高峰,旅客出站后迅速前往地铁站和出租车调度站。尽管短时间内出现客流高峰,但因地铁进站闸机全面开启,容纳能力…

章子怡晒照祝儿女节日快乐 陪伴是最好的礼物

6月1日,章子怡在社交平台上晒出女儿和儿子的照片,祝他们儿童节快乐。她写道:“陪伴孩子们的每一天都是上天的恩赐……陪伴就是给孩子们最好的礼物。节日快乐我的孩子,愿你们的童年如彩虹般绚烂,健康快乐地成长。”网友们纷纷留言表示,醒醒越来越像妈妈了。2023年10月23日…

郑钦文闯进法网8强 拿下359万奖金 鏖战三盘胜出

北京时间6月1日晚,法国网球公开赛1/8决赛中,郑钦文以7-6(5)、1-6、6-3战胜萨姆索诺娃,首次闯入法网8强,并获得430积分和44万欧元奖金。首盘比赛中,双方表现平稳,比分交替上升。进入中段后,两人互相破发,比赛变得越来越激烈,最终进入抢七局。在抢七局中,郑钦文表现出…

为中国高速列车发展护航 中南大学团队的创新与坚守

高铁已成为许多旅客出行的首选,但鲜为人知的是,高速列车流线型外形及碰撞吸能结构的设计背后,是中南大学轨道交通空气动力与碰撞安全技术创新团队的辛勤付出。我国首个准高速列车项目设立之初,长沙铁道学院(现中南大学)的几名青年教师敏锐地意识到空气动力学在高速列车发…

鸿蒙仓颉语言开发教程:自定义弹窗

假期第一天&#xff0c;祝大家端午节快乐。昨天观看了时代旗舰尊界S800的发布&#xff0c;不得不感慨这车真好啊&#xff5e; 放假闲来无事&#xff0c;继续跟大家分享仓颉语言的开发教程&#xff0c;今天介绍一下自定义弹窗。 仓颉语言中的自定义弹窗和ArkTs类似&#xff0c…

shp转3d tiles在cesium渲染楼宇白膜

shp文件一般做gis的人都知道它是干嘛的&#xff0c;它是一种地理信息系统矢量数据格式&#xff0c;主要用于存储地理空间数据。但是在cesium中&#xff0c;通过Cesium3DTileset渲染白膜只能渲染3d tiles文件格式。所以我们需要工具去将shp文件转换成3d tiles格式。 我是使用的…

郑钦文请球童用帽子将蜜蜂送离场地 法网8强之路

北京时间6月1日晚,法国网球公开赛1/8决赛中,郑钦文以7-6(5)、1-6、6-3战胜萨姆索诺娃,首次闯入法网8强。她还获得了430积分和44万欧元奖金(约合人民币359万元)。首盘比赛中,双方开局平稳,比分交替上升。进入中段后,两人互相破发,比赛变得越来越激烈,最终进入抢七局…

俄版珍珠港?俄军事博主呼吁报复 乌无人机袭击引发紧张局势

就在俄乌定于6月2日举行的第二轮直接谈判前夕,俄罗斯境内发生了一系列袭击事件。当地时间6月1日,俄罗斯境内有五个空军基地遭遇大规模无人机袭击,乌克兰安全局宣称对此负责。这是乌军自俄乌冲突爆发以来对俄领土发动的最具渗透性的袭击之一。俄罗斯国防部认定这是一次“恐怖…

HarmonyOS鸿蒙开发,Text组件作为容器使用(ImageSpan/Span)快速掌握

Text作为容器使用的时候&#xff0c;里面可以使用ImageSpan存放图片&#xff0c;Span用来存放文字 文本显示 (Text/Span)-使用文本-UI开发 (ArkTS声明式开发范式)-ArkUI&#xff08;方舟UI框架&#xff09;-应用框架 - 华为HarmonyOS开发者 例如&#xff1a;&#xff08;提供给…

鸿蒙HarmonyOS 5.0开发实战:长列表滑动到指定列表项动效实现案例

往期鸿蒙5.0全套实战文章必看&#xff1a;&#xff08;文中附带鸿蒙5.0全栈学习资料&#xff09; 鸿蒙开发核心知识点&#xff0c;看这篇文章就够了 最新版&#xff01;鸿蒙HarmonyOS Next应用开发实战学习路线 鸿蒙HarmonyOS NEXT开发技术最全学习路线指南 鸿蒙应用开发实战…

【HarmonyOS 5】App Linking 应用间跳转详解

目录 什么是 App Linking 使用场景 工作原理 如何开发 1.开通 App Linking 2.确定域名 3.服务端部署 applinking.json 文件 4.AGC绑定域名 5.项目配置 6.组装聚合链接 7.解析聚合链接中的参数 其他 如何获取应用ID 如何在应用未安装时点击链接跳转至应用市场 什…