본문 바로가기
Python/Application

[Data Scraping] Error : Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

by 하람 Haram 2023. 2. 10.
728x90

Error 내용

 

Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

 

-> data는 다음과 같이

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

image = np.array(cv2.imread('./sample_data/Black_porgy.JPG'))
image = transform(image=image)['image']
image = image.unsqueeze(0)
image = image.to(device, dtype=torch.float32)

device 즉, cuda에 올렸지만 model은 cuda에 올리지 않아서 발생한 문제

 

기존 코드

class TempModel(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.efficientnet = timm.create_model('efficientnet_b4', pretrained = True, num_classes = num_classes, drop_rate=0.5, act_layer = nn.ReLU)

    def forward(self, x):
        x = self.efficientnet(x)
        return x

file_path = './fish_EfficientNetB4_best_epoch23_0.7847.pth'

model = TempModel(num_classes =12)
model.load_state_dict(torch.load(file_path))
model.eval()

이걸 device에 올려주도록 하자

 

 

728x90