В этом руководстве мы рассмотрим, как решать дифференциальные уравнения и создавать модели нейронных дифференциальных уравнений с помощью библиотеки Diffrax.
Настройка среды и установка необходимых библиотек
Мы начнём с настройки чистой вычислительной среды и установки необходимых библиотек для научных вычислений, таких как JAX, Diffrax, Equinox и Optax.
«`python
import os, sys, subprocess, importlib, pathlib
SENTINEL = «/tmp/diffraxcolabready_v3″
def _run(cmd):
subprocess.check_call(cmd)
def needinstall():
try:
import numpy
import jax
import diffrax
import equinox
import optax
import matplotlib
return False
except Exception:
return True
if not os.path.exists(SENTINEL) or needinstall():
_run([sys.executable, «-m», «pip», «uninstall», «-y», «numpy», «jax», «jaxlib», «diffrax», «equinox», «optax»])
_run([sys.executable, «-m», «pip», «install», «-q», «—upgrade», «pip»])
_run([
sys.executable, «-m», «pip», «install», «-q»,
«numpy==1.26.4»,
«jax[cpu]==0.4.38»,
«jaxlib==0.4.38»,
«diffrax»,
«equinox»,
«optax»,
«matplotlib»
])
pathlib.Path(SENTINEL).write_text(«ready»)
print(«Packages installed cleanly. Runtime will restart now. After reconnect, run this same cell again.»)
os._exit(0)
«`
Решение обыкновенных дифференциальных уравнений
Мы продемонстрируем, как решать обыкновенные дифференциальные уравнения с помощью адаптивных решателей и выполнять плотную интерполяцию для запроса решений в произвольные моменты времени.
«`python
import time
import math
import numpy as np
import jax
import jax.numpy as jnp
import jax.random as jr
import diffrax
import equinox as eqx
import optax
import matplotlib.pyplot as plt
def logistic(t, y, args):
r, k = args
return r y (1 — y / k)
t0, t1 = 0.0, 10.0
ts = jnp.linspace(t0, t1, 300)
y0 = jnp.array(0.4)
args = (2.0, 5.0)
sol_logistic = diffrax.diffeqsolve(
diffrax.ODETerm(logistic),
diffrax.Tsit5(),
t0=t0,
t1=t1,
dt0=0.05,
y0=y0,
args=args,
saveat=diffrax.SaveAt(ts=ts, dense=True),
stepsize_controller=diffrax.PIDController(rtol=1e-6, atol=1e-8),
max_steps=100000,
)
query_ts = jnp.array([0.7, 2.35, 4.8, 9.2])
queryys = jax.vmap(sollogistic.evaluate)(query_ts)
print(«\n=== Пример 1: Логистический рост ===»)
print(«Сохранённое решение shape:», sol_logistic.ys.shape)
print(«Интерполированные значения:»)
for t, y in zip(queryts, queryys):
print(f»t={float(t):.3f} -> y={float(y):.6f}»)
«`
Моделирование системы Лотки-Вольтерра
Мы моделируем систему Лотки-Вольтерра для изучения динамики взаимодействующих популяций во времени.
«`python
def lotka_volterra(t, y, args):
alpha, beta, delta, gamma = args
prey, predator = y
dprey = alpha prey — beta prey * predator
dpred = delta prey predator — gamma * predator
return jnp.array([dprey, dpred])
lv_y0 = jnp.array([10.0, 2.0])
lv_args = (1.5, 1.0, 0.75, 1.0)
lv_ts = jnp.linspace(0.0, 15.0, 500)
sol_lv = diffrax.diffeqsolve(
diffrax.ODETerm(lotka_volterra),
diffrax.Dopri5(),
t0=0.0,
t1=15.0,
dt0=0.02,
y0=lv_y0,
args=lv_args,
saveat=diffrax.SaveAt(ts=lv_ts),
stepsize_controller=diffrax.PIDController(rtol=1e-6, atol=1e-8),
max_steps=100000,
)
print(«\n=== Пример 2: Лотка-Вольтерра ===»)
print(«Shape:», sol_lv.ys.shape)
«`
Визуализация результатов
Мы визуализируем результаты моделирования и процесса обучения, чтобы понять поведение моделируемых систем.
«`python
plt.figure(figsize=(8, 4))
plt.plot(ts, sol_logistic.ys, label=»solution»)
plt.scatter(np.array(queryts), np.array(queryys), s=30, label=»dense interpolation»)
plt.title(«Adaptive ODE + Dense Interpolation»)
plt.xlabel(«t»)
plt.ylabel(«y»)
plt.legend()
plt.tight_layout()
plt.show()
«`
Заключение
Мы реализовали полный рабочий процесс для научных вычислений и машинного обучения с использованием Diffrax и экосистемы JAX. Мы решили детерминированные и стохастические дифференциальные уравнения, выполнили пакетные симуляции и обучили модель нейронных дифференциальных уравнений, которая изучает основные закономерности системы на основе данных. На протяжении всего процесса мы использовали JAX для компиляции и автоматической дифференциации, чтобы достичь эффективных вычислений и масштабируемого экспериментирования.
1. Какие библиотеки и инструменты используются для решения дифференциальных уравнений в статье?
В статье используются следующие библиотеки и инструменты: JAX, Diffrax, Equinox, Optax, NumPy, Matplotlib.
2. Какие типы дифференциальных уравнений рассматриваются в статье?
В статье рассматривается решение обыкновенных дифференциальных уравнений (ОДУ), в частности, демонстрируется решение логистического уравнения и системы Лотки-Вольтерра.
3. Какие методы используются для решения дифференциальных уравнений в статье?
В статье используются адаптивные решатели для решения ОДУ. Например, в примере с логистическим уравнением используется решатель Tsit5, а в примере с системой Лотки-Вольтерра — Dopri5.
4. Какие параметры используются при решении дифференциальных уравнений в статье?
При решении дифференциальных уравнений в статье используются следующие параметры: начальные условия (y0), аргументы функции (args), временные интервалы (t0, t1), шаг интегрирования (dt0), параметры контроллеров шага (rtol, atol), максимальное количество шагов (max_steps).
5. Какие преимущества даёт использование JAX для научных вычислений и машинного обучения?
Использование JAX для научных вычислений и машинного обучения позволяет достичь эффективных вычислений и масштабируемого экспериментирования благодаря компиляции и автоматической дифференциации.