๋ก๋งจ์ค ์ค์บ ์ฌ์ง๊ณผ, ๊ทธ๋ฅ ์ฌ์ง์ ๊ตฌ๋ณํ ์ ์๋ ViT ๋ชจ๋ธ ์
๋๋ค.
๊ธฐ์กด์ CNN ๋ชจ๋ธ์ ๋นํด ํจ์ ์ฑ๋ฅ์ด ์ข์ต๋๋ค.
์ถํ ๋ฐ์ดํฐ๋ฅผ ์ถ๊ฐํด ์ฑ๋ฅ์ ๋์ฑ ๋๋ฆด๊ฒ ์
๋๋ค.
์ฌ์ฉ ์ฝ๋๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค.
import torch
from transformers import ViTForImageClassification, ViTFeatureExtractor
from PIL import Image
# Hugging Face์์ ๋ชจ๋ธ ๋ฐ ํน์ง ์ถ์ถ๊ธฐ ๋ถ๋ฌ์ค๊ธฐ
model = ViTForImageClassification.from_pretrained("gihakkk/vit_modle")
feature_extractor = ViTFeatureExtractor.from_pretrained("gihakkk/vit_modle")
# ์๋ก์ด ์ด๋ฏธ์ง ์์ธก ํจ์ ์ ์
def predict_image(image_path):
# ์ด๋ฏธ์ง๋ฅผ ๋ก๋ํ๊ณ RGB๋ก ๋ณํ
image = Image.open(image_path).convert("RGB")
# ์ด๋ฏธ์ง๋ฅผ ํน์ง ์ถ์ถ๊ธฐ๋ก ์ ์ฒ๋ฆฌํ์ฌ ๋ชจ๋ธ ์
๋ ฅ ํ์์ผ๋ก ๋ณํ
inputs = feature_extractor(images=image, return_tensors="pt")
# ์์ธก ์ํ
with torch.no_grad():
outputs = model(**inputs).logits
predicted_class = torch.argmax(outputs, dim=-1).item()
return "๊ทธ๋ฅ ์ฌ์ง" if predicted_class == 1 else "๋ก๋งจ์ค ์ค์บ ์ฌ์ง"
# ์์ธก ์์
image_path = r'path\to\your\img.jpg'
result = predict_image(image_path)
print(result)
- Downloads last month
- 0