Colab에서 PyTorch 모델 TPU로 학습하기

Colab에서 PyTorch 모델 TPU로 학습하기

딥러닝 모델을 학습시키다 보면 항상 vram의 압박에 시달리게 된다. 특히 최근 막대한 크기의 모델들이 등장해 이런 압박은 더 심해지기도 한다.

한편, 일반 사용자용 그래픽 카드 중 최상위인 Nvidia 2080ti조차도 vram이 겨우 11GB밖에 되지 않아 거대한 모델을 Fine-tuning 하는 것조차 굉장히 작은 배치사이즈로 학습시켜야 한다.

Google Colab에서 제공하는 TPU는 tpu v3-8 모델로 총 128GB의 메모리를 가지고 있어, 상대적으로 큰 모델과 배치사이즈를 이용해 학습할 수 있다. (tpu v3 하나는 16GB의 HBM 메모리를 가지고 있고, tpu v3-8은 8개의 코어로 총 128GB의 메모리를 가진다.)

PyTorch에서는 Pytorch/XLA 프로젝트를 통해 PyTorch에서도 TPU를 통한 학습을 할 수 있도록 컴파일러를 제공하고 있고, colab에 해당 패키지를 설치하면 TPU를 곧바로 사용할 수 있다.

NOTE: 이번 글은 아래 공식 튜토리얼의 내용을 따라갑니다.

공식 Tutorial: PyTorch on Cloud TPUs: Single Core Training AlexNet on Fashion MNIST

(단 내용의 100%를 담는 대신, 기존 PyTorch와 동일한 부분은 제외함)

PyTorch/XLA 설치하기

1
2
3
VERSION = "20200220"
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version $VERSION

PyTorch/XLA 패키지는 Github에서 설치 스크립트를 받아 설치할 수 있다.

글 쓰는 날짜 2020/02/24에는 20200220이 최신버전이다.

colab에서 설치를 진행하면 torch-1.4.0torch-1.5.0a0 버전으로, torchvision-0.5.00.6.0a0 로 업데이트한다. 이와 함께 설치하는 torch-xla 패키지가 메인 패키지가 된다.

데이터셋 준비하기

데이터셋을 다루는 방법은 기존 PyTorch에서 사용하던 방법과 동일하다.

1
2
3
4
5
6
7
8
9
10
import os
import torch
import torchvision
import torchvision.datasets as datasets

raw_dataset = datasets.FashionMNIST(
os.path.join("/tmp/fashionmnist"),
train=True,
download=True
)

torchvision.datasets 를 통해 FashionMNIST를 받을 수 있다.

1
2
3
4
5
6
7
8
9
10
11
import torchvision.transforms as transforms

# See https://pytorch.org/docs/stable/torchvision/models.html for normalization
# Pre-trained TorchVision models expect RGB (3 x H x W) images
# H and W should be >= 224
# Loaded into [0, 1] and normalized as follows:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
to_rgb = transforms.Lambda(lambda image: image.convert('RGB'))
resize = transforms.Resize((224, 224))
my_transform = transforms.Compose([resize, to_rgb, transforms.ToTensor(), normalize])

torchvision.transforms를 통해 데이터셋을 이미지넷 기반으로 Normalize하고 & RGB & 244x244 사이즈로 리사이징 하는 처리를 해줄 수 있다.

1
2
3
4
5
6
7
8
9
10
11
train_dataset = datasets.FashionMNIST(
os.path.join("/tmp/fashionmnist"),
train=True,
download=True,
transform=my_transform)

test_dataset = datasets.FashionMNIST(
os.path.join("/tmp/fashionmnist"),
train=False,
download=True,
transform=my_transform)

Dataset을 load할 때 transform 인자로 전달해주면 위 처리가 모두 함께 진행된다.

1
2
train_sampler = torch.utils.data.RandomSampler(train_dataset)
test_sampler = torch.utils.data.RandomSampler(test_dataset)

이후 Sampler를 통해 순서를 적절히 섞어준다.

1
2
3
4
5
6
7
8
9
10
11
batch_size = 8

