|
|
|
@ -22,10 +22,17 @@ class V_PREDICTION(EPS):
|
|
|
|
|
class ModelSamplingDiscrete(torch.nn.Module):
|
|
|
|
|
def __init__(self, model_config=None):
|
|
|
|
|
super().__init__()
|
|
|
|
|
beta_schedule = "linear"
|
|
|
|
|
|
|
|
|
|
if model_config is not None:
|
|
|
|
|
beta_schedule = model_config.sampling_settings.get("beta_schedule", beta_schedule)
|
|
|
|
|
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=0.00085, linear_end=0.012, cosine_s=8e-3)
|
|
|
|
|
sampling_settings = model_config.sampling_settings
|
|
|
|
|
else:
|
|
|
|
|
sampling_settings = {}
|
|
|
|
|
|
|
|
|
|
beta_schedule = sampling_settings.get("beta_schedule", "linear")
|
|
|
|
|
linear_start = sampling_settings.get("linear_start", 0.00085)
|
|
|
|
|
linear_end = sampling_settings.get("linear_end", 0.012)
|
|
|
|
|
|
|
|
|
|
self._register_schedule(given_betas=None, beta_schedule=beta_schedule, timesteps=1000, linear_start=linear_start, linear_end=linear_end, cosine_s=8e-3)
|
|
|
|
|
self.sigma_data = 1.0
|
|
|
|
|
|
|
|
|
|
def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps=1000,
|
|
|
|
|