ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • pytorch 사용 예제
    AI 2021. 6. 17. 17:51
    #라이브러리 불러오기 (torch.nn : 딥러닝 네트워크 구현 및 학습을 간단하게 수행할 수 있도록 다양한 함수 제공)
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    
    #gpu가 있는 경우 pytorch 연산을 gpu로, 그렇지 않은 경우 pytorch 연산을 cpu로 수행
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    #딥러닝 모델
    class Model(nn.Module) :
    	def __init__(self) :	# 네트워크의 변수 정의
    		super(Model, self).__init__()
    		self.fc1 = nn.Linear(3, 4)	#nn.Linear(입력수,출력수)
    		self.fc2 = nn.Linear(4, 4)
    		self.fc3 = nn.Linear(4, 2)
    		
    	def forward(self, x) :	# 네트워크 구조 결정 및 연산 수행
    		x = F.relu(self.fc1(x))	#선형 연산 후 비선형 함수 통과
    		x = F.relu(self.fc2(x))
    		x = self.fc3(x)	# 선형 연산만 수행
    		return x
            
    #네트워크 선언
    model = Model().to(device)	#Model클래스 호출 후 디바이스 할당하고 이를 model로 선언
    	
    #Pytorch를 이용한 연산 수행
    x = torch.rand(2,3).to(device)	#(2,3)의 랜덤 데이터를 torch로 생성 후 연산장치 할당
    				# torch는 pytorch 연산 수행을 위한 데이터 자료형
    output = model(x)
    
    #결과 출력
    print(output)
    
    #Numpy 변환
    output_np = output.cpu().detach().numpy()	# 연산장치를 cpu로 변환
    						# detch : gradient 전파가 안되도록 함
                                                # numpy : tensor를 numpy로 변환
    print(output_np)
    

    numpy  변환을 위해서는 detach를 꼭 해줘야 함

    'AI' 카테고리의 다른 글

    RNN에 기반한 딥러닝 기법들의 예시  (0) 2021.06.22
    CNN에 기반한 딥러닝 기법들의 예시  (0) 2021.06.22
Designed by Tistory.