Google JAX 끄적이기
JAX
요즘 Deep Learning 오픈소스를 보면 JAX를 사용한 많은 프로젝트가 보인다. 실제 DeepMind도 자신들의 프로젝트의 대부분을 “JAX”와 결합을 시켰다고 하는데.. 왜이렇게 인기가 많을까?
JAX는 Google이 만든 Python과 Numpy만을 결합한 머신러닝 라이브러리다. 간단하게 요약하면
JAX is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance machine learning research.
- Python과 Numpy로 개발되었다. (깃허브를 보면 90% 이상의 코드가 파이썬이다.)
- Numpy를 GPU에서 연산시킬 수 있게 하여 기존 Numpy 성능을 뛰어넘는다.
- 자동 미분 계산
- jit(just in time) 컴파일 기법과 XLA 컴파일러를 사용하여 컴파일을 할 수 있다. (XLA는 모든 기기에 적용될 예정이라고 합니다.)
라이브러리 API
grad
,jit
: instances of such transformations.vmap
: automatic vectorizationpmap
: single-program multiple-data (SPMD) parallel programming of multiple accelerators, with more to come.
확실히 Python과 Numpy만으로(+ GPU) 개발되었다는 점이 놀랍다. 이는 tensor array와 같은 것을 고려하지 않고 numpy array만을 고려해서 코드를 짤 수 있다는 것인데 이 부분이 가장 큰 장점이지 않을까 생각된다.
그리고 jit 컴파일 데코레이터 함수를 적용하면 부분 컴파일이 가능해 깔끔하게 작성된 코드에서는 큰 장점을 가진다고 한다.
컴파일의 개념이 Pytorch같은 라이브러리에도 적용될 가능성이 있지 않을까 생각된다.
튜토리얼
- 참조 : QuickStart
Import
1
2
3
4
5
6
7
8
9
import numpy as np
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
from jax import device_put
# Prevent GPU/TPU warning.
import jax; jax.config.update('jax_platform_name', 'cpu')
Multiplying Matrices
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
key = random.PRNGKey(0) # seed
x = random.normal(key, (10,)) # random
print(x)
size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready() # runs on the GPU
# 866 ms ± 22.8 ms per loop
# block_until_ready() : 비동기 실행을 위한 함수
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit jnp.dot(x, x.T).block_until_ready()
# 839 ms ± 11.9 ms per loop
# 매번 GPU에 올리니 느릴 수 있다.
x = np.random.normal(size=(size, size)).astype(np.float32)
x = device_put(x)
%timeit jnp.dot(x, x.T).block_until_ready()
# 793 ms ± 10.4 ms per loop
# 미리 올려놓는 방법 device_put 조금 더 빠를 수 있다.
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit np.dot(x, x.T)
# 434 ms ± 8.64 ms per loop
# -> 이상하네.. 왜이렇게 빠르지 jnp가 더 빨라야 하는것 아닌가?
마지막 줄이 이상하니.. 역시 stackoverflow.. 이와 관련된 질문이 있습니다.
랜덤 변수 만들 때 x = random.normal(key, (10,))
이 부분에 dtype=jnp.float64
를 추가해야 정확하게 나오는 것 같습니다.
jit()
를 사용해서 속도 업!
1
2
3
4
5
6
7
8
9
10
def selu(x, alpha=1.67, lmbda=1.05):
return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()
# 5.66 ms ± 141 µs per loop
selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()
# 1.15 ms ± 4.06 µs per loop
grad()
를 사용해서 미분!
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def sum_logistic(x):
return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))
x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))
# [0.25 0.19661197 0.10499357]
def first_finite_differences(f, x):
eps = 1e-3
return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)
for v in jnp.eye(len(x))])
print(first_finite_differences(sum_logistic, x_small))
# [0.24998187 0.1964569 0.10502338]
vmap()
을 사용해서 자동 벡터화!
- promote matrix-vector products into matrix-matrix products using
vmap()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
mat = random.normal(key, (150, 100))
batched_x = random.normal(key, (10, 100))
def apply_matrix(v):
return jnp.dot(mat, v)
def naively_batched_apply_matrix(v_batched):
return jnp.stack([apply_matrix(v) for v in v_batched])
print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()
# 7.68 ms ± 201 µs per loop
# 단순한 방식
@jit
def batched_apply_matrix(v_batched):
return jnp.dot(v_batched, mat.T)
print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()
# 68.8 µs ± 369 ns per loop
# dot으로 자동으로 처리
@jit
def vmap_batched_apply_matrix(v_batched):
return vmap(apply_matrix)(v_batched)
print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()
# 105 µs ± 438 ns per loop
# 만약 dot 말고 복잡한 연산을 가지는 배치 연산을 자동으로 처리
This post is licensed under CC BY 4.0 by the author.