728x90
Error 내용
Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same
-> data는 다음과 같이
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
image = np.array(cv2.imread('./sample_data/Black_porgy.JPG'))
image = transform(image=image)['image']
image = image.unsqueeze(0)
image = image.to(device, dtype=torch.float32)
device 즉, cuda에 올렸지만 model은 cuda에 올리지 않아서 발생한 문제
기존 코드
class TempModel(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.efficientnet = timm.create_model('efficientnet_b4', pretrained = True, num_classes = num_classes, drop_rate=0.5, act_layer = nn.ReLU)
def forward(self, x):
x = self.efficientnet(x)
return x
file_path = './fish_EfficientNetB4_best_epoch23_0.7847.pth'
model = TempModel(num_classes =12)
model.load_state_dict(torch.load(file_path))
model.eval()
이걸 device에 올려주도록 하자
728x90
'Python > Application' 카테고리의 다른 글
[Python] Streamlit 활용 전 정리 (0) | 2024.06.05 |
---|---|
[Data Scraping] 구글에서 고화질 이미지 스크래핑(크롤링) - selenium (0) | 2023.01.27 |
[Scraping] Web Crowling 이라기 보단 Scraping (beautifulsoup4) (0) | 2023.01.24 |
[Scraping] HTML, XPath, Requests,정규식 찍먹해보기 (+User-agent) (1) | 2023.01.19 |