dLLM training implementation on pure jax/flax (w/o pytorch) for Google TPUs(v4/v5e/v6e). #TPUSprint #TRC - View it on GitHub
Star
7
Rank
2009547