MedQA: Revolutionizing Clinical AI with AMD ROCm, Without CUDA
Le brief IA que les pros lisent chaque soir
Les 7 actus IA du jour, décryptées en 5 min. Gratuit.
Inclus dès l'inscription : notre sélection des meilleurs guides & comparatifs IA.
Choisis ton rythme
Gratuit · Pas de spam · Désabonnement en 1 clic
MedQA: An Innovative Clinical Model on AMD ROCm
MedQA represents a significant advancement in the field of clinical question-answering models. This model is fine-tuned using LoRA (Low-Rank Adaptation) and operates entirely on AMD hardware utilizing ROCm, without requiring CUDA. MedQA is designed to handle multiple-choice medical questions and provides not only the correct answer but also a detailed clinical explanation of the reasoning behind that answer. The entire training pipeline, from data loading to adapter export, is carried out on an AMD Instinct MI300X, thus eliminating any dependency on CUDA.
The Hardware: AMD Instinct MI300X
The AMD Instinct MI300X is cutting-edge hardware, equipped with 192 GB of HBM3 memory in a single device. This memory capacity is crucial for fine-tuning large language models (LLMs), as it determines batch size, sequence length, and the potential need for quantization. With its 192 GB of memory, it is possible to train the Qwen3-1.7B model with LoRA in full precision fp16, without resorting to 4-bit or 8-bit quantization tricks.
The Dataset: MedMCQA
The dataset used for this project is MedMCQA, which is derived from Indian medical entrance exams, similar in style to the USMLE. Each example in this dataset includes a clinical question, four answer options (A–D), the index of the correct answer, and a free-text explanation. To demonstrate the speed and efficiency of the fine-tuning, only 2,000 training samples were used. Training on the MI300X took about five minutes, illustrating the power of the hardware and the fine-tuning method.
Model: Qwen3-1.7B
The base model used in this project is the Qwen/Qwen3-1.7B, the latest small-scale language model developed by Alibaba. With its 1.7 billion parameters, this model is compact enough to allow for economical fine-tuning while being powerful enough to produce coherent clinical reasoning. It supports the option trust_remote_code=True and integrates seamlessly with HuggingFace Transformers, facilitating its use and deployment.
Prompt Format
Consistency in prompt formatting is essential for instruction fine-tuning. Each training example and each inference call follows the same formatting model, structured as follows:
{answer_letter}) {answer_text}
### Explanation:
Training with LoRA
Rather than fine-tuning all 1.5 billion parameters of the model, the LoRA method is used via the PEFT library. LoRA allows for the injection of small trainable low-rank decomposition matrices into the attention layers while keeping the base weights frozen. This approach significantly reduces the number of parameters to be trained, with only about 2.2 million parameters modified, maintaining low memory usage and enabling rapid training.
LoRA Configuration
The LoRA configuration is set up using the following code:
from peft import LoraConfig, get_peft_model, TaskType
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
lora_dropout=0.05,
target_modules=["q_proj", "v_proj"],
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
Only about 2.2 million of the 1.5 billion parameters of the model are trained, optimizing memory usage and speeding up the training process.
Training Arguments
The training arguments are configured as follows:
from transformers import TrainingArguments
args = TrainingArguments(
output_dir="./outputs",
num_train_epochs=2,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=2e-4,
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
gradient_checkpointing=True,
optim="adamw_torch",
warmup_ratio=0.05,
lr_scheduler_type="cosine",
report_to="none",
)
A few important points to note:
-
fp16=True, bf16=False: The fp16 standard is used. During initial experiments with bfloat16, NaN losses were encountered, but switching to fp16 resolved this issue.
-
gradient_checkpointing=True: This option allows for trading computation for memory. Although not strictly necessary on the MI300X, given its 192 GB of VRAM, it is a good practice to ensure reproducibility on smaller GPUs.
-
gradient_accumulation_steps=4: This allows for an effective batch size of 16 with a physical batch of 4.
-
The cosine LR scheduling with warmup provides smoother convergence than a flat training plan for short trainings.
The Complete Training Loop
The complete training loop is set up with the following code:
from transformers import DataCollatorForSeq2Seq, Trainer
collator = DataCollatorForSeq2Seq(pad_to_multiple_of=8)
trainer = Trainer(
train_dataset=train_ds,
eval_dataset=val_ds,
data_collator=collator,
)
model.save_pretrained("./outputs")
tokenizer.save_pretrained("./outputs")
After training, the ./outputs directory contains the weights of the LoRA adapter, which only weigh a few MB, unlike a full checkpoint of several GB.
Loading from HuggingFace Hub
The fine-tuned adapter is publicly available and can be loaded directly without needing to clone the repository:
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-1.7B", trust_remote_code=True)
base = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-1.7B", torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True)
model = PeftModel.from_pretrained(base, "HK2184/medqa-qwen3-lora")
model = model.merge_and_unload()
Challenges and Solutions
No project using AMD ROCm is complete without a section on the challenges encountered. Here are some of the issues we faced:
-
Mixed precision instability: We had to switch from bfloat16 to fp16 to resolve NaN loss issues.
-
GPU not detected: This was due to missing ROCm environment variables. We configured the variables ROCR_VISIBLE_DEVICES, HIP_VISIBLE_DEVICES, and HSA_OVERRIDE_GFX_VERSION.
-
bitsandbytes not supported: There is no ROCm version of bitsandbytes, which led us to abandon quantization. Fortunately, the MI300X has enough VRAM to do without it.
-
Incorrect inference output: This was due to incorrect tokenizer padding configuration. We adjusted pad_token to be equal to eos_token and corrected padding_side.
-
Trainer evaluation errors: These errors were due to a version incompatibility of Transformers. We fixed the Transformers version to >=4.40.0 to resolve this issue.
The bitsandbytes issue deserves special mention: on NVIDIA hardware, 4-bit quantization is often necessary to fit a model in memory. On the MI300X with its 192 GB of HBM3, this is simply unnecessary. This is a true hardware advantage, allowing for cleaner training without quantization artifacts.
Training Time on MI300X
-
Size of the dataset used: 2,000 samples.
-
Base accuracy of MedMCQA: 6.1.
No GPU? No problem. The live demo on Gradio runs on HuggingFace Spaces.
Next Steps
This project proves that the pipeline works. The next steps involve scaling and enhancement:
-
Train on the complete MedMCQA dataset (~180,000 questions) and add PubMedQA.
-
Add calibrated confidence estimates to the answers.
-
Integrate real-time medical literature retrieval.
-
Conduct appropriate accuracy benchmarking beyond the training split.
MedQA demonstrates that it is not only possible to build a capable and explainable medical AI on open-source AMD hardware, but that it is also straightforward. The ROCm compatibility of the HuggingFace ecosystem is genuinely good. The memory space of the MI300X eliminates an entire category of engineering problems. And LoRA makes fine-tuning a 1.7 billion parameter model a 5-minute task.
If you are working on AMD ROCm and encounter obstacles, the solutions above should save you hours. And if you are building a medical AI, the emphasis on explanation rather than mere accuracy deserves serious consideration.
Brief IA — L'actualité IA en français
L'essentiel de l'actualité de l'intelligence artificielle, décrypté et expliqué chaque jour.