From ef79fbc4b5066f540ef4ad57bf72ace5fd7e84ea Mon Sep 17 00:00:00 2001 From: "dujiancong.djc" Date: Tue, 3 Feb 2026 21:46:28 +0800 Subject: [PATCH] init flux2 dit on meta device --- diffsynth_engine/models/flux2/flux2_dit.py | 3 ++- diffsynth_engine/pipelines/flux2_klein_image.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/diffsynth_engine/models/flux2/flux2_dit.py b/diffsynth_engine/models/flux2/flux2_dit.py index a022524..954274b 100644 --- a/diffsynth_engine/models/flux2/flux2_dit.py +++ b/diffsynth_engine/models/flux2/flux2_dit.py @@ -1058,7 +1058,8 @@ def from_state_dict( dtype: torch.dtype = torch.float32, **kwargs, ) -> "Flux2DiT": - model = cls(device="meta", dtype=dtype, **kwargs) + with torch.device("meta"): + model = cls(device="meta", dtype=dtype, **kwargs) model = model.requires_grad_(False) model.load_state_dict(state_dict, assign=True) model.to(device=device, dtype=dtype, non_blocking=True) diff --git a/diffsynth_engine/pipelines/flux2_klein_image.py b/diffsynth_engine/pipelines/flux2_klein_image.py index 8409372..98c300d 100644 --- a/diffsynth_engine/pipelines/flux2_klein_image.py +++ b/diffsynth_engine/pipelines/flux2_klein_image.py @@ -202,7 +202,8 @@ def _from_state_dict(cls, state_dicts: Flux2StateDicts, config: Flux2KleinPipeli else: with open(FLUX2_TEXT_ENCODER_8B_CONF_PATH, "r", encoding="utf-8") as f: qwen3_config = Qwen3Config(**json.load(f)) - state_dicts.encoder.pop("lm_head.weight") + if "lm_head.weight" in state_dicts.encoder: + state_dicts.encoder.pop("lm_head.weight") dit_config = {"guidance_embeds": False, "joint_attention_dim": 12288, "num_attention_heads": 32, "num_layers": 8, "num_single_layers": 24} text_encoder = Qwen3Model.from_state_dict( state_dicts.encoder, config=qwen3_config, device=init_device, dtype=config.encoder_dtype