diff --git a/scripts/animatediff_ui.py b/scripts/animatediff_ui.py index 01001121..83d53000 100644 --- a/scripts/animatediff_ui.py +++ b/scripts/animatediff_ui.py @@ -186,7 +186,7 @@ def set_p(self, p: StableDiffusionProcessing): cn_unit.batch_mask_dir = self.mask_path # find minimun control images in CN batch - cn_unit_batch_params = cn_unit.batch_images.split('\n') + cn_unit_batch_params = cn_unit.batch_images.split('\n') if cn_unit.batch_images is not None else [] if cn_unit.input_mode.name == 'BATCH': cn_unit.animatediff_batch = True # for A1111 sd-webui-controlnet if not any([cn_param.startswith("keyframe:") for cn_param in cn_unit_batch_params[1:]]): diff --git a/scripts/animatediff_utils.py b/scripts/animatediff_utils.py index 0889e828..87b96823 100644 --- a/scripts/animatediff_utils.py +++ b/scripts/animatediff_utils.py @@ -57,14 +57,28 @@ def get_controlnet_units(p: StableDiffusionProcessing): cn_units = p.script_args[script.args_from:script.args_to] if p.is_api and len(cn_units) > 0 and isinstance(cn_units[0], dict): - from scripts import external_code - from scripts.batch_hijack import InputMode - cn_units_dataclass = external_code.get_all_units_in_processing(p) - for cn_unit_dict, cn_unit_dataclass in zip(cn_units, cn_units_dataclass): + from scripts import external_code + from scripts.batch_hijack import InputMode + cn_units_dataclass = external_code.get_all_units_in_processing(p) + for cn_unit_dict, cn_unit_dataclass in zip(cn_units, cn_units_dataclass): + # NB: Unfortunately this setattr section is required because those attributes don't exist + # in the default ControlNetUnit class defined in sd-webui-controlnet library. + # So we have to use this hack to append extra batch processing related attributes to the object + # until sd-webui-controlnet makes an update. + setattr(cn_unit_dataclass, "input_mode", InputMode.SIMPLE) + setattr(cn_unit_dataclass, "batch_images", None) + setattr(cn_unit_dataclass, "batch_mask_dir", None) + setattr(cn_unit_dataclass, "batch_input_gallery", None) + setattr(cn_unit_dataclass, "batch_modifiers", []) + setattr(cn_unit_dataclass, "animatediff_batch", False) + if cn_unit_dataclass.image is None: cn_unit_dataclass.input_mode = InputMode.BATCH - cn_unit_dataclass.batch_images = cn_unit_dict.get("batch_images", None) - p.script_args[script.args_from:script.args_to] = cn_units_dataclass + cn_unit_dataclass.batch_images = getattr(cn_unit_dict, "batch_images", None) + cn_unit_dataclass.animatediff_batch = True + + p.script_args[script.args_from:script.args_to] = cn_units_dataclass + cn_units = cn_units_dataclass return [x for x in cn_units if x.enabled] if not p.is_api else cn_units