Voici la liste des modules proposé par la bibliothèque JAX pour Python :
| Nom | Description |
|---|---|
| jax.ad_checkpoint | Ce module fournit des outils pour le checkpointing des gradients, réduisant la mémoire utilisée lors de la rétropropagation. |
| jax.debug | Ce module offre des fonctions pour le debugging des programmes JAX, comme le suivi des valeurs ou la vérification des transformations. |
| jax.distributed | Ce module fournit des primitives pour l'entraînement distribué sur plusieurs machines et dispositifs. |
| jax.dlpack | Ce module fournit des fonctions d'interopérabilité avec d'autres bibliothèques utilisant DLPack, pour échanger des tenseurs entre cadres d'applications. |
| jax.dtypes | Ce module gère les types de données numériques, permettant un contrôle précis sur les float32, float64, int32,... |
| jax.example_libraries | Ce module regroupe des bibliothèques d'exemples et extensions illustrant les usages avancés de JAX. |
| jax.experimental | Ce module contient des fonctionnalités expérimentales, souvent instables, permettant d'accéder à des APIs avancées ou en développement. |
| jax.export | Ce module permet d'exporter des fonctions JAX compilées pour utilisation en dehors de Python. |
| jax.extend | Ce module fournit des outils pour étendre JAX avec de nouvelles primitives ou transformations personnalisées. |
| jax.ffi | Ce module offre des interfaces pour interagir avec du code bas niveau en C et des bibliothèques externes via Foreign Function Interface. |
| jax.flatten_util | Ce module permet de convertir des structures de données complexes en vecteurs plats et de les reconstruire, utile pour l'optimisation et l'algèbre linéaire. |
| jax.image | Ce module fournit des outils de traitement d'images, comme le redimensionnement, le filtrage et l'interpolation. |
| jax.lax | Ce module expose les opérations primitives de bas niveau de JAX, permettant des transformations fonctionnelles et un contrôle fin sur les calculs. |
| jax.nn | Ce module contient des fonctions pour les réseaux de neurones, comme les activations (relu, softmax) et la normalisation. |
| jax.numpy | Ce module fournit une API compatible NumPy pour le calcul sur tableaux, avec prise en charge de GPU/TPU et des transformations JAX comme jit et grad. |
| jax.ops | Ce module propose des opérations sur tableaux immuables, telles que index_update et index_add, pour manipuler les DeviceArray. |
| jax.profiler | Ce module fournit des outils pour profiling et suivi des performances, afin d'optimiser les programmes JAX. |
| jax.random | Ce module gère la génération de nombres aléatoires de manière purement fonctionnelle et reproductible pour les simulations et modèles probabilistes. |
| jax.ref | Ce module offre des références mutables dans un environnement normalement fonctionnel, pour gérer certains cas d'état local. |
| jax.scipy | Ce module fournit des fonctions scientifiques avancées, similaires à SciPy, mais optimisées pour l'exécution sur GPU/TPU. |
| jax.sharding | Ce module permet la distribution de tableaux sur plusieurs dispositifs pour un calcul parallèle et optimisé sur GPU/TPU. |
| jax.stages | Ce module est utilisé pour staging du code et la transformation progressive des fonctions avant compilation. |
| jax.test_util | Ce module contient des outils et utilitaires pour les tests, incluant la comparaison de tableaux et la validation des résultats. |
| jax.tree | Ce module fournit des fonctions pour parcourir et transformer des structures arborescentes de données (lists, tuples, dicts imbriqués). |
| jax.tree_util | Ce module complète jax.tree avec des fonctions pour mapper, flatten et unflatten des arbres de données. |
| jax.typing | Ce module contient des définitions de types pour la vérification statique et la documentation, améliorant la sécurité des fonctions JAX. |
Dernière mise à jour : Lundi, le 2 février 2026