-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
35 lines (25 loc) · 909 Bytes
/
utils.py
File metadata and controls
35 lines (25 loc) · 909 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import zipfile
import tensorflow as tf
from config import *
def safe_remove(file_name):
"""Pythonic remove-if-exists."""
try:
os.remove(file_name)
except OSError:
pass
def get_extension(local_file: str) -> str:
"""Extract the file extension of a file."""
return local_file.rsplit(".", 1)[1].lower()
def load_model():
"""Load a model from file into memory for inference."""
tf_trackable = tf.saved_model.load(TF_MODEL_PATH)
model = tf_trackable.signatures.get("serving_default")
model_type = "tensorflow"
return model, model_type
def unzip(local_file):
"""Unzip a file if it has a zip extension, otherwise leave as is."""
extension = get_extension(local_file)
if extension == "zip":
with zipfile.ZipFile(local_file, "r") as zip_ref:
zip_ref.extractall(MODEL_FOLDER)
print(">>> Successfully unzipped model.")