Deep Learning/Pytorch

Pytorch 코드 컨셉

판교데싸 2022. 11. 23. 00:19

파이 토치 기본 스타일

- 모듈 클래스로 신경망 생성

- 데이터셋 클래스로 데이터를 불러와 학습 

 

* 모듈 클래스로 신경망 만들기

__init__() -> 신경망 구성요소 정의
forward() -> 신경망 동작 정의

1. 파이 토치가 제공하는 모듈을 불러와 __init__()함수 안에 정의

Class Net(nn.module):
	def __init__ (self):
    
    
    # 신경망 구성요소 정의

 


2. forward() 함수에 신경망의 동작을 정의(__init__()함수에서 정의한 모듈을 연결하거나 필요한 연산등을 정의)

def forward(self, input):
	
    #신경망의 동작 정의
    
    return output

 

* 데이터셋 클래스로 데이터를 불러와 학습하기

1. __init__함수는 학습에 사용할 데이터를 불러온다
2. __len__함수는 데이터 개수를 반환
3. __getitem__함수는 우리가 지정한 i번째 입력 데이터와 정답을 반환

Class Dataset():
	def __init__(self):
    
    
    # 필요한 데이터 불러오기
    
    
    def __len__(self):
    
    # 데이터의 개수 반환
    
    
    	return len(data)
        
        
    def __getitem__(self, i):
    	
        #i번째 입력 데이터와 i번째 정답을 반환
        
        
        return data[i], label[i]

 

*  모듈 클래스와 데이터셋 클래스를 이용한 딥러닝 학습의 뼈대

1. 파이토치는 학습에 사용할 입력데이터와 정답을 불러오는 데이터 로더를 제공

2. 데이터로더는 데이터셋 클래스를 입력으로 받아 학습에 필요한 양 만큼의 데이터를 불러오는 역할을 수행

3. 데이터로더로부터 데이터와 정답을 불러와 신경망의 예측값을 계산 (여기서의 신경망은 파이토치 모듈)

3-1  (1)예측값을 계산 했다면 (2)손실함수를 이용해 신경망의 오차를 계산하고 (3)파이토치의 backward() 메서드를 활용하여 오차를 역전파한 다음, (4)step()메서드를 활용해 신경망 가중치를 수정

for data, label in DataLoader():
	# 1 모델의 예측값 계산
    prediction = model(data)
    
    # 2. 손실 함수를 이용해 오차 계싼
    
    loss = LossFunction(prediction , label)
    
    # 3. 오차 역전파
    loss.backward()
    
    # 4. 신경망 가중치 수정
    optimizer.step()
반응형

'Deep Learning > Pytorch' 카테고리의 다른 글

Pytorch 의 nn.module 상속의미  (0) 2022.11.19
반응형