이 글을 읽기에 앞서 다음 글을 읽어보는 것을 추천한다
https://roki9914.tistory.com/7
이번 글에서는 모듈에 hook과 apply를 사용하는 법을 알아보자
hook은 이미 만들어진(패키지화된) 코드에 다른 프로그래머가 자신의 custom code를 실행시킬 수 있게 해주는 기능이다.
hook을 추가하는 함수는 여러가지가 있지만 이번에는 reigster_hook(), register_forward_hook(), register_forward_pre_hook(), register_full_backward_hook()에 대해서만 알아보자.
Register
1. register_hook()
register_hook() 함수를 통해 tensor에 hook을 설치할 수 있다.
tensor = torch.rand(1, requires_grad=True)
def tensor_hook(grad):
pass
tensor.register_hook(tensor_hook)
tensor._backward_hooks
>>> OrderedDict([(0, <function __main__.tensor_hook(grad)>)])
tensor에 hook을 설치한 후, _backward_hooks를 통해 hook을 출력하면 OrderedDict이 나오는 것을 볼 수 있다.
2. register_forward_hook(), register_forward_pre_hook()
register_forward_hook() 함수는 모듈의 forward 함수가 결과를 출력할때 실행될 함수를 hook에 추가해준다.
반면, register_forward_pre_hook()은 결과를 출력하지 않더라도 forward 함수가 실행할 차례가 되면 실행되는 함수를 hook에 추가해준다
이전 게시물에 있던 코드를 다시 가져와보자
class Add(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x1, x2):
output = torch.add(x1, x2)
return output
여기에 다음의 두 hook을 정의하고 register_forward_hook()과 register_forward_pre_hook()를 이용해 각각 hook에 추가해주자
result=[] #hook의 작동 결과를 모아보기 위해
def pre_hook(module, input):
answer.extend([input[0], input[1]])
def hook(module, input, output):
answer.append(output)
add.register_forward_hook(hook)
add.register_forward_pre_hook(pre_hook)
그리고 다음과 같이 Add 모듈을 실행해주면
add=Add()
x1 = torch.rand(1)
x2 = torch.rand(1)
output = add(x1, x2)
print(result, "\n--------------------\n",[x1, x2, output])
>>>
[tensor([0.3345]), tensor([0.9280]), tensor([1.2625])]
--------------------
[tensor([0.3345]), tensor([0.9280]), tensor([1.2625])]
함수를 실행하는 중에 나오는 값들을 hook들이 result 리스트로 잘 전달한 것을 볼 수 있다.
3. register_full_backward_hook()
register_full_backward_hook()은 모델이 backpropagation을 할 때 실행되는 hook을 추가해준다.
먼저 다음과 같은 모듈을 생성해보자
class Model(nn.Module):
def __init__(self):
super().__init__()
self.W = Parameter(torch.Tensor([5]))
def forward(self, x1, x2):
output = x1 * x2
output = output * self.W
return output
model = Model()
그리고 register_full_backward_hook()에 넣을 함수를 만들고 hook에 추가하자
result = [] #register_full_backward_hook의 결과를 확인
def module_hook(module, grad_input, grad_output):
answer.extend([grad_input[0], grad_input[1], grad_output[0]])
model.register_full_backward_hook(module_hook)
그리고 다음과 같이 Model 모듈을 실행해주면
x1 = torch.rand(1, requires_grad=True)
x2 = torch.rand(1, requires_grad=True)
output = model(x1, x2)
output.retain_grad()
output.backward()
print(result, "\n----------------------------\n", [x1.grad, x2.grad, output.grad])
>>>
[tensor([3.8946]), tensor([0.2615]), tensor([1.])]
----------------------------
[tensor([3.8946]), tensor([0.2615]), tensor([1.])]
backpropagation 과정중 생성된 gradient들이 제대로 전달된 것을 볼 수 있다.
Apply
pandas에서 dataframe 전체에 영향을 주기 위해 apply를 쓰는 것처럼 모듈 전체에 영향을 주기 위해 apply를 쓸 수 있다.
먼저 저번 포스팅에서 예시로 들었던 Model - Layer - Function 구조를 가지는 모듈을 가져와봤다.
class Function_A(nn.Module):
def __init__(self, name):
super().__init__()
self.name = name
self.W = Parameter(torch.rand(1))
def forward(self, x):
return x + self.W
class Function_B(nn.Module):
def __init__(self, name):
super().__init__()
self.name = name
self.W = Parameter(torch.rand(1))
def forward(self, x):
return x - self.W
class Function_C(nn.Module):
def __init__(self, name):
super().__init__()
self.name = name
self.W = Parameter(torch.rand(1))
def forward(self, x):
return x * self.W
class Function_D(nn.Module):
def __init__(self, name):
super().__init__()
self.name = name
self.W = Parameter(torch.rand(1))
def forward(self, x):
return x / self.W
# Layer
class Layer_AB(nn.Module):
def __init__(self):
super().__init__()
self.a = Function_A('plus')
self.b = Function_B('substract')
def forward(self, x):
x = self.a(x)
x = self.b(x)
return x
class Layer_CD(nn.Module):
def __init__(self):
super().__init__()
self.c = Function_C('multiply')
self.d = Function_D('divide')
def forward(self, x):
x = self.c(x)
x = self.d(x)
return x
# Model
class Model(nn.Module):
def __init__(self):
super().__init__()
self.ab = Layer_AB()
self.cd = Layer_CD()
def forward(self, x):
x = self.ab(x)
x = self.cd(x)
return x
model = Model()
이제 모듈 전체에 적용할 함수를 정의하고 apply를 이용해 적용시켜보자
def print_module(module):
print(module)
print("-" * 30)
returned_module = model.apply(print_module)
그러면 다음의 결과값을 출력한다.
Function_A()
------------------------------
Function_B()
------------------------------
Layer_AB(
(a): Function_A()
(b): Function_B()
)
------------------------------
Function_C()
------------------------------
Function_D()
------------------------------
Layer_CD(
(c): Function_C()
(d): Function_D()
)
------------------------------
Model(
(ab): Layer_AB(
(a): Function_A()
(b): Function_B()
)
(cd): Layer_CD(
(c): Function_C()
(d): Function_D()
)
)
------------------------------
이때, 위의 모델의 구조가 출력되는 순서를 보면 Postorder Traversal방식으로 모듈에 apply 가 전파되는 것을 볼 수 있다.
postorder traversal이란? - https://en.wikipedia.org/wiki/Tree_traversal#Post-order,_LRN
Tree traversal - Wikipedia
From Wikipedia, the free encyclopedia Class of algorithms "Tree search" redirects here. Not to be confused with Search tree. In computer science, tree traversal (also known as tree search and walking the tree) is a form of graph traversal and refers to the
en.wikipedia.org
마지막으로 repr()랑 apply()를 같이 쓴 예시를 보도록 하자
위의 Model 모듈에서 각 Function의 이름인 plus, subtract, multiply, divide를 출력하고 싶다면, 각 function에 extra_repr()를 정의해줘야 한다.
그럴땐 다음과 같이 코드를 짜면 된다.
def function_repr(self):
return f'name={self.name}'
def add_repr(module):
module_name = module.__class__.__name__
if "Function" in module_name:
module.extra_repr = partial(function_repr, module)
returned_module = model.apply(add_repr)
model_repr=repr(model)
이렇게 되면 각 function에 extra_repr를 추가할 수 있게 되고, model_repr를 출력하면 다음과 같은 결과를 얻는다
print(model_repr)
>>>
Model(
(ab): Layer_AB(
(a): Function_A(name=plus)
(b): Function_B(name=substract)
)
(cd): Layer_CD(
(c): Function_C(name=multiply)
(d): Function_D(name=divide)
)
)
'AI 기초이론' 카테고리의 다른 글
Pytorch Dataloader에 대해 알아보자 - 2 (1) | 2023.03.18 |
---|---|
Pytorch Dataloader에 대해 알아보자 - 1 (0) | 2023.03.18 |
Pytorch로 간단한 모듈 만들기 - 1 (0) | 2023.03.17 |
Pytorch의 기본적 기능 (0) | 2023.03.15 |
CNN, RNN이 무엇인지 간단하게 알아보자 (0) | 2023.03.10 |
댓글