-
-
Notifications
You must be signed in to change notification settings - Fork 48.8k
Add Vision Transformer demo in computer_vision module #13351
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 11 commits
5fa4b88
62af08b
b5b3192
fc2c15a
e5d52d7
fa237c3
451b7ec
1ba4fe1
8ad8cc8
fb29757
6b90016
f30d09f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
""" | ||
Vision Transformer (ViT) Module | ||
================================ | ||
|
||
Classify images using a pretrained Vision Transformer (ViT) | ||
from Hugging Face Transformers. | ||
""" | ||
|
||
from io import BytesIO | ||
from typing import Optional | ||
|
||
import requests | ||
import torch | ||
from PIL import Image, UnidentifiedImageError | ||
from transformers import ViTForImageClassification, ViTImageProcessor | ||
|
||
|
||
def classify_image(image: Image.Image) -> str: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As there is no test file in this pull request nor any test function or class in the file |
||
"""Classify a PIL image using pretrained ViT.""" | ||
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224") | ||
model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224") | ||
|
||
inputs = processor(images=image, return_tensors="pt") | ||
|
||
with torch.no_grad(): | ||
outputs = model(**inputs) | ||
logits = outputs.logits | ||
|
||
predicted_class_idx = logits.argmax(-1).item() | ||
return model.config.id2label[predicted_class_idx] | ||
|
||
|
||
def demo(url: Optional[str] = None) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As there is no test file in this pull request nor any test function or class in the file |
||
""" | ||
Run a demo using a sample image or provided URL. | ||
|
||
Args: | ||
url (Optional[str]): URL of the image. If None, uses default cat image. | ||
""" | ||
if url is None: | ||
url = ( | ||
"https://images.unsplash.com/photo-1592194996308-7b43878e84a6" | ||
) # default example image | ||
|
||
try: | ||
response = requests.get(url, timeout=10) | ||
response.raise_for_status() | ||
image = Image.open(BytesIO(response.content)) | ||
except (requests.RequestException, UnidentifiedImageError) as e: | ||
print(f"Failed to load image from {url}. Error: {e}") | ||
return | ||
|
||
label = classify_image(image) | ||
print(f"Predicted label: {label}") | ||
|
||
|
||
if __name__ == "__main__": | ||
demo() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As there is no test file in this pull request nor any test function or class in the file
computer_vision/vision_transformer.py
, please provide doctest for the functionclassify_image