Section courante

A propos

Section administrative du site

Voici la liste des modules proposé par la bibliothèque JAX pour Python :

Nom Description
jax.ad_checkpoint Ce module fournit des outils pour le checkpointing des gradients, réduisant la mémoire utilisée lors de la rétropropagation.
jax.debug Ce module offre des fonctions pour le debugging des programmes JAX, comme le suivi des valeurs ou la vérification des transformations.
jax.distributed Ce module fournit des primitives pour l'entraînement distribué sur plusieurs machines et dispositifs.
jax.dlpack Ce module fournit des fonctions d'interopérabilité avec d'autres bibliothèques utilisant DLPack, pour échanger des tenseurs entre cadres d'applications.
jax.dtypes Ce module gère les types de données numériques, permettant un contrôle précis sur les float32, float64, int32,...
jax.example_libraries Ce module regroupe des bibliothèques d'exemples et extensions illustrant les usages avancés de JAX.
jax.experimental Ce module contient des fonctionnalités expérimentales, souvent instables, permettant d'accéder à des APIs avancées ou en développement.
jax.export Ce module permet d'exporter des fonctions JAX compilées pour utilisation en dehors de Python.
jax.extend Ce module fournit des outils pour étendre JAX avec de nouvelles primitives ou transformations personnalisées.
jax.ffi Ce module offre des interfaces pour interagir avec du code bas niveau en C et des bibliothèques externes via Foreign Function Interface.
jax.flatten_util Ce module permet de convertir des structures de données complexes en vecteurs plats et de les reconstruire, utile pour l'optimisation et l'algèbre linéaire.
jax.image Ce module fournit des outils de traitement d'images, comme le redimensionnement, le filtrage et l'interpolation.
jax.lax Ce module expose les opérations primitives de bas niveau de JAX, permettant des transformations fonctionnelles et un contrôle fin sur les calculs.
jax.nn Ce module contient des fonctions pour les réseaux de neurones, comme les activations (relu, softmax) et la normalisation.
jax.numpy Ce module fournit une API compatible NumPy pour le calcul sur tableaux, avec prise en charge de GPU/TPU et des transformations JAX comme jit et grad.
jax.ops Ce module propose des opérations sur tableaux immuables, telles que index_update et index_add, pour manipuler les DeviceArray.
jax.profiler Ce module fournit des outils pour profiling et suivi des performances, afin d'optimiser les programmes JAX.
jax.random Ce module gère la génération de nombres aléatoires de manière purement fonctionnelle et reproductible pour les simulations et modèles probabilistes.
jax.ref Ce module offre des références mutables dans un environnement normalement fonctionnel, pour gérer certains cas d'état local.
jax.scipy Ce module fournit des fonctions scientifiques avancées, similaires à SciPy, mais optimisées pour l'exécution sur GPU/TPU.
jax.sharding Ce module permet la distribution de tableaux sur plusieurs dispositifs pour un calcul parallèle et optimisé sur GPU/TPU.
jax.stages Ce module est utilisé pour staging du code et la transformation progressive des fonctions avant compilation.
jax.test_util Ce module contient des outils et utilitaires pour les tests, incluant la comparaison de tableaux et la validation des résultats.
jax.tree Ce module fournit des fonctions pour parcourir et transformer des structures arborescentes de données (lists, tuples, dicts imbriqués).
jax.tree_util Ce module complète jax.tree avec des fonctions pour mapper, flatten et unflatten des arbres de données.
jax.typing Ce module contient des définitions de types pour la vérification statique et la documentation, améliorant la sécurité des fonctions JAX.


Dernière mise à jour : Lundi, le 2 février 2026