-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathinference.py
More file actions
111 lines (88 loc) · 3.93 KB
/
inference.py
File metadata and controls
111 lines (88 loc) · 3.93 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import json
import pandas as pd
import torch
import argparse
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from tqdm import tqdm
def load_data(jsonl_file):
# Load the dataset from JSONL file
data = []
with open(jsonl_file, 'r', encoding='utf-8') as f:
for line in f:
try:
data.append(json.loads(line))
except json.JSONDecodeError:
print("Warning: Skipping invalid JSON line")
# Filter for required fields, falling back to abstract_extra when abstract is null
filtered_data = []
for item in data:
if 'title' in item and 'abstract' in item:
# use abstract if not null, otherwise fallback to abstract_extra
if item['abstract'] is not None:
filtered_data.append(item)
elif item.get('abstract_extra') is not None:
item['abstract'] = item['abstract_extra']
filtered_data.append(item)
return filtered_data
def predict(model_path, data):
# Load model and tokenizer
model = AutoModelForSequenceClassification.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
# Set device
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f"Using device: {device}")
model.to(device)
model.eval()
# Get id2label mapping from the model config
id2label = model.config.id2label
results = []
for item in tqdm(data, desc="Processing"):
# Prepare input text
text = f"Title: {item.get('title', '')}\nAbstract: {item.get('abstract', '')}"
# Tokenize
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
inputs = {k: v.to(device) for k, v in inputs.items()}
# Generate prediction
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probabilities = torch.nn.functional.softmax(logits, dim=1)
prediction = torch.argmax(logits, dim=1).item()
# Create result object
result = {
'paperId': item.get('paperId', ''),
'title': item.get('title', ''),
"abstract": item.get('abstract', ''),
'predicted_label': id2label[prediction],
'confidence': probabilities[0][prediction].item()
}
results.append(result)
return results
def main(model_path, input_jsonl, output_jsonl):
# Load data
data = load_data(input_jsonl)
print(f"Loaded {len(data)} papers for inference")
# Make predictions
results = predict(model_path, data)
# Save results
with open(output_jsonl, 'w', encoding='utf-8') as f:
for result in results:
f.write(json.dumps(result) + '\n')
print(f"Predictions saved to {output_jsonl}")
# Print summary
predictions = [r['predicted_label'] for r in results]
bionlp_count = predictions.count('Y')
non_bionlp_count = predictions.count('N')
print(f"Summary:")
print(f" BioNLP papers: {bionlp_count} ({bionlp_count/len(predictions)*100:.1f}%)")
print(f" Non-BioNLP papers: {non_bionlp_count} ({non_bionlp_count/len(predictions)*100:.1f}%)")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run inference using a trained BioNLP classification model.")
parser.add_argument("--model_path", type=str, required=True,
help="Path to the trained model directory")
parser.add_argument("--input_jsonl", type=str, required=True,
help="The JSONL file containing papers to classify")
parser.add_argument("--output_jsonl", type=str, required=True,
help="The output JSONL file where predictions will be written")
args = parser.parse_args()
main(args.model_path, args.input_jsonl, args.output_jsonl)