train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=batch_size,
sampler=train_sampler)

test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=batch_size,
sampler=test_sampler)

위와 같이 DataLoader를 통해 데이터셋을 batch_size로 잘라 Iterable한 객체로 바꿔준다.

보다 자세한 Dataset, Sampler, DataLoader에 대한 정보는 아래 링크를 참고해보자.

[링크] pytorch dataset 정리 | Hulk의 개인 공부용 블로그

[중요!] torch-xla 패키지로 Device 지정하기

PyTorch에서는 .to(device) 문법을 통해 텐서 변수들과 모델들을 GPU와 같은 device에 올릴 수 있다.

TPU에 올리기 위해서는 torch_xla 에서 제공하는 xm.xla_device() 를 통해 PyTorch에 호환되는 device 를 지정할 수 있다.

1
2
3
4
5
6
7
8
9
import torch_xla
import torch_xla.core.xla_model as xm

# Creates AlexNet for 10 classes
net = torchvision.models.alexnet(num_classes=10)

# Acquires the default Cloud TPU core and moves the model to it
device = xm.xla_device()
net = net.to(device)

TPU로 PyTorch 딥러닝 모델 학습시키기

DataLoader와 Transforms를 이용해 데이터를 증폭시키고 적절한 수준의 이미지로 변환시키는 과정은 기존의 PyTorch 코드와 완전히 동일하다.

또한, loss function와 optimizer도 torch.nntorch.optim에서 사용하는 것 그대로 사용할 수 있다.

그리고 나머지 Training 과정의 코드도 거의 99% 동일하지만 .to(device) 의 device가 TPU라는 점만이 차이가 있다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# Note: this will take 4-5 minutes to run.
num_epochs = 1
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters())

# Ensures network is in train mode
net.train()

start_time = time.time()
for epoch in range(num_epochs):
for data, targets in iter(train_loader):
# Sends data and targets to device
data = data.to(device)
targets = targets.to(device)

# Acquires the network's best guesses at each class
results = net(data)

# Computes loss
loss = loss_fn(results, targets)

# Updates model
optimizer.zero_grad()
loss.backward()
xm.optimizer_step(optimizer, barrier=True) # 이부분이 TPU 쓸때 필요한 코드!!

elapsed_time = time.time() - start_time
print ("Spent ", elapsed_time, " seconds training for ", num_epochs, " epoch(s) on a single core.")

위 코드 중 25번째 줄의 xm.optimizer_step(optimizer, barrier=True) 부분이 TPU를 사용하기 위한 코드이다. (GPU 코드에 겨우 한줄 추가!)

TPU 코어 FULL 활용하기

앞서 다룬 과정에서는 TPU 8core 중 1core만을 사용한다. 한편, TPU 1개에 들어있는 8개의 코어 전체를 사용하면 보다 빠른 학습이 가능하다.

아래 코드는 PyTorch/XLA의 공식 튜토리얼을 참고합니다.

공식 Tutorial: PyTorch on Cloud TPUs: MultiCore Training AlexNet on Fashion MNIST

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import torch
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp

# "Map function": acquires a corresponding Cloud TPU core, creates a tensor on it,
# and prints its core
def simple_map_fn(index, flags):
# Sets a common random seed - both for initialization and ensuring graph is the same
torch.manual_seed(1234)

# Acquires the (unique) Cloud TPU core corresponding to this process's index
device = xm.xla_device()

# Creates a tensor on this process's device
t = torch.randn((2, 2), device=device)

print("Process", index ,"is using", xm.xla_real_devices([str(device)])[0])

# Spawns eight of the map functions, one for each of the eight cores on
# the Cloud TPU
flags = {}
# Note: Colab only supports start_method='fork'
xmp.spawn(simple_map_fn, args=(flags,), nprocs=8, start_method='fork')

TPU 여러 코어를 사용하기 위해서는 torch_xla.distributed.xla_multiprocessing 을 통해 프로세스를 N개 띄우는 방식으로 진행한다.

