Compare commits

...

10 Commits

Author SHA1 Message Date
comfyanonymous 935ae153e1 Cleanup. 6 months ago
Chenlei Hu e91662e784
Get logs endpoint & system_stats additions (#4690)
* Add route for getting output logs

* Include ComfyUI version

* Move to own function

* Changed to memory logger

* Unify logger setup logic

* Fix get version git fallback

---------

Co-authored-by: pythongosssss <125205205+pythongosssss@users.noreply.github.com>
6 months ago
comfyanonymous 63fafaef45 Fix potential issue with hydit controlnets. 6 months ago
Alex "mcmonkey" Goodwin ec28cd9136
swap legacy sdv15 link (#4682)
* swap legacy sdv15 link

* swap v15 ckpt examples to safetensors

* link the fp16 copy of the model by default
6 months ago
comfyanonymous 6eb5d64522 Fix glora lowvram issue. 6 months ago
comfyanonymous 10a79e9898 Implement model part of flux union controlnet. 6 months ago
comfyanonymous ea3f39bd69 InstantX depth flux controlnet. 6 months ago
comfyanonymous b33cd61070 InstantX canny controlnet. 6 months ago
Dr.Lt.Data 34eda0f853
fix: remove redundant useless loop (#4656)
fix: potential error of undefined variable

https://github.com/comfyanonymous/ComfyUI/discussions/4650
6 months ago
comfyanonymous d31e226650 Unify RMSNorm code. 6 months ago

@ -14,7 +14,7 @@ run_cpu.bat
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints
You can download the stable diffusion 1.5 one from: https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt
You can download the stable diffusion 1.5 one from: https://huggingface.co/Comfy-Org/stable-diffusion-v1-5-archive/blob/main/v1-5-pruned-emaonly-fp16.safetensors
RECOMMENDED WAY TO UPDATE:

@ -2,6 +2,7 @@ from aiohttp import web
from typing import Optional
from folder_paths import models_dir, user_directory, output_directory
from api_server.services.file_service import FileService
import app.logger
class InternalRoutes:
'''
@ -31,6 +32,9 @@ class InternalRoutes:
except Exception as e:
return web.json_response({"error": str(e)}, status=500)
@self.routes.get('/logs')
async def get_logs(request):
return web.json_response(app.logger.get_logs())
def get_app(self):
if self._app is None:

@ -0,0 +1,31 @@
import logging
from logging.handlers import MemoryHandler
from collections import deque
logs = None
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
def get_logs():
return "\n".join([formatter.format(x) for x in logs])
def setup_logger(verbose: bool = False, capacity: int = 300):
global logs
if logs:
return
# Setup default global logger
logger = logging.getLogger()
logger.setLevel(logging.DEBUG if verbose else logging.INFO)
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(logging.Formatter("%(message)s"))
logger.addHandler(stream_handler)
# Create a memory handler with a deque as its buffer
logs = deque(maxlen=capacity)
memory_handler = MemoryHandler(capacity, flushLevel=logging.INFO)
memory_handler.buffer = logs
memory_handler.setFormatter(formatter)
logger.addHandler(memory_handler)

@ -179,10 +179,3 @@ if args.windows_standalone_build:
if args.disable_auto_launch:
args.auto_launch = False
import logging
logging_level = logging.INFO
if args.verbose:
logging_level = logging.DEBUG
logging.basicConfig(format="%(message)s", level=logging_level)

@ -34,7 +34,7 @@ import comfy.t2i_adapter.adapter
import comfy.ldm.cascade.controlnet
import comfy.cldm.mmdit
import comfy.ldm.hydit.controlnet
import comfy.ldm.flux.controlnet_xlabs
import comfy.ldm.flux.controlnet
def broadcast_image_to(tensor, target_batch_size, batched_number):
@ -148,7 +148,7 @@ class ControlBase:
elif self.strength_type == StrengthType.LINEAR_UP:
x *= (self.strength ** float(len(control_output) - i))
if x.dtype != output_dtype:
if output_dtype is not None and x.dtype != output_dtype:
x = x.to(output_dtype)
out[key].append(x)
@ -206,7 +206,6 @@ class ControlNet(ControlBase):
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
output_dtype = x_noisy.dtype
if self.cond_hint is None or x_noisy.shape[2] * self.compression_ratio != self.cond_hint.shape[2] or x_noisy.shape[3] * self.compression_ratio != self.cond_hint.shape[3]:
if self.cond_hint is not None:
del self.cond_hint
@ -236,7 +235,7 @@ class ControlNet(ControlBase):
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.to(dtype), context=context.to(dtype), **extra)
return self.control_merge(control, control_prev, output_dtype)
return self.control_merge(control, control_prev, output_dtype=None)
def copy(self):
c = ControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
@ -433,12 +432,30 @@ def load_controlnet_hunyuandit(controlnet_data):
def load_controlnet_flux_xlabs(sd):
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd)
control_model = comfy.ldm.flux.controlnet_xlabs.ControlNetFlux(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, sd)
extra_conds = ['y', 'guidance']
control = ControlNet(control_model, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
return control
def load_controlnet_flux_instantx(sd):
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd)
for k in sd:
new_sd[k] = sd[k]
num_union_modes = 0
union_cnet = "controlnet_mode_embedder.weight"
if union_cnet in new_sd:
num_union_modes = new_sd[union_cnet].shape[0]
control_model = comfy.ldm.flux.controlnet.ControlNetFlux(latent_input=True, num_union_modes=num_union_modes, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
control_model = controlnet_load_state_dict(control_model, new_sd)
latent_format = comfy.latent_formats.Flux()
extra_conds = ['y', 'guidance']
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
return control
def load_controlnet(ckpt_path, model=None):
controlnet_data = comfy.utils.load_torch_file(ckpt_path, safe_load=True)
@ -504,8 +521,10 @@ def load_controlnet(ckpt_path, model=None):
elif "controlnet_blocks.0.weight" in controlnet_data: #SD3 diffusers format
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
return load_controlnet_flux_xlabs(controlnet_data)
else:
elif "pos_embed_input.proj.weight" in controlnet_data:
return load_controlnet_mmdit(controlnet_data)
elif "controlnet_x_embedder.weight" in controlnet_data:
return load_controlnet_flux_instantx(controlnet_data)
pth_key = 'control_model.zero_convs.0.0.weight'
pth = False

@ -1,4 +1,5 @@
import torch
import comfy.ops
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting():
@ -6,3 +7,15 @@ def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0]
pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1]
return torch.nn.functional.pad(img, (0, pad_w, 0, pad_h), mode=padding_mode)
try:
rms_norm_torch = torch.nn.functional.rms_norm
except:
rms_norm_torch = None
def rms_norm(x, weight, eps=1e-6):
if rms_norm_torch is not None:
return rms_norm_torch(x, weight.shape, weight=comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
else:
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
return (x * rrms) * comfy.ops.cast_to(weight, dtype=x.dtype, device=x.device)

@ -0,0 +1,151 @@
#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
import torch
import math
from torch import Tensor, nn
from einops import rearrange, repeat
from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
MLPEmbedder, SingleStreamBlock,
timestep_embedding)
from .model import Flux
import comfy.ldm.common_dit
class ControlNetFlux(Flux):
def __init__(self, latent_input=False, num_union_modes=0, image_model=None, dtype=None, device=None, operations=None, **kwargs):
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
self.main_model_double = 19
self.main_model_single = 38
# add ControlNet blocks
self.controlnet_blocks = nn.ModuleList([])
for _ in range(self.params.depth):
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
self.controlnet_blocks.append(controlnet_block)
self.controlnet_single_blocks = nn.ModuleList([])
for _ in range(self.params.depth_single_blocks):
self.controlnet_single_blocks.append(operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device))
self.num_union_modes = num_union_modes
self.controlnet_mode_embedder = None
if self.num_union_modes > 0:
self.controlnet_mode_embedder = operations.Embedding(self.num_union_modes, self.hidden_size, dtype=dtype, device=device)
self.gradient_checkpointing = False
self.latent_input = latent_input
self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
if not self.latent_input:
self.input_hint_block = nn.Sequential(
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
)
def forward_orig(
self,
img: Tensor,
img_ids: Tensor,
controlnet_cond: Tensor,
txt: Tensor,
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor = None,
control_type: Tensor = None,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
# running on sequences img
img = self.img_in(img)
if not self.latent_input:
controlnet_cond = self.input_hint_block(controlnet_cond)
controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
controlnet_cond = self.pos_embed_input(controlnet_cond)
img = img + controlnet_cond
vec = self.time_in(timestep_embedding(timesteps, 256))
if self.params.guidance_embed:
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
if self.controlnet_mode_embedder is not None and len(control_type) > 0:
control_cond = self.controlnet_mode_embedder(torch.tensor(control_type, device=img.device), out_dtype=img.dtype).unsqueeze(0).repeat((txt.shape[0], 1, 1))
txt = torch.cat([control_cond, txt], dim=1)
txt_ids = torch.cat([txt_ids[:,:1], txt_ids], dim=1)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
controlnet_double = ()
for i in range(len(self.double_blocks)):
img, txt = self.double_blocks[i](img=img, txt=txt, vec=vec, pe=pe)
controlnet_double = controlnet_double + (self.controlnet_blocks[i](img),)
img = torch.cat((txt, img), 1)
controlnet_single = ()
for i in range(len(self.single_blocks)):
img = self.single_blocks[i](img, vec=vec, pe=pe)
controlnet_single = controlnet_single + (self.controlnet_single_blocks[i](img[:, txt.shape[1] :, ...]),)
repeat = math.ceil(self.main_model_double / len(controlnet_double))
if self.latent_input:
out_input = ()
for x in controlnet_double:
out_input += (x,) * repeat
else:
out_input = (controlnet_double * repeat)
out = {"input": out_input[:self.main_model_double]}
if len(controlnet_single) > 0:
repeat = math.ceil(self.main_model_single / len(controlnet_single))
out_output = ()
if self.latent_input:
for x in controlnet_single:
out_output += (x,) * repeat
else:
out_output = (controlnet_single * repeat)
out["output"] = out_output[:self.main_model_single]
return out
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
patch_size = 2
if self.latent_input:
hint = comfy.ldm.common_dit.pad_to_patch_size(hint, (patch_size, patch_size))
hint = rearrange(hint, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
else:
hint = hint * 2.0 - 1.0
bs, c, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance, control_type=kwargs.get("control_type", []))

@ -1,104 +0,0 @@
#Original code can be found on: https://github.com/XLabs-AI/x-flux/blob/main/src/flux/controlnet.py
import torch
from torch import Tensor, nn
from einops import rearrange, repeat
from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
MLPEmbedder, SingleStreamBlock,
timestep_embedding)
from .model import Flux
import comfy.ldm.common_dit
class ControlNetFlux(Flux):
def __init__(self, image_model=None, dtype=None, device=None, operations=None, **kwargs):
super().__init__(final_layer=False, dtype=dtype, device=device, operations=operations, **kwargs)
# add ControlNet blocks
self.controlnet_blocks = nn.ModuleList([])
for _ in range(self.params.depth):
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
# controlnet_block = zero_module(controlnet_block)
self.controlnet_blocks.append(controlnet_block)
self.pos_embed_input = operations.Linear(self.in_channels, self.hidden_size, bias=True, dtype=dtype, device=device)
self.gradient_checkpointing = False
self.input_hint_block = nn.Sequential(
operations.Conv2d(3, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, stride=2, dtype=dtype, device=device),
nn.SiLU(),
operations.Conv2d(16, 16, 3, padding=1, dtype=dtype, device=device)
)
def forward_orig(
self,
img: Tensor,
img_ids: Tensor,
controlnet_cond: Tensor,
txt: Tensor,
txt_ids: Tensor,
timesteps: Tensor,
y: Tensor,
guidance: Tensor = None,
) -> Tensor:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
# running on sequences img
img = self.img_in(img)
controlnet_cond = self.input_hint_block(controlnet_cond)
controlnet_cond = rearrange(controlnet_cond, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
controlnet_cond = self.pos_embed_input(controlnet_cond)
img = img + controlnet_cond
vec = self.time_in(timestep_embedding(timesteps, 256))
if self.params.guidance_embed:
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
ids = torch.cat((txt_ids, img_ids), dim=1)
pe = self.pe_embedder(ids)
block_res_samples = ()
for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
block_res_samples = block_res_samples + (img,)
controlnet_block_res_samples = ()
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
block_res_sample = controlnet_block(block_res_sample)
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
return {"input": (controlnet_block_res_samples * 10)[:19]}
def forward(self, x, timesteps, context, y, guidance=None, hint=None, **kwargs):
hint = hint * 2.0 - 1.0
bs, c, h, w = x.shape
patch_size = 2
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
return self.forward_orig(img, img_ids, hint, context, txt_ids, timesteps, y, guidance)

@ -6,6 +6,7 @@ from torch import Tensor, nn
from .math import attention, rope
import comfy.ops
import comfy.ldm.common_dit
class EmbedND(nn.Module):
@ -63,8 +64,7 @@ class RMSNorm(torch.nn.Module):
self.scale = nn.Parameter(torch.empty((dim), dtype=dtype, device=device))
def forward(self, x: Tensor):
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
return (x * rrms) * comfy.ops.cast_to(self.scale, dtype=x.dtype, device=x.device)
return comfy.ldm.common_dit.rms_norm(x, self.scale, 1e-6)
class QKNorm(torch.nn.Module):

@ -372,7 +372,7 @@ class HunYuanDiT(nn.Module):
for layer, block in enumerate(self.blocks):
if layer > self.depth // 2:
if controls is not None:
skip = skips.pop() + controls.pop()
skip = skips.pop() + controls.pop().to(dtype=x.dtype)
else:
skip = skips.pop()
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)

@ -355,29 +355,9 @@ class RMSNorm(torch.nn.Module):
else:
self.register_parameter("weight", None)
def _norm(self, x):
"""
Apply the RMSNorm normalization to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
"""
Forward pass through the RMSNorm layer.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying RMSNorm.
"""
x = self._norm(x)
if self.learnable_scale:
return x * self.weight.to(device=x.device, dtype=x.dtype)
else:
return x
return comfy.ldm.common_dit.rms_norm(x, self.weight, self.eps)
class SwiGLUFeedForward(nn.Module):

@ -540,7 +540,7 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32):
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
try:
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1), a2), a1)).reshape(weight.shape)
lora_diff = (torch.mm(b2, b1) + torch.mm(torch.mm(weight.flatten(start_dim=1).to(dtype=intermediate_dtype), a2), a1)).reshape(weight.shape)
if dora_scale is not None:
weight = function(weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype))
else:

@ -45,6 +45,7 @@ cpu_state = CPUState.GPU
total_vram = 0
xpu_available = False
torch_version = ""
try:
torch_version = torch.version.__version__
xpu_available = (int(torch_version[0]) < 2 or (int(torch_version[0]) == 2 and int(torch_version[2]) <= 4)) and torch.xpu.is_available()

@ -528,6 +528,8 @@ def flux_to_diffusers(mmdit_config, output_prefix=""):
("guidance_in.out_layer.weight", "time_text_embed.guidance_embedder.linear_2.weight"),
("final_layer.adaLN_modulation.1.bias", "norm_out.linear.bias", swap_scale_shift),
("final_layer.adaLN_modulation.1.weight", "norm_out.linear.weight", swap_scale_shift),
("pos_embed_input.bias", "controlnet_x_embedder.bias"),
("pos_embed_input.weight", "controlnet_x_embedder.weight"),
}
for k in MAP_BASIC:

@ -6,6 +6,10 @@ import importlib.util
import folder_paths
import time
from comfy.cli_args import args
from app.logger import setup_logger
setup_logger(verbose=args.verbose)
def execute_prestartup_script():

@ -79,7 +79,7 @@
"#!wget -c https://huggingface.co/comfyanonymous/clip_vision_g/resolve/main/clip_vision_g.safetensors -P ./models/clip_vision/\n",
"\n",
"# SD1.5\n",
"!wget -c https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt -P ./models/checkpoints/\n",
"!wget -c https://huggingface.co/Comfy-Org/stable-diffusion-v1-5-archive/resolve/main/v1-5-pruned-emaonly-fp16.safetensors -P ./models/checkpoints/\n",
"\n",
"# SD2\n",
"#!wget -c https://huggingface.co/stabilityai/stable-diffusion-2-1-base/resolve/main/v2-1_512-ema-pruned.safetensors -P ./models/checkpoints/\n",

@ -43,7 +43,7 @@ prompt_text = """
"4": {
"class_type": "CheckpointLoaderSimple",
"inputs": {
"ckpt_name": "v1-5-pruned-emaonly.ckpt"
"ckpt_name": "v1-5-pruned-emaonly.safetensors"
}
},
"5": {

@ -41,15 +41,14 @@ def get_images(ws, prompt):
continue #previews are binary data
history = get_history(prompt_id)[prompt_id]
for o in history['outputs']:
for node_id in history['outputs']:
node_output = history['outputs'][node_id]
if 'images' in node_output:
images_output = []
for image in node_output['images']:
image_data = get_image(image['filename'], image['subfolder'], image['type'])
images_output.append(image_data)
output_images[node_id] = images_output
for node_id in history['outputs']:
node_output = history['outputs'][node_id]
images_output = []
if 'images' in node_output:
for image in node_output['images']:
image_data = get_image(image['filename'], image['subfolder'], image['type'])
images_output.append(image_data)
output_images[node_id] = images_output
return output_images
@ -85,7 +84,7 @@ prompt_text = """
"4": {
"class_type": "CheckpointLoaderSimple",
"inputs": {
"ckpt_name": "v1-5-pruned-emaonly.ckpt"
"ckpt_name": "v1-5-pruned-emaonly.safetensors"
}
},
"5": {

@ -81,7 +81,7 @@ prompt_text = """
"4": {
"class_type": "CheckpointLoaderSimple",
"inputs": {
"ckpt_name": "v1-5-pruned-emaonly.ckpt"
"ckpt_name": "v1-5-pruned-emaonly.safetensors"
}
},
"5": {

@ -31,7 +31,6 @@ from model_filemanager import download_model, DownloadModelStatus
from typing import Optional
from api_server.routes.internal.internal_routes import InternalRoutes
class BinaryEventTypes:
PREVIEW_IMAGE = 1
UNENCODED_PREVIEW_IMAGE = 2
@ -42,6 +41,21 @@ async def send_socket_catch_exception(function, message):
except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError) as err:
logging.warning("send error: {}".format(err))
def get_comfyui_version():
comfyui_version = "unknown"
repo_path = os.path.dirname(os.path.realpath(__file__))
try:
import pygit2
repo = pygit2.Repository(repo_path)
comfyui_version = repo.describe(describe_strategy=pygit2.GIT_DESCRIBE_TAGS)
except Exception:
try:
import subprocess
comfyui_version = subprocess.check_output(["git", "describe", "--tags"], cwd=repo_path).decode('utf-8')
except Exception as e:
logging.warning(f"Failed to get ComfyUI version: {e}")
return comfyui_version.strip()
@web.middleware
async def cache_control(request: web.Request, handler):
response: web.Response = await handler(request)
@ -401,16 +415,20 @@ class PromptServer():
return web.json_response(dt["__metadata__"])
@routes.get("/system_stats")
async def get_queue(request):
async def system_stats(request):
device = comfy.model_management.get_torch_device()
device_name = comfy.model_management.get_torch_device_name(device)
vram_total, torch_vram_total = comfy.model_management.get_total_memory(device, torch_total_too=True)
vram_free, torch_vram_free = comfy.model_management.get_free_memory(device, torch_free_too=True)
system_stats = {
"system": {
"os": os.name,
"comfyui_version": get_comfyui_version(),
"python_version": sys.version,
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded"
"pytorch_version": comfy.model_management.torch_version,
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded",
"argv": sys.argv
},
"devices": [
{

@ -95,17 +95,16 @@ class ComfyClient:
pass # Probably want to store this off for testing
history = self.get_history(prompt_id)[prompt_id]
for o in history['outputs']:
for node_id in history['outputs']:
node_output = history['outputs'][node_id]
result.outputs[node_id] = node_output
if 'images' in node_output:
images_output = []
for image in node_output['images']:
image_data = self.get_image(image['filename'], image['subfolder'], image['type'])
image_obj = Image.open(BytesIO(image_data))
images_output.append(image_obj)
node_output['image_objects'] = images_output
for node_id in history['outputs']:
node_output = history['outputs'][node_id]
result.outputs[node_id] = node_output
images_output = []
if 'images' in node_output:
for image in node_output['images']:
image_data = self.get_image(image['filename'], image['subfolder'], image['type'])
image_obj = Image.open(BytesIO(image_data))
images_output.append(image_obj)
node_output['image_objects'] = images_output
return result

@ -109,15 +109,14 @@ class ComfyClient:
continue #previews are binary data
history = self.get_history(prompt_id)[prompt_id]
for o in history['outputs']:
for node_id in history['outputs']:
node_output = history['outputs'][node_id]
if 'images' in node_output:
images_output = []
for image in node_output['images']:
image_data = self.get_image(image['filename'], image['subfolder'], image['type'])
images_output.append(image_data)
output_images[node_id] = images_output
for node_id in history['outputs']:
node_output = history['outputs'][node_id]
images_output = []
if 'images' in node_output:
for image in node_output['images']:
image_data = self.get_image(image['filename'], image['subfolder'], image['type'])
images_output.append(image_data)
output_images[node_id] = images_output
return output_images

Loading…
Cancel
Save