|
|
|
@ -170,8 +170,8 @@ class DoubleStreamBlock(nn.Module):
|
|
|
|
|
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
|
|
|
|
|
|
|
|
|
# calculate the img bloks
|
|
|
|
|
img += img_mod1.gate * self.img_attn.proj(img_attn)
|
|
|
|
|
img += img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
|
|
|
|
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
|
|
|
|
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
|
|
|
|
|
|
|
|
|
# calculate the txt bloks
|
|
|
|
|
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
|
|
|
|