На чем программируют суровый ML в Гугле

В 2015 году Гугл выпустил TensorFlow - супер-современный фреймворк для машинного обучения, созданный с участием самого Jeff Dean.

Но если почитать статьи за последние несколько лет из Google/DeepMind, то можно заметить что эксперименты реализованы с использованием совсем другого малоизвестного фреймворка JAX и лишь в самом конце портированы на TensorFlow/PyTorch. Именно с помощью JAX тренируют гигантские нейронные сети: текстовые, computer vision, мультимодальные.

История JAX

Несколько инженеров сложили вместе библиотеку автоматической дифференциации autograd, компилятор XLA (Accelerated Linear Algebra for GPU/TPU), написали элегантный just-in-time питоновских функций, и получили функциональный NumPy, который легко параллелится на GPU/TPU и кластеры.

Если TensorFlow изначально был спроектирован для работы на одном сервере, а распредленные вычисления были добавлены позже, то параллелизация в JAX - это его центральная фишка. Она достигается в частности использованием функциональной парадигмы: хорошо компилируются только pure functions без побочних эффектов, пользователь сам отвечает за управление state, даже training loop принято писать самому. JAX очень удачно совпал с потребностями инженеров и исследователей. Они быстро написали кучу библиотек для нейронок, оптимизаторов, reinforcement learning. Написали даже конвертер JAX программ в TensorFlow graphs, чтобы использовать TensorFlow Serving (aka Servomatic).

JAX невероятно популярен внутри Гугла, но малоизвестен за его пределами. Исследователям это на руку - никто не принуждает делать фреймфорк доступным для всех, да и вице-президенты не терзают команду туманными целями и прочими синергиями.

У JAX есть отличная документация на Readthedocs. Я перепечатывал примеры оттуда в Google Colab, изменял их, пробовал их запускать на бесплатных Colab kernels with CPU/GPU/TPU.

Основные строительные блоки

NumPy interface

Некоторые курсы по машинному обучению показывали как можно реализовь тренировку нейронных сетей умножением векторов/матриц NumPy, как вычислять производные цепочеатк функций. JAX - это в первую очередь невероятно ускоренный NumPy (see JAX As Accelerated NumPy). Все операции jax.numpy оптимизированы для выполнения на GPU/TPU. К этому добавлены возможности автоматической векторизации и параллелизации вычислений (как в курсе ml-class.org можно было векторизовать вычисления в Octave, ускоряя их в десятки-сотни раз).

Just-in time compilation

Функции без побочных эффектов можно легко скомпилировать, обернув их в функцию jax.jit. Компиляция осуществляется методом трассировки - в качестве параметров передаются специальные объекты, которые запоминают все операции, которые с ними производятся. По результатам трассировки строится граф вычислений “входные параметры” - ??? - “выходные параметры”. Потом этот граф компилируется с использованием XLA (её когда-то написали для TensorFlow).

Автоматическая дифферециация

Производные больше считать не нужно. Оборачиваешь loss function в функцию grad и получаешь градиенты. Вообще очень многое в JAX решается композицией функций. Опыт функционального программирования (Haskell, Erlang, ваши варианты) будет очень к стати.

Flax - библиотека для нейронок

Flax - самая популярная библетека для моделирования нейронных сетей. Отличная документация, есть много примеров, в том числе реальных исследовательских проектов из Гугла. Еще со всем недавно с ней конкурировала библиотека Haiku, но в конце концов Flax стал более популярен и Haiku перевели в режим поддержки.

У Flax офигенная философия. Чего только стоит “Prefer duplicating code over a bad abstraction.” Не всем такая философия подходит, но мне очень резонирует.

Заключение

А TenforFlow? Он все равно оказался полезен - многие его полезные части теперь успешно используются в других проектах. Но как ML Framework он теряет популярность.

comments powered by Disqus