From 4506ddc86a413779cd2274305694c73af02c3892 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 19 Aug 2024 13:38:03 -0400 Subject: [PATCH] Better subnormal fp8 stochastic rounding. Thanks Ashen. --- comfy/float.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/comfy/float.py b/comfy/float.py index 51ae987..1dbdafd 100644 --- a/comfy/float.py +++ b/comfy/float.py @@ -19,25 +19,29 @@ def manual_stochastic_round_to_float8(x, dtype): ) # Combine mantissa calculation and rounding - mantissa = abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0 - mantissa_scaled = mantissa * (2**MANTISSA_BITS) + # min_normal = 2.0 ** (-EXPONENT_BIAS + 1) + # zero_mask = (abs_x == 0) + # subnormal_mask = (exponent == 0) & (abs_x != 0) + normal_mask = ~(exponent == 0) + + mantissa_scaled = torch.where( + normal_mask, + (abs_x / (2.0 ** (exponent - EXPONENT_BIAS)) - 1.0) * (2**MANTISSA_BITS), + (abs_x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS))) + ) mantissa_floor = mantissa_scaled.floor() mantissa = torch.where( torch.rand_like(mantissa_scaled) < (mantissa_scaled - mantissa_floor), (mantissa_floor + 1) / (2**MANTISSA_BITS), mantissa_floor / (2**MANTISSA_BITS) ) + result = torch.where( + normal_mask, + sign * (2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + mantissa), + sign * (2.0 ** (-EXPONENT_BIAS + 1)) * mantissa + ) - # Combine final result calculation - result = sign * (2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + mantissa) - - # Handle zero case - zero_mask = (abs_x == 0) - result = torch.where(zero_mask, torch.zeros_like(result), result) - - # Handle subnormal numbers - min_normal = 2.0 ** (-EXPONENT_BIAS + 1) - result = torch.where((abs_x < min_normal) & (~zero_mask), torch.round(x / (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS))) * (2.0 ** (-EXPONENT_BIAS + 1 - MANTISSA_BITS)), result) + result = torch.where(abs_x == 0, 0, result) return result.to(dtype=dtype)