Skip to content
Closed
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions computer_vision/vision_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""
Vision Transformer (ViT) Module
================================

Classify images using a pretrained Vision Transformer (ViT)
from Hugging Face Transformers.
"""

from io import BytesIO
from typing import Optional

import requests
import torch
from PIL import Image, UnidentifiedImageError
from transformers import ViTForImageClassification, ViTImageProcessor


def classify_image(image: Image.Image) -> str:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As there is no test file in this pull request nor any test function or class in the file computer_vision/vision_transformer.py, please provide doctest for the function classify_image

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As there is no test file in this pull request nor any test function or class in the file computer_vision/vision_transformer.py, please provide doctest for the function classify_image

"""Classify a PIL image using pretrained ViT."""
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224")

inputs = processor(images=image, return_tensors="pt")

with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits

predicted_class_idx = logits.argmax(-1).item()
return model.config.id2label[predicted_class_idx]


def demo(url: Optional[str] = None) -> None:

Check failure on line 33 in computer_vision/vision_transformer.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (UP045)

computer_vision/vision_transformer.py:33:15: UP045 Use `X | None` for type annotations

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As there is no test file in this pull request nor any test function or class in the file computer_vision/vision_transformer.py, please provide doctest for the function demo

"""
Run a demo using a sample image or provided URL.

Args:
url (Optional[str]): URL of the image. If None, uses default cat image.
"""
if url is None:
url = (
"https://images.unsplash.com/photo-1592194996308-7b43878e84a6"
) # default example image

try:
response = requests.get(url, timeout=10)
response.raise_for_status()
image = Image.open(BytesIO(response.content))
except (requests.RequestException, UnidentifiedImageError) as e:
print(f"Failed to load image from {url}. Error: {e}")
return

label = classify_image(image)
print(f"Predicted label: {label}")


if __name__ == "__main__":
demo()
Loading