단, 앞서 진행한 코드는 model과 data 로드 부분이 모두 각자의 코드로 나와 Colab 인스턴스의 CPU/Ram에서 진행되어 코드 내의 변수를 .to(device) 를 사용해 GPU나 TPU로 보낼 수 있지만, TPU의 여러 코어를 사용할때는 하나의 함수 내에 model과 data 모두를 넣고 진행해야 한다.

아래와 같이 map_fn 을 만들어서 PyTorch 학습/평가를 위한 부분을 넣어준다.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torchvision
from torchvision import datasets
import torchvision.transforms as transforms
import torch_xla.distributed.parallel_loader as pl
import time

def map_fn(index, flags):
torch.manual_seed(flags['seed'])
device = xm.xla_device()
dataset_path = os.path.join("/tmp/fashionmnist", str(xm.get_ordinal()))
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
to_rgb = transforms.Lambda(lambda image: image.convert('RGB'))
resize = transforms.Resize((224, 224))
my_transform = transforms.Compose([resize, to_rgb, transforms.ToTensor(), normalize])

train_dataset = datasets.FashionMNIST(
dataset_path,
train=True,
download=True,
transform=my_transform)

test_dataset = datasets.FashionMNIST(
dataset_path,
train=False,
download=True,
transform=my_transform)

train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset,
num_replicas=xm.xrt_world_size(),
rank=xm.get_ordinal(),
shuffle=True)

test_sampler = torch.utils.data.distributed.DistributedSampler(
test_dataset,
num_replicas=xm.xrt_world_size(),
rank=xm.get_ordinal(),
shuffle=False)

train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=flags['batch_size'],
sampler=train_sampler,
num_workers=flags['num_workers'],
drop_last=True)

test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=flags['batch_size'],
sampler=test_sampler,
shuffle=False,
num_workers=flags['num_workers'],
drop_last=True)

net = torchvision.models.alexnet(num_classes=10).to(device).train()

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters())

train_start = time.time()
for epoch in range(flags['num_epochs']):
para_train_loader = pl.ParallelLoader(train_loader, [device]).per_device_loader(device)
for batch_num, batch in enumerate(para_train_loader):
data, targets = batch

output = net(data)

loss = loss_fn(output, targets)

optimizer.zero_grad()
loss.backward()

xm.optimizer_step(optimizer) # ParallelLoader 쓸때는 barrier=True 필요 없음

elapsed_train_time = time.time() - train_start
print("Process", index, "finished training. Train time was:", elapsed_train_time)

## Evaluation
# Sets net to eval and no grad context
net.eval()
eval_start = time.time()
with torch.no_grad():
num_correct = 0
total_guesses = 0

para_train_loader = pl.ParallelLoader(test_loader, [device]).per_device_loader(device)
for batch_num, batch in enumerate(para_train_loader):
data, targets = batch

output = net(data)
best_guesses = torch.argmax(output, 1)

num_correct += torch.eq(targets, best_guesses).sum().item()
total_guesses += flags['batch_size']

elapsed_eval_time = time.time() - eval_start
print("Process", index, "finished evaluation. Evaluation time was:", elapsed_eval_time)
print("Process", index, "guessed", num_correct, "of", total_guesses, "correctly for", num_correct/total_guesses * 100, "% accuracy.")

이전 코드와 다른 부분은 74번째 줄에서 더이상 barrier=True 가 필요하지 않다는 것이다.

1
2
3
4
5
6
7
# Configures training (and evaluation) parameters
flags['batch_size'] = 32
flags['num_workers'] = 8
flags['num_epochs'] = 1
flags['seed'] = 1234

xmp.spawn(map_fn, args=(flags,), nprocs=8, start_method='fork')

위에서 만든 함수를 xmp.spawn 함수를 통해 배치사이즈, 워커(코어수=8개), epochs, seed값을 제공해주면 실제 TPU로 해당 코드가 컴파일 되어 전달된 뒤 학습이 진행된다.

References

Your browser is out-of-date!

Update your browser to view this website correctly. Update my browser now

×