-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinference.py
More file actions
49 lines (41 loc) · 1.51 KB
/
inference.py
File metadata and controls
49 lines (41 loc) · 1.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import lightning as L
from cyclopts import App
from lightning.pytorch import seed_everything
from src.configs import ModelPredConfig
from src.data_module import AmongUsDatamodule
from src.models.fcos_pretrained import ModelFcosPretrained
from src.utils import create_output
app = App(name="Define Config for inferencing:")
@app.command
def run_inference(cfg: ModelPredConfig = ModelPredConfig()):
"""
Run inference on all images in a directory and save results with bounding boxes
Args:
image_dir: directory containing images to process
output_dir: directory to save output images with bounding boxes
checkpoint_path: path to the model checkpoint file
confidence_threshold: confidence threshold for drawing boxes
"""
inference_cfg = cfg.inference_cfg
seed_everything(inference_cfg.seed)
# initialize Datamodule
data_module = AmongUsDatamodule(
cfg.datamodule_cfg, cfg.creation_cfg, cfg.transform_cfg
)
trainer = L.Trainer(
accelerator="gpu",
enable_progress_bar=True,
)
model = ModelFcosPretrained.load_from_checkpoint(
inference_cfg.checkpoint, weights_only=False
)
output = trainer.predict(model=model, datamodule=data_module)
images_paths, preds = [item[0] for item in output], [item[1] for item in output]
create_output(
images_paths,
preds,
(cfg.transform_cfg.width, cfg.transform_cfg.height),
cfg.datamodule_cfg.pred_output,
)
if __name__ == "__main__":
app()