VisTabNet is a powerful Vision Transformer-based Tabular Data Classifier that leverages the strength of transformer architectures for tabular data classification tasks.
It is a proof of concept for VisTabNet: Adapting Vision Transformers for Tabular Data publication.
- Vision Transformer architecture adapted for tabular data
- Simple and intuitive API similar to scikit-learn
- GPU acceleration support
- Automatic handling of numerical features
- Built-in evaluation metrics
- Compatible with pandas DataFrames and numpy arrays
You can install VisTabNet using pip:
pip install vistabnetHere's a simple example to get you started:
from vistabnet import VisTabNetClassifier
import numpy as np
from sklearn.metrics import balanced_accuracy_score
from sklearn.model_selection import train_test_split
# Prepare your data
X_train, y_train, X_test, y_test = ... # Load your data here
# Note: y should be label encoded, not one-hot encoded
# Initialize the model
model = VisTabNetClassifier(
input_features=X_train.shape[1],
classes=len(np.unique(y_train)),
device="cuda" # Use "cpu" if no GPU is available
)
# Train the model
model.fit(
X_train,
y_train,
eval_X=X_test,
eval_y=y_test
)
# Make predictions
y_pred = model.predict(X_test)
# Evaluate the model
accuracy = balanced_accuracy_score(y_test, y_pred)
print(f"Balanced accuracy: {accuracy}")You can customize the VisTabNet model by adjusting various parameters:
model = VisTabNetClassifier(
input_features=X_train.shape[1],
classes=len(np.unique(y_train)),
hidden_dim=256,
num_layers=6,
num_heads=8,
device="cuda"
)- Python ≥ 3.9
- PyTorch ≥ 2.0
- torchvision ≥ 0.15.0
- tqdm ≥ 4.65.0
- focal-loss-torch ≥ 0.1.2
Contributions are welcome! Please feel free to submit a Pull Request.
This project is licensed under the MIT License - see the LICENSE file for details.
If you use VisTabNet in your research, please cite:
@misc{wydmański2024vistabnetadaptingvisiontransformers,
title={VisTabNet: Adapting Vision Transformers for Tabular Data},
author={Witold Wydmański and Ulvi Movsum-zada and Jacek Tabor and Marek Śmieja},
year={2024},
eprint={2501.00057},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2501.00057},
}For questions and support, please open an issue in the GitHub repository.