Skip to content
Closed
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions computer_vision/vision_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""
Vision Transformer (ViT) Module
================================

Classify images using a pretrained Vision Transformer (ViT)
from Hugging Face Transformers.

Can be used as a demo or imported in other scripts.

Source:
https://huggingface.co/docs/transformers/model_doc/vit
"""

try:
import requests
import torch
from io import BytesIO
from PIL import Image
from transformers import ViTForImageClassification, ViTImageProcessor

Check failure on line 19 in computer_vision/vision_transformer.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

computer_vision/vision_transformer.py:15:5: I001 Import block is un-sorted or un-formatted
except ImportError as e:
raise ImportError(
"This module requires 'torch', 'transformers', 'PIL', and 'requests'. "
"Install them with: pip install torch transformers pillow requests"
) from e


def classify_image(image: Image.Image) -> str:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As there is no test file in this pull request nor any test function or class in the file computer_vision/vision_transformer.py, please provide doctest for the function classify_image

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As there is no test file in this pull request nor any test function or class in the file computer_vision/vision_transformer.py, please provide doctest for the function classify_image

"""Classify a PIL image using pretrained ViT."""
processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224")

inputs = processor(images=image, return_tensors="pt")

with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits

predicted_class_idx = logits.argmax(-1).item()
return model.config.id2label[predicted_class_idx]


def demo(url: str = None) -> None:

Check failure on line 42 in computer_vision/vision_transformer.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (RUF013)

computer_vision/vision_transformer.py:42:15: RUF013 PEP 484 prohibits implicit `Optional`

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As there is no test file in this pull request nor any test function or class in the file computer_vision/vision_transformer.py, please provide doctest for the function demo

"""
Run a demo using a sample image or provided URL.

Check failure on line 45 in computer_vision/vision_transformer.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (W293)

computer_vision/vision_transformer.py:45:1: W293 Blank line contains whitespace
Args:
url (str): URL of the image. If None, uses a default cat image.
"""
if url is None:
url = "https://images.unsplash.com/photo-1592194996308-7b43878e84a6" # default example image

Check failure on line 50 in computer_vision/vision_transformer.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E501)

computer_vision/vision_transformer.py:50:89: E501 Line too long (101 > 88)

try:
response = requests.get(url, timeout=10)
response.raise_for_status()
image = Image.open(BytesIO(response.content))
except Exception as e:

Check failure on line 56 in computer_vision/vision_transformer.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (BLE001)

computer_vision/vision_transformer.py:56:12: BLE001 Do not catch blind exception: `Exception`
print(f"Failed to load image from {url}. Error: {e}")
return

label = classify_image(image)
print(f"Predicted label: {label}")


if __name__ == "__main__":
demo()
Loading