以下の内容はhttps://touch-sp.hatenablog.com/entry/2024/12/17/180305より取得しました。


【Diffusers】FLUX.1-devでtorchao(PyTorch Architecture Optimization)を試してみる

PC環境

Windows 11
RTX 3080 Laptop (VRAM 16GB)
CUDA 11.8
Python 3.12

Python環境構築

pip install torch==2.5.1+cu118 --index-url https://download.pytorch.org/whl/cu118
pip install diffusers[torch]
pip install torchao transformers protobuf sentencepiece
diffusers==0.32.1
protobuf==5.29.2
sentencepiece==0.2.0
torch==2.5.1+cu118
torchao==0.7.0
transformers==4.47.1

Pythonスクリプト

import torch
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig
from decorator import gpu_monitor, time_monitor
import gc

def flush():
    gc.collect()
    torch.cuda.empty_cache()

@time_monitor
@gpu_monitor(interval=0.5)
def main():
    model_id = "black-forest-labs/Flux.1-Dev"
    dtype = torch.bfloat16

    pipeline = FluxPipeline.from_pretrained(
            model_id,
            transformer=None,
            vae=None,
            torch_dtype=dtype
    ).to("cuda")

    prompt = "A cat holding a sign that says hello world"

    with torch.no_grad():
        prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
            prompt=prompt,
            prompt_2=None,
        )

    print("text_encoder:")
    print(f"torch.cuda.max_memory_allocated: {torch.cuda.max_memory_allocated()/ 1024**3:.2f} GB")

    del pipeline
    flush()

    quantization_config = TorchAoConfig("int8wo")
    transformer = FluxTransformer2DModel.from_pretrained(
        model_id,
        subfolder="transformer",
        quantization_config=quantization_config,
        torch_dtype=dtype,
    )
    pipeline = FluxPipeline.from_pretrained(
        model_id,
        transformer=transformer,
        text_encoder=None,
        text_encoder_2=None,
        tokenizer=None,
        tokenizer_2=None,
        torch_dtype=dtype
    ).to("cuda")

    image = pipeline(
        prompt_embeds=prompt_embeds,
        pooled_prompt_embeds=pooled_prompt_embeds,
        num_inference_steps=28,
        guidance_scale=3.5
    ).images[0]

    image.save("output.jpg")

    print("transformer:")
    print(f"torch.cuda.max_memory_allocated: {torch.cuda.max_memory_allocated()/ 1024**3:.2f} GB")

if __name__ == "__main__":
    main()

結果

text_encoder:
torch.cuda.max_memory_allocated: 9.33 GB

transformer:
torch.cuda.max_memory_allocated: 13.66 GB

GPU 0 - Used memory: 14.31/16.00 GB

time: 156.65 sec


その他

ベンチマークはこちらで記述したスクリプトで行いました。
touch-sp.hatenablog.com






以上の内容はhttps://touch-sp.hatenablog.com/entry/2024/12/17/180305より取得しました。
このページはhttp://font.textar.tv/のウェブフォントを使用してます

不具合報告/要望等はこちらへお願いします。
モバイルやる夫Viewer Ver0.14