|
|
|
@ -319,13 +319,15 @@ class HunYuanDiT(nn.Module):
|
|
|
|
|
text_states_mask = text_embedding_mask.bool() # 2,77
|
|
|
|
|
text_states_t5_mask = text_embedding_mask_t5.bool() # 2,256
|
|
|
|
|
b_t5, l_t5, c_t5 = text_states_t5.shape
|
|
|
|
|
text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5))
|
|
|
|
|
text_states = torch.cat([text_states, text_states_t5.view(b_t5, l_t5, -1)], dim=1) # 2,205,1024
|
|
|
|
|
text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5)).view(b_t5, l_t5, -1)
|
|
|
|
|
|
|
|
|
|
clip_t5_mask = torch.cat([text_states_mask, text_states_t5_mask], dim=-1)
|
|
|
|
|
padding = self.text_embedding_padding.to(text_states)
|
|
|
|
|
|
|
|
|
|
clip_t5_mask = clip_t5_mask
|
|
|
|
|
text_states = torch.where(clip_t5_mask.unsqueeze(2), text_states, self.text_embedding_padding.to(text_states))
|
|
|
|
|
text_states[:,-self.text_len:] = torch.where(text_states_mask[:,-self.text_len:].unsqueeze(2), text_states[:,-self.text_len:], padding[:self.text_len])
|
|
|
|
|
text_states_t5[:,-self.text_len_t5:] = torch.where(text_states_t5_mask[:,-self.text_len_t5:].unsqueeze(2), text_states_t5[:,-self.text_len_t5:], padding[self.text_len:])
|
|
|
|
|
|
|
|
|
|
text_states = torch.cat([text_states, text_states_t5], dim=1) # 2,205,1024
|
|
|
|
|
# clip_t5_mask = torch.cat([text_states_mask, text_states_t5_mask], dim=-1)
|
|
|
|
|
|
|
|
|
|
_, _, oh, ow = x.shape
|
|
|
|
|
th, tw = (oh + (self.patch_size // 2)) // self.patch_size, (ow + (self.patch_size // 2)) // self.patch_size
|
|
|
|
|