지난 시간, 수식을 통한 오차역전파법에 대하여 이해해 보았습니다.
이번 시간에서는 계산 그래프를 통한 역전파에 대하여 알아보도록 하겠습니다!
오차역전파법을 위한 계산 그래프
일전에 수식으로 풀어본 오차역전파법은 수학을 오랫동안 놓았거나 수식으로만 생각하면 본질을 놓칠 우려가 있습니다. 이번에 우리가 해볼 내용은 계산 그래프를 이용해 오차역전파법을 이해하는 것인데요, 수식으로 오차역전파법을 이해하는 것보다는 약간은 부정확할 수 있으나 최종적으로는 수식으로 알아본 오차역전파법을 이해할 수 있고, 실제 코드 구현까지 해보도록 하겠습니다. 계산 그래프로 설명한다는 아이디어는 안드레 카패스의 블로그 또 그와 페이페이 리 교수가 진행한 스탠퍼드 대학교 딥러닝 수업 CS321n을 참고했습니다.
계산 그래프
계산 그래프(computational graph)는 계산 과정을 그래프로 그려낸 것입니다. 그래프는 우리가 잘 아는 그래프 자료 구조 형태로 되어 있으며, 처음에 쉽게 접근하기 위해 계산 그래프를 통한 간단한 문제를 풀어보도록 하겠습니다. 먼저 익숙해지자!라는 이야기입니다. 예를 들어 다음과 같은 예시가 있다고 하죠, "A라는 사람이 1개 100원인 사과를 2개 샀습니다. 이때 지불 금액을 구하세요, 단 소비세 10%가 부과됩니다."라는 예시를 계산그래프로 표현하면 다음과 같아집니다.
처음에 사과의 100원이 'x 2' 노드로 흘러 200원이 된 다음 소비세 계산을 위해 'x 1.1' 노드를 거쳐 최종적으로는 220원이 됩니다. 위 그래프에 따르면 최종 답은 220원이 된다는 사실을 알 수 있네요 위의 그림에서는 계산 노드를 각각 'x 2', 'x 1.1'로 표현했지만 '2'와 '1.1'을 각각 사과의 개수와 소비세에 대한 변수가 되기 때문에 따로 빼서 다음과 같이 표기할 수 있습니다.
그럼 다음 문제를 풀어 보도록 하겠습니다.
"A가 사과를 2개, 귤을 3개 샀습니다. 사과는 1개에 100원, 귤은 1개 150원입니다. 소비세가 10% 부과될 때 A가 지불해야 할 금액은?" 위 문제도 계산그래프로 풀어볼 수 있습니다. 이때의 계산 그래프는 다음과 같겠네요!
위 문제에서는 새로운 노드인 덧셈 노드가 추가되었습니다. 덧셈 노드가 추가되어 사과의 가격과 귤의 가격을 합치는 모습이 보이고 있습니다. 왼쪽에서 오른쪽으로 순차적으로 계산을 끝내고 제일 마지막에 1.1을 곱하면 우리가 원하는 값인 715원이 나오고 끝나게 됩니다. 계산 그래프를 이용한 문제풀이는 다음과 같이 해석할 수 있습니다.
- 계산 그래프를 구성
- 그래프에서 계산을 왼쪽에서 오른쪽으로 진행
이처럼 '계산을 왼쪽에서 오른쪽으로 진행'하는 단계를 순전파(forward propagation)라고 합니다. 순전파는 계산 그래프의 출발점부터 종착점으로의 전파단계를 그려줍니다. 역전파(backword propagation)는 무엇일까요? 바로 '오른쪽에서 왼쪽으로 전파되는 단계를 의미합니다!
국소적 계산
계산 그래프의 특징은 '국소적 계산'을 전파함으로써 최종 결과를 얻는다는 점에 있습니다. 여기서 '국소적'이란, "자신과 직접 관계된 작은 범위"를 의미하는데, 뭔가 떠오르지 않으시나요? 수학으로 따지면 바로 편미분을 의미한다는 것입니다. 즉, 국소적 계산은 전체에서 어떤 일이 벌어지든 상관없이 자신과 관계된 정보만을 토대로 결과를 낼 수 있다는 이야기입니다. 구체적인 예를 들어 보겠습니다. 여러분이 마트에서 사과 2개를 포함한 여러 가지의 물품들을 구매하는 상황을 구해보겠습니다. 그렇다면 사과에 대한 국소적 계산을 진행한다고 이해할 수 있는데요, 그래프로 확인해 보겠습니다.
위 그림에서 여러 식품을 구매하여( 복잡한 계산을 하여) 4,000원이라는 금액이 나왔고, 여기에 사과 가격인 200원을 더해 총 4,200원이 나왔습니다. 이는 '사과에 대한 국소적 계산'이기 때문에, 4,000원이 어떻게 나왔는지는 전혀 신경 쓸게 없다는 이야기가 됩니다. 그냥 단순히 복잡한 계산의 결과물인 4,000원과 사과의 가격인 200원을 더해 4,200을 알아내면 된다는 것이죠. 중요한 점은 계산 그래프는 이처럼 국소적 계산에 집중한다는 것입니다. 전체 계산 자체가 아무리 복잡해도 각 단계에서 하는 일은 해당 노드의 '국소적 계산'일뿐입니다. 국소적 계산은 단순하지만 그 결과를 전달함으로써 전체를 구성하는 복잡한 계산을 해낼 수 있습니다. 마치 자동차 조립을 하는 것과 비슷한데요, 각각의 부품을 복잡하게 만들어 내고, 최종적으로 합쳐 차를 완성하는 단계라고 볼 수 있습니다.
계산 그래프를 사용하는 이유
계산 그래프의 이점은 무엇일까요? 바로 국소적 계산입니다. 전체가 아무리 복잡해도 각 노드에서는 단순한 계산에 집중하여 문제를 단순화시킬 수 있기 때문이지요, 또한 계산 그래프는 중간 계산 결과를 모두 보관할 수 있습니다. 에지에 저장되어 있는 숫자들이 그것을 의미하고 있지요, 하지만 이것 때문에 계산 그래프를 사용하진 않습니다! 계산 그래프를 사용하는 가장 큰 이유는 역전파를 통해 '미분'을 효율적으로 계산할 수 있기 때문입니다.
계산 그래프의 역전파 첫 번째 문제에 대한 계산 그래프는 사과 2개를 사서 소비세를 포함한 최종 금액을 구하는 것이었습니다. 여기서 새로운 문제를 제시해 보겠습니다. "사과 가격이 오르면 최종 금액에 어떠한 영향을 미칠 것인가?"가 문제입니다. 즉 이는 사과 가격에 대한 지불 금액의 미분을 구하는 문제에 해당됩니다. 사과 값을 x로, 지불 금액을 L이라 했을 때
로 표현이 가능하다는 것이죠, 즉 이 미분값은 사과 값이 '아주 조금' 올랐을 때 지불 금액이 얼마나 증가하느냐를 표시한 것입니다. 즉, '사과 가격에 대한 지불 금액의 미분' 같은 값은 계산 그래프에서 역전파를 하면 구할 수 있게 됩니다. 다음 그림에서는 계산 그래프 상의 역전파에 의해 미분을 구할 수가 있습니다. 아직 역전파가 어떻게 이뤄지는지에 대해서는 이야기하지 않았습니다!
위 그림에서 굵은 화살표로 역전파를 표현해 보았습니다. 이 전파는 각각 노드에 대한 국소적 미분을 전달합니다. 즉, 들어오고 있는 사과의 개수나 소비세에 대한 국소적으로 미분을 진행하였기 때문에, 소비세와 사과의 개수 같은 변수에 대한 미분만 진행했다는 이야기입니다. 그리고 그 미분값은 화살표 방향으로 적어내고 있습니다. 이 예에서 역전 파는 오른쪽에서 왼쪽으로 '1 -> 1.1 -> 2.2' 순으로 미분값을 전달하고 있습니다. 이 결과로부터 알 수 있는 사실은 '사과 가격에 대한 지불금액이 미분'값은 2.2라는 것을 알 수 있게 됩니다. 즉, 사과 가격이 1원 오르면 최종 가격은 2.2원 오른다는 것이죠. 여기에서는 사과 가격에 대한 미분만 구했지만, '소비세에 대한 지불 금액의 미분'이나 '사과 개수에 대한 지불 금액의 미분'도 같은 순서로 구해낼 수가 있습니다. 그리고 그때는 중간까지 구한 미분 결과를 공유할 수 있어서 다수의 미분을 효율적으로 계산할 수 있습니다. 이처럼 계산 그래프의 이점은 순전파와 역전파를 활용해서 각 변수의 미분을 효율적으로 구할 수 있다는 것입니다.
연쇄법칙과 계산 그래프
연쇄법칙 계산을 계산 그래프로 나타낼 수 있습니다. 2 제곱 계산을 '**2' 노드로 나타내면 다음과 같습니다.
오른쪽에서 왼쪽으로 신호가 전파되는 모습을 볼 수 있습니다. 역전파에서의 계산 절차는 노드로 들어온 입력 신호에 그 노드의 국소적 미분인 편미분을 곱한 후 다음 노드로 전달합니다. 예를 들어 **2 노드에서의 역전파를 보면 입력은 ∂𝑧∂𝑧이며, 이에 대한 국소적 미분인 ∂𝑧∂𝑡를 곱해 다음 노드로 넘깁니다. 맨 왼쪽의 역전파를 보면 x에 대한 z의 미분이 연쇄법칙에 따라서
가 된다는 사실을 알아낼 수 있고, 이를 계산하면
가 된다는 사실을 알아낼 수 있습니다.
지금까지 아주아주 긴 오차역전파법을 위한 계산 그래프를 위한 이해를 수식으로 알아보았습니다! 다음 세션을 통해 최종적으로 코드 구현을 해보겠습니다.
'Programming > Deep Learning' 카테고리의 다른 글
[Python/DeepLearning] #10.5. 역전파) 활성화 함수 계층 구현 (0) | 2024.03.01 |
---|---|
[Python/DeepLearning] #10.4. 역전파) 덧셈 노드와 곱셈 노드 (0) | 2024.02.22 |
[Python/DeepLearning] #10.2. 역전파) 수식을 통한 오차역전파법 이해 (0) | 2024.02.08 |
[Python/DeepLearning] #10.1. 역전파) 합성함수의 미분과 연쇄법칙 (1) | 2024.02.06 |
[Python/DeepLearning] #9.3. MNIST 신경망 구현하기 (2) | 2024.02.02 |