L'IA sur plusieurs GPU : ZeRO et FSDP
⚡ Résumé en français par Brief IA
• L'article explique le fonctionnement de l'optimiseur Zero Redundancy et son implémentation dans PyTorch. • ZeRO permet d'optimiser l'utilisation de la mémoire sur plusieurs GPU, améliorant ainsi l'efficacité des modèles d'IA. • Dans un contexte où les modèles deviennent de plus en plus gourmands en ressources, la gestion efficace des GPU est cruciale pour le développement d'IA à grande échelle. 💡 Pourquoi c'est important : L'optimisation des ressources GPU est essentielle pour réduire les coûts et accélérer le déploiement des modèles d'IA avancés.
📄 Article traduit en français
L'IA sur plusieurs GPU : ZeRO et FSDP
Introduction à ZeRO
Dans le précédent article, nous avons vu comment le Distributed Data Parallelism (DDP) accélère l'entraînement en répartissant les lots sur plusieurs GPU. Bien que DDP résolve le problème de débit, il introduit un nouveau défi : la redondance mémoire.
Dans le DDP classique, chaque GPU détient une copie complète des paramètres du modèle, des gradients et des états de l'optimiseur. Pour de grands modèles comme GPT-3 (175 milliards de paramètres), cette redondance devient un gaspillage considérable de VRAM précieuse.
Problème de mémoire dans DDP
Analysons ce qui consomme réellement de la mémoire pendant l'entraînement. Pour un modèle avec N paramètres :
- Paramètres du modèle : valeurs (les poids de votre réseau de neurones)
- Gradients : valeurs (un gradient par paramètre)
- États de l'optimiseur (Adam) : valeurs (premier moment et second moment pour chaque paramètre)
- Activations : sorties intermédiaires stockées pendant le passage avant pour être utilisées dans le passage arrière
Les trois premiers éléments augmentent avec la taille du modèle et sont redondants sur les GPU dans DDP. Les activations, quant à elles, dépendent de la taille du lot, de la longueur de la séquence et du nombre de neurones, et sont uniques à chaque GPU, car chaque GPU traite des données différentes. ZeRO ne touche pas à la mémoire des activations.
Calculons l'utilisation de la mémoire pour un modèle de 7 milliards de paramètres utilisant Adam et FP32 :
- Paramètres : 7 milliards * 4 octets = 28 Go
- Gradients : 7 milliards * 4 octets = 28 Go
- États de l'optimiseur : 7 milliards * 2 * 4 octets = 56 Go
Mémoire par GPU dans DDP : 112 Go
Les activations ajoutent une mémoire significative à cela, mais comme elles sont uniques à chaque GPU, ZeRO ne peut pas les partitionner. Des techniques comme le checkpointing des activations peuvent aider, en supprimant certaines activations et en les recomputant au besoin pendant le passage arrière. Mais cela dépasse le cadre de cet article.
Fonctionnement de ZeRO
Comprenons comment ZeRO fonctionne en l'implémentant depuis le début, en commençant par ZeRO-1 et en progressant vers ZeRO-3.
ZeRO-1 : Partitionnement des états de l'optimiseur
Dans ZeRO-1, seuls les états de l'optimiseur sont partitionnés. Chaque GPU :
- Détient toujours les paramètres complets du modèle et les gradients
- Stocke seulement 1/N des états de l'optimiseur (N = nombre de GPU)
- Met à jour uniquement les 1/N correspondants des paramètres
Voici la séquence d'actions effectuées pendant l'entraînement :
- Passage avant : chaque GPU traite son propre micro-lot
- Passage arrière : calcul des gradients
- All-reduce des gradients : chaque GPU reçoit tous les gradients
- Étape de l'optimiseur : chaque GPU met à jour sa partition de paramètres
- All-gather des paramètres : synchronisation du modèle mis à jour entre les GPU
Implémentation simplifiée
Voici une implémentation simplifiée :
import torch.distributed as dist
def __init__(self, model, optimizer_cls):
self.model = model
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
self.param_shards = list() # chaque rang ne détient que sa part des états de l'optimiseur
self.param_metadata = list() # métadonnées pour reconstruire les parts
for param in self.model.parameters():
original_shape = param.data.shape
flat = param.data.view(-1)
numel = flat.numel()
remainder = numel % self.world_size
pad_size = (self.world_size - remainder) % self.world_size
padded_numel = numel + pad_size
shard_size = padded_numel // self.world_size
shard_start = self.rank * shard_size
shard_end = shard_start + shard_size
self.param_metadata.append({
"original_shape": original_shape,
"padded_numel": padded_numel,
"shard_size": shard_size,
"shard_start": shard_start,
"shard_end": shard_end,
})
if pad_size > 0:
flat_padded = torch.cat([flat, flat.new_zeros(pad_size)])
else:
flat_padded = flat
shard = flat_padded[shard_start:shard_end].clone()
shard.requires_grad_(True)
self.param_shards.append(shard)
self.optimizer = optimizer_cls(self.param_shards)
def training_step(self, inputs, targets, loss_fn):
output = self.model(inputs) # passage avant
loss = loss_fn(output, targets) # calcul de la perte
loss.backward() # passage arrière
self._sync_gradients() # all-reduce des gradients entre les GPU
self.optimizer.step() # mise à jour de la part locale des paramètres
self._sync_params() # synchronisation des paramètres du modèle
# réinitialiser les gradients pour l'étape suivante
for param in self.model.parameters():
param.grad = None
def _sync_gradients(self):
for idx, param in enumerate(self.model.parameters()):
# Code pour synchroniser les gradients
Cette implémentation montre comment ZeRO-1 partitionne les états de l'optimiseur tout en maintenant les paramètres et les gradients complets sur chaque GPU.
Brief IA — Veille IA en français
Toutes les innovations mondiales en IA, traduites et résumées automatiquement. Recevoir les meilleures actus IA chaque jour.