1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
| class MiniGPT4(Blip2Base): def __init__( ): if mm_projector_type == "vib":
self.llama_proj = nn.Linear( self.Qformer.config.hidden_size, self.llama_model.config.hidden_size )
self.llama_proj_std = nn.Linear( self.Qformer.config.hidden_size, self.llama_model.config.hidden_size )
self.mu_p = nn.Parameter(torch.randn(self.llama_model.config.hidden_size)) self.std_p = nn.Parameter(torch.randn(self.llama_model.config.hidden_size))
elif mm_projector_type == "linear": self.llama_proj = nn.Linear( self.Qformer.config.hidden_size, self.llama_model.config.hidden_size )
def universal_sentence_embedding(self, sentences, mask, sqrt=True): ''' :param sentences: [batch_size, seq_len, hidden_size] :param mask: [batch_size, seq_len] :param sqrt: :return: [batch_size, hidden_size] ''' sentence_sums = torch.bmm( sentences.permute(0, 2, 1), mask.float().unsqueeze(-1) ).squeeze(-1) divisor = (mask.sum(dim=1).view(-1, 1).float()) if sqrt: divisor = divisor.sqrt() sentence_sums /= divisor return sentence_sums
def _emb_entropy(self, inputs_llama, llama_emb_weights): batch_size = inputs_llama.size(0) prompt_len = inputs_llama.size(1) with torch.no_grad(): _mask = torch.ones([batch_size, prompt_len]).to(inputs_llama.device) _pooling_states = self.universal_sentence_embedding(inputs_llama, mask=_mask) _prob = torch.matmul(_pooling_states, llama_emb_weights.transpose(0, 1)).softmax(-1) _entropy = - (_prob * (_prob + 1e-8).log()).sum(-1) return _entropy, _prob
def estimate(self, emb, emb2mu, emb2std): """Estimates mu and std from the given input embeddings.""" mean = emb2mu(emb) std = torch.nn.functional.softplus(emb2std(emb)) return mean, std
def kl_div(self, mu_q, std_q, mu_p, std_p): k = mu_q.size(-1) mu_diff = mu_p - mu_q mu_diff_sq = torch.mul(mu_diff, mu_diff) logdet_std_q = torch.sum(2 * torch.log(torch.clamp(std_q, min=1e-8)), dim=-1) logdet_std_p = torch.sum(2 * torch.log(torch.clamp(std_p, min=1e-8)), dim=-1) fs = torch.sum(torch.div(std_q ** 2, std_p ** 2), dim=-1) + torch.sum(torch.div(mu_diff_sq, std_p ** 2), dim=-1) kl_divergence = (fs - k + logdet_std_p - logdet_std_q) * 0.5 return kl_divergence
def reparameterize(self, mu, std, sample_size): batch_size = mu.size(0) z = torch.randn(sample_size, batch_size, mu.size(1), mu.size(2)).to(mu.device) return mu + std * z
def vib_layer(self, query_output_state, is_training, ib_sample_size): batch_size = query_output_state.size(0) prompt_len = query_output_state.size(1) mu, std = self.estimate(query_output_state, self.llama_proj, self.llama_proj_std)
_mask = torch.ones([batch_size, prompt_len], requires_grad=False).to(mu.device) mu_pooling = self.universal_sentence_embedding(mu, _mask) std_pooling = self.universal_sentence_embedding(std, _mask)
mu_p = self.mu_p.view(1, -1).expand(batch_size, -1) std_p = torch.nn.functional.softplus(self.std_p.view(1, -1).expand(batch_size, -1)) kl_loss = self.kl_div(mu_pooling, std_pooling, mu_p, std_p)
if is_training: z = self.reparameterize(mu, std, sample_size=ib_sample_size) sampled_logits = self.get_logits(z) logits = sampled_logits else: logits = mu
return logits, kl_loss
def _alpha_fn(self, _entropy, v_size): return - (_entropy / math.log(v_size)).log()
def encode_img(self, image, is_training, beta=1., self_adaptive=False, ib_sample_size=3): if self.mm_projector_type == "vib":
inputs_llama, kl_loss = self.vib_layer(query_output.last_hidden_state, is_training, ib_sample_size) llama_emb_weight = self.llama_model.model.embed_tokens.weight _entropy, _prob = self._emb_entropy(inputs_llama, llama_emb_weight)
if self_adaptive and is_training: _alpha = self._alpha_fn(_entropy, v_size=llama_emb_weight.size(0)) sample_size = _alpha.size(0) // kl_loss.size(0) _alpha = _alpha.reshape(sample_size, -1).mean(0) kl_loss = beta * (kl_loss * _alpha).mean() else: kl_loss = beta * kl_loss.mean()
|