Skip to content
/ FlashCE Public

Flash-Cross-Entropy, using the same overarching ideas as FlashAttention (mapreduce/semigroup-folds)

License

Notifications You must be signed in to change notification settings

Apsod/FlashCE

Repository files navigation

Flash Cross-entropy

A small and not really optimized implementation of memory efficient cross-entropy loss. The overarching idea is similar to that of Flash-Attention and similar methods, which in essence is the same underlying idea as mapreduce, namely monoidal/semigroup-folds. (See the PDF for details)

A jupyter notebook containing some testing, validation, and use is also included here.

This implementation is done directly in pytorch and has been tested on CPUs, GPU kernel implementation left as an excersise to the reader ;)

About

Flash-Cross-Entropy, using the same overarching ideas as FlashAttention (mapreduce/semigroup-folds)

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published