|
|
@ -14,7 +14,7 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
|
|
|
|
|
|
|
|
|
|
|
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
|
|
|
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
|
|
|
assert dim % 2 == 0
|
|
|
|
assert dim % 2 == 0
|
|
|
|
if comfy.model_management.is_device_mps(pos.device):
|
|
|
|
if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu():
|
|
|
|
device = torch.device("cpu")
|
|
|
|
device = torch.device("cpu")
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
device = pos.device
|
|
|
|
device = pos.device
|
|
|
|