はじめに
WeDLMはTencentが公開しているDiffusion Language Modelです。
Gradioを使ってチャットできるようにしてみました。
Pythonコード
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig import gradio as gr quantization_config = BitsAndBytesConfig(load_in_4bit=True) tokenizer = AutoTokenizer.from_pretrained("tencent/WeDLM-8B-Instruct", trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( "tencent/WeDLM-8B-Instruct", trust_remote_code=True, torch_dtype="auto", device_map="auto", quantization_config = quantization_config ) def extract_text_from_content(content): """Gradio 6の構造化コンテンツからテキストを抽出する""" if isinstance(content, str): return content elif isinstance(content, list): text_parts = [] for item in content: if isinstance(item, dict): if item.get("type") == "text": text_parts.append(item.get("text", "")) else: text_parts.append(str(item)) else: text_parts.append(str(item)) return " ".join(text_parts) else: return str(content) def chat(message, history): messages = [] for entry in history: if isinstance(entry, dict): role = entry.get("role", "user") content = extract_text_from_content(entry.get("content", "")) messages.append({"role": role, "content": content}) message_content = extract_text_from_content(message) messages.append({"role": "user", "content": message_content}) text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = tokenizer(text, return_tensors="pt").to(model.device) output_ids = model.generate( **inputs, max_new_tokens=512, pad_token_id=tokenizer.eos_token_id ) response = tokenizer.decode( output_ids[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True ) return response demo = gr.ChatInterface( fn=chat, title="WeDLM-8B Chat", description="Tencent WeDLM-8B-Instruct モデルとチャットできます", examples=["Hello! How are you?", "Python について教えてください", "簡単なコードを書いてください"] ) if __name__ == "__main__": demo.launch()
環境構築
pyproject.tomlを載せておきます。
uvを使うとuv syncだけで環境構築できると思います。
[project] name = "dlm" version = "0.1.0" description = "Add your description here" readme = "README.md" requires-python = ">=3.13" dependencies = [ "accelerate==1.12.0", "bitsandbytes==0.49.0", "gradio==6.2.0", "hf-xet==1.2.0", "torch==2.9.1+cu126", "transformers==4.57.1", ] [[tool.uv.index]] name = "torch-cuda" url = "https://download.pytorch.org/whl/cu126" explicit = true [tool.uv.sources] torch = [{ index = "torch-cuda" }]
transformersライブライが4.53.2以上ではエラーが出ました。
TypeError: check_model_inputs.<locals>.wrapped_fn() got an unexpected keyword argument 'input_ids'