Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions .github/workflows/main.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ jobs:
main:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.11.6
cache: "pip"
Expand All @@ -18,7 +18,3 @@ jobs:
run: black . --check --diff --color
- name: "isort"
run: isort . --check --diff
- name: "mypy"
run: mypy
- name: "pytests"
run: pytest
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
.mypy_cache/
.pytest_cache/
__pycache__/
.ipynb_checkpoints/
.ipynb_checkpoints/
*/.ipynb_checkpoints/
88 changes: 69 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,27 +1,77 @@
# py_template
# GeRaCl: General Rapid text Classifier

Template repository for Python projects.
Use it to create a new repo, but feel free to adopt for your use-cases.
**GeRaCl** is an open‑source **framework** for building, training, and evaluating efficient zero‑shot text classifiers on top of any BERT‑like sentence-encoder. It is inspired by the [GLiNER](https://github.com/urchade/GLiNER/tree/main) framework.

## Structure
### ✨ Why GeRaCl?

There are several directories to organize your code:
- `src`: Main directory for your modules, e.g., models or dataset implementations, train loops, metrics.
- `scripts`: Directory to define scripts to interact with modules, e.g., run training or evaluation, run data preprocessing, collect statistic.
- `tests`: Directory for tests, this may include multiple unit tests for different parts of logic.
| Feature | What it means for you |
| ------------------------------ | ------------------------------------------------------------------------------------------------- |
| **Zero‑shot by design** | Classify with **arbitrary** label sets that you decide at run‑time — just pass a list of strings. |
| **One forward pass** | As fast as ordinary text classification; no pairwise loops like in NLI‑based approaches. |
| **Model‑agnostic** | Works with any Hugging Face sentence-encoder. |
| **155 M reference checkpoint** | A lean [baseline](https://huggingface.co/deepvk/GeRaCl-USER2-base) (155M parameters) that beats much larger sentence‑encoders (300-500M parameters). |
| **All‑in‑one toolkit** | Training/eval scripts, HF Hub and WandB integration. |

You can create new directories for your need.
For example, you can create a `Notebooks` folder for Jupyter notebooks, such as `EDA.ipynb`.

## Usage
### 🚀 Quick Start

First of all,
navigate to [`pyproject.toml`](./pyproject.toml) and set up `name` and `url` properties according to your project.
Clone and install directly from GitHub:

For correct work of the import system:
1. Use absolute import statements starting from `src`. For example, `from src.model import MySuperModel`
2. Execute scripts as modules, i.e. use `python -m scripts.<module_name>`. See details about `-m` flag [here](https://docs.python.org/3/using/cmdline.html#cmdoption-m).
```bash
git clone https://github.com/deepvk/zero-shot-classification
cd GeRaCl

To keep your code clean, use `black`, `isort`, and `mypy`
(install everything from [`requirements.dev.txt`](./requirements.dev.txt)).
[`pyproject.toml`](./pyproject.toml) already defines their parameters, but you can change them if you want.
pip install -r requirements.txt
```

Verify your installation:

```python
import geracl
print(geracl.__version__)
```

### 🧑‍💻 Usage Examples

#### Single classification scenario

```python
from transformers import AutoTokenizer
from geracl import GeraclHF, ZeroShotClassificationPipeline

model = GeraclHF.from_pretrained('deepvk/GeRaCl-USER2-base').to('cuda').eval()
tokenizer = AutoTokenizer.from_pretrained('deepvk/GeRaCl-USER2-base')

pipe = ZeroShotClassificationPipeline(model, tokenizer, device="cuda")

text = "Утилизация катализаторов: как неплохо заработать"
labels = ["экономика", "происшествия", "политика", "культура", "наука", "спорт"]
result = pipe(text, labels, batch_size=1)[0]

print(labels[result])
```

#### Multiple classification scenarios

```python
from transformers import AutoTokenizer
from geracl import GeraclHF, ZeroShotClassificationPipeline

model = GeraclHF.from_pretrained('deepvk/GeRaCl-USER2-base').to('cuda').eval()
tokenizer = AutoTokenizer.from_pretrained('deepvk/GeRaCl-USER2-base')

pipe = ZeroShotClassificationPipeline(model, tokenizer, device="cuda")

texts = [
"Утилизация катализаторов: как неплохо заработать",
"Мне не понравился этот фильм."
]
labels = [
["экономика", "происшествия", "политика", "культура", "наука", "спорт"],
["нейтральный", "позитивный", "негативный"]
]
results = pipe(texts, labels, batch_size=2)

for i in range(len(labels)):
print(labels[i][results[i]])
```
4 changes: 4 additions & 0 deletions geracl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .model.config import GeraclConfigHF
from .model.geracl import Geracl
from .model.hf_wrapper import GeraclHF
from .pipeline import ZeroShotClassificationPipeline
43 changes: 43 additions & 0 deletions geracl/configs/custom.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
model:
embedder_name: "deepvk/USER2-base"
ffn_dim: 2048
ffn_classes_dropout: 0.4
ffn_text_dropout: 0.4
device: "cuda"
unfreeze_embedder: True
loss_args:
loss_type: "bce"
# init_params:
# alpha:
# gamma:
optimizer_args:
class_path: torch.optim.AdamW
init_params:
lr: 0.000005
weight_decay: 0.1
scheduler_args:
scheduler: "linear"
total_steps: 35310
warmup_steps: 1000

data_module:
batch_size: 32
val_batch_size: 32
tokenizer_name: "deepvk/USER2-base"
config: "real_world_extended_expanded"
model_max_length: 2000
num_workers: 5
include_scenarios: False
input_prompt: "classification: "

trainer:
accelerator: "gpu"
val_check_interval: 7062
max_epochs: 5
log_every_n_steps: 100
# gradient_clip_val: 2.0
#accumulate_grad_batches: 2
# overfit_batches: 50

other:
checkpoints_dir: "/data/checkpoints/release_user2_base_training"
28 changes: 28 additions & 0 deletions geracl/configs/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
model:
embedder_name: "deepvk/USER2-base"
ffn_dim: 2048
device: "cuda"
unfreeze_embedder: False
pooling_type: "mean"
loss_args:
loss_type: "bce"

data_module:
batch_size: 16
val_batch_size: 16
num_workers: 10
tokenizer_name: "deepvk/USER2-base"
config: "synthetic_positives_multiclass"
include_scenarios: False
input_prompt: "classification: "

trainer:
# max_steps: 200
accelerator: "gpu"
val_check_interval: 1000
gradient_clip_val: 0.0
log_every_n_steps: 100

other:
wandb_project: "universal_classifier"
checkpoints_dir: "/data/checkpoints"
2 changes: 1 addition & 1 deletion src/configs/sweeps.yaml → geracl/configs/sweeps.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ data_module:
trainer:
# max_steps: 200
accelerator: "gpu"
val_check_interval: 5810
val_check_interval: 50
gradient_clip_val: 0.0
log_every_n_steps: 50

Expand Down
File renamed without changes.
25 changes: 17 additions & 8 deletions src/data/data_utils.py → geracl/data/batch_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ def make_classifier_prompt(
input_seq: ndarray,
special_token_ids: dict[int],
classes_list: list[ndarray],
positive_labels: list[list[int]] = None,
scenario: ndarray = np.array([], dtype=int),
starting_prompt: ndarray = np.array([], dtype=int),
positive_labels: list[int] = None,
) -> tuple[ndarray, ndarray]:
if positive_labels:
label_mask = [-2] * len(classes_list)
Expand All @@ -26,10 +28,11 @@ def make_classifier_prompt(
for i, (class_name, mask) in enumerate(zip(classes_list, label_mask))
]
)

result_prompt = np.concatenate(
[
[special_token_ids["bos_token"]],
starting_prompt,
scenario,
result_prompt,
[special_token_ids["sep_token"]],
input_seq,
Expand All @@ -40,9 +43,11 @@ def make_classifier_prompt(
extended_label_mask = np.concatenate(
[
np.array([-4]),
np.full(len(starting_prompt), -5, dtype=int),
np.full(len(scenario), -5, dtype=int),
extended_label_mask,
np.array([-4]),
np.full(len(input_seq), -3),
np.full(len(input_seq), -3, dtype=int),
np.array([-4]),
]
)
Expand All @@ -61,7 +66,6 @@ def prepare_batch(
max_len = max(len(res_prompt) for res_prompt in result_prompts)
if model_max_length is not None:
max_len = min(max_len, model_max_length)

input_ids = torch.full((batch_size, max_len), pad_token_id, dtype=torch.long)
attention_mask = torch.zeros((batch_size, max_len), dtype=torch.long)
classes_mask = torch.full((batch_size, max_len), -4, dtype=torch.long)
Expand All @@ -85,9 +89,11 @@ def prepare_batch(
return input_ids, attention_mask, classes_mask


def prepare_inference_batch(input_texts, classes, tokenizer):
def prepare_inference_batch(input_texts, classes, tokenizer, input_prompt=None):
tokenized_texts = tokenizer(input_texts, add_special_tokens=False).input_ids
tokenized_classes = [tokenizer(sample_classes, add_special_tokens=False).input_ids for sample_classes in classes]
if input_prompt:
tokenized_prompt = tokenizer(input_prompt, add_special_tokens=False).input_ids

result_prompts = []
label_masks = []
Expand All @@ -98,9 +104,12 @@ def prepare_inference_batch(input_texts, classes, tokenizer):
"sep_token": tokenizer.sep_token_id,
"eos_token": tokenizer.eos_token_id,
}

for tokenized_text, tokenized_sample_classes in tokenized_texts:
result_prompt, label_mask = make_classifier_prompt(tokenized_text, special_token_ids, tokenized_sample_classes)
if input_prompt is None:
tokenized_prompt = np.array([], dtype=int)
for tokenized_text, tokenized_sample_classes in zip(tokenized_texts, tokenized_classes):
result_prompt, label_mask = make_classifier_prompt(
tokenized_text, special_token_ids, tokenized_sample_classes, starting_prompt=tokenized_prompt
)

result_prompts.append(result_prompt)
label_masks.append(label_mask)
Expand Down
Loading