Section courante

A propos

Section administrative du site

Le format étrange de safetensors

Brève explication

Lorsque vous souhaitez charger des dictionnaires d'état depuis le disque :

Explication détaillée

OLMo enregistre les points de contrôle non fragmentés avec torch.save(), écrivant un dictionnaire d'états dans un fichier en utilisant (essentiellement) pickle de la bibliothèque standard Python. Le problème est que pickle est lent, mono-processus léger et nous oblige à lire l'intégralité du fichier avant d'en lire une partie. Pour le modèle 65B, les points de contrôle pèsent environ 700 Go et leur chargement prend plusieurs minutes, ce qui constitue un problème majeur. De plus, sur une machine équipée de 8 GPU, nous ne pouvons pas charger le modèle 8 fois en parallèle, car nous manquerions de mémoire.

Le format safetensors de Huggingface résout ces problèmes. Safetensors nous permet d'entreposer les dictionnaires d'états de telle sorte que, lors de leur chargement, les données des tenseurs soient cartographiées en mémoire depuis le disque. Cela signifie que le tenseur ne sera pas réellement chargé tant que votre code n'y aura pas accédé. Plusieurs processus chargeant le même fichier safetensors sur la même machine (ce qui se produit exactement lorsque OLMo charge un modèle depuis un point de contrôle) ne liront les données qu'une seule fois.

Malheureusement, safetensors ne va pas assez loin. Safetensors peut accomplir sa tâche si vous disposez d'un dictionnaire d'état conforme au type Dict[str, Tensor], c'est-à-dire un dictionnaire Python cartographiant les chaînes aux tenseurs. Cela est valable pour les pondérations du modèle, mais pas pour l'état de l'optimiseur. Nous avons donc placé une couche au-dessus de safetensors cartographiant les types de données nécessaires à OLMo pour le modèle et l'état de l'optimiseur aux types de données nécessaires à safetensors pour accomplir sa tâche. Cette cartographie s'effectue dans safetensors_util.py.

Les fonctions clefs sont :

Il existe un script permettant de convertir un fichier au format standard de PyTorch au format safetensors : convert_pt_to_safetensors.py. Ce script charge lentement le fichier d'origine et stocke l'intégralité de son contenu en mémoire, ce qui consomme beaucoup de mémoire et de temps.

Lors du chargement du fichier, chaque fois qu'OLMo tente de charger un dictionnaire d'état à partir du fichier foo.pt, il vérifie d'abord si le fichier foo.safetensors existe. Si c'est le cas, il le charge à la place.



Dernière mise à jour : Vendredi, le 6 juin 2025