La recherche en IA te passionne ?
Les papers et avancées qui comptent, expliqués simplement, chaque soir. Gratuit.
Inclus dès l'inscription : notre sélection des meilleurs guides & comparatifs IA.
Choisis ton rythme
Gratuit · Pas de spam · Désabonnement en 1 clic
MedQA : Un modèle clinique innovant sur AMD ROCm
MedQA représente une avancée significative dans le domaine des modèles de question-réponse clinique. Ce modèle est affiné à l'aide de LoRA (Low-Rank Adaptation) et fonctionne entièrement sur du matériel AMD utilisant ROCm, sans nécessiter CUDA. MedQA est conçu pour traiter des questions médicales à choix multiples et fournit non seulement la réponse correcte, mais aussi une explication clinique détaillée du raisonnement derrière cette réponse. L'ensemble du pipeline d'entraînement, depuis le chargement des données jusqu'à l'exportation de l'adaptateur, est réalisé sur un AMD Instinct MI300X, éliminant ainsi toute dépendance à CUDA.
Le matériel : AMD Instinct MI300X
L'AMD Instinct MI300X est un matériel de pointe, doté de 192 Go de mémoire HBM3 dans un seul appareil. Cette capacité mémoire est cruciale pour l'affinage des modèles de langage de grande taille (LLM), car elle détermine la taille des lots, la longueur des séquences et la nécessité éventuelle de quantification. Grâce à ses 192 Go de mémoire, il est possible d'entraîner le modèle Qwen3-1.7B avec LoRA en pleine précision fp16, sans avoir recours à des astuces de quantification en 4 bits ou 8 bits.
Le Dataset : MedMCQA
Le jeu de données utilisé pour ce projet est MedMCQA, qui est dérivé des examens d'entrée en médecine indiens, similaires au style USMLE. Chaque exemple dans cet ensemble de données comprend une question clinique, quatre options de réponse (A–D), l'index de la bonne réponse et une explication en texte libre. Pour démontrer la rapidité et l'efficacité de l'affinage, seulement 2 000 échantillons d'entraînement ont été utilisés. L'entraînement sur le MI300X a duré environ cinq minutes, illustrant ainsi la puissance du matériel et de la méthode d'affinage.
Modèle : Qwen3-1.7B
Le modèle de base utilisé dans ce projet est le Qwen/Qwen3-1.7B, le dernier modèle de langage à petite échelle développé par Alibaba. Avec ses 1,7 milliard de paramètres, ce modèle est suffisamment compact pour permettre un affinage économique tout en étant assez puissant pour produire un raisonnement clinique cohérent. Il prend en charge l'option trust_remote_code=True et s'intègre parfaitement avec HuggingFace Transformers, facilitant ainsi son utilisation et son déploiement.
Format de l'invite
La cohérence dans le format des invites est essentielle pour l'affinage des instructions. Chaque exemple d'entraînement et chaque appel d'inférence suivent le même modèle de formatage, qui est structuré de la manière suivante :
{answer_letter}) {answer_text}
### Explanation:
Entraînement avec LoRA
Plutôt que d'affiner l'ensemble des 1,5 milliard de paramètres du modèle, la méthode LoRA est utilisée via la bibliothèque PEFT. LoRA permet d'injecter de petites matrices de décomposition de rang entraînables dans les couches d'attention, tout en laissant les poids de base gelés. Cette approche réduit considérablement la quantité de paramètres à entraîner, avec seulement environ 2,2 millions de paramètres modifiés, ce qui maintient une faible utilisation de la mémoire et permet un entraînement rapide.
Configuration de LoRA
La configuration de LoRA est réalisée à l'aide du code suivant :
from peft import LoraConfig, get_peft_model, TaskType
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
lora_dropout=0.05,
target_modules=["q_proj", "v_proj"],
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
Seuls environ 2,2 millions des 1,5 milliard de paramètres du modèle sont entraînés, ce qui optimise l'utilisation de la mémoire et accélère le processus d'entraînement.
Arguments d'entraînement
Les arguments d'entraînement sont configurés comme suit :
from transformers import TrainingArguments
args = TrainingArguments(
output_dir="./outputs",
num_train_epochs=2,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=2e-4,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
gradient_checkpointing=True,
optim="adamw_torch",
warmup_ratio=0.05,
lr_scheduler_type="cosine",
report_to="none",
)
Quelques points importants à noter :
-
fp16=True, bf16=False : Le standard fp16 est utilisé. Lors des premières expériences avec bfloat16, des pertes NaN ont été rencontrées, mais le passage à fp16 a résolu ce problème.
-
gradient_checkpointing=True : Cette option permet d'échanger du calcul contre de la mémoire. Bien que cela ne soit pas strictement nécessaire sur le MI300X, compte tenu de ses 192 Go de VRAM, c'est une bonne pratique pour assurer la reproductibilité sur des GPU plus petits.
-
gradient_accumulation_steps=4 : Cela permet d'avoir une taille de lot efficace de 16 avec un lot physique de 4.
-
La planification de LR cosinus avec échauffement offre une convergence plus douce qu'un plan de formation plat pour des entraînements courts.
La boucle d'entraînement complète
La boucle d'entraînement complète est mise en place avec le code suivant :
from transformers import DataCollatorForSeq2Seq, Trainer
collator = DataCollatorForSeq2Seq(pad_to_multiple_of=8)
trainer = Trainer(
train_dataset=train_ds,
eval_dataset=val_ds,
data_collator=collator,
)
model.save_pretrained("./outputs")
[tokenizer](/glossaire/tokenizer).save_pretrained("./outputs")
Après l'entraînement, le répertoire ./outputs contient les poids de l'adaptateur LoRA, qui ne pèsent que quelques Mo, contrairement à un point de contrôle complet de plusieurs Go.
Chargement depuis HuggingFace Hub
L'adaptateur affiné est disponible publiquement et peut être chargé directement sans avoir besoin de cloner le dépôt :
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-1.7B", trust_remote_code=True)
base = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-1.7B", torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True)
model = PeftModel.from_pretrained(base, "HK2184/medqa-qwen3-lora")
model = model.merge_and_unload()
Défis et solutions
Aucun projet utilisant AMD ROCm n'est complet sans une section sur les défis rencontrés. Voici quelques-uns des problèmes auxquels nous avons été confrontés :
-
Instabilité de la précision mixte : Nous avons dû passer de bfloat16 à fp16 pour résoudre les problèmes de pertes NaN.
-
GPU non détecté : Cela était dû à des variables d'environnement ROCm manquantes. Nous avons configuré les variables ROCR_VISIBLE_DEVICES, HIP_VISIBLE_DEVICES et HSA_OVERRIDE_GFX_VERSION.
-
bitsandbytes non pris en charge : Il n'existe pas de version ROCm de bitsandbytes, ce qui nous a amenés à abandonner la quantification. Heureusement, le MI300X dispose de suffisamment de VRAM pour s'en passer.
-
Sortie d'inférence incorrecte : Cela était dû à une configuration incorrecte du remplissage du tokenizer. Nous avons ajusté pad_token pour qu'il soit égal à eos_token et corrigé padding_side.
-
Erreurs d'évaluation du Trainer : Ces erreurs étaient dues à une incompatibilité de version des Transformers. Nous avons fixé la version des Transformers à >=4.40.0 pour résoudre ce problème.
Le problème de bitsandbytes mérite une mention spéciale : sur le matériel NVIDIA, la quantification en 4 bits est souvent nécessaire pour faire tenir un modèle en mémoire. Sur le MI300X avec ses 192 Go de HBM3, cela est tout simplement inutile. C'est un véritable avantage matériel, permettant un entraînement plus propre, sans artefacts de quantification.
Temps d'entraînement sur MI300X
-
Taille de l'ensemble de données utilisé : 2 000 échantillons.
-
Précision de base de MedMCQA : 6.1.
Pas de GPU ? Pas de problème. La démo en direct sur Gradio fonctionne sur HuggingFace Spaces.
Prochaines étapes
Ce projet prouve que le pipeline fonctionne. Les prochaines étapes concernent l'échelle et le renforcement :
-
Entraîner sur l'ensemble complet de MedMCQA (~180 000 questions) et ajouter PubMedQA.
-
Ajouter des estimations de confiance calibrées aux réponses.
-
Intégrer la récupération de littérature médicale en temps réel.
-
Effectuer un benchmarking de précision approprié au-delà de la division d'entraînement.
MedQA montre qu'il est non seulement possible de construire une IA médicale capable et explicable sur du matériel AMD open-source, mais que c'est également simple. La compatibilité ROCm de l'écosystème HuggingFace est réellement bonne. L'espace mémoire du MI300X élimine une catégorie entière de problèmes d'ingénierie. Et LoRA rend l'affinage d'un modèle de 1,7 milliard de paramètres un travail de 5 minutes.
Si vous travaillez sur AMD ROCm et rencontrez des obstacles, les solutions ci-dessus devraient vous faire gagner des heures. Et si vous construisez une IA médicale, l'accent mis sur l'explication plutôt que sur la simple précision mérite d'être pris au sérieux.


