-
Notifications
You must be signed in to change notification settings - Fork 22
Expand file tree
/
Copy pathpytorch_score.py
More file actions
61 lines (47 loc) · 1.59 KB
/
pytorch_score.py
File metadata and controls
61 lines (47 loc) · 1.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# Copyright (c) Microsoft. All rights reserved.
# Licensed under the MIT license.
import torch
import torch.nn as nn
from torchvision import transforms
import json
import base64
from io import BytesIO
from PIL import Image
import os
import pickle
from azureml.core.model import Model
def preprocess_image(image_file):
"""Preprocess the input image."""
data_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image = Image.open(image_file)
image = data_transforms(image).float()
image = torch.tensor(image)
image = image.unsqueeze(0)
return image
def base64ToImg(base64ImgString):
base64Img = base64ImgString.encode('utf-8')
decoded_img = base64.b64decode(base64Img)
return BytesIO(decoded_img)
def init():
global model, classes
model_path = Model.get_model_path('model')
model = torch.load(os.path.join(model_path,'model.pt'), map_location=lambda storage, loc: storage)
model.eval()
pkl_file = open(os.path.join(model_path,'class_names.pkl'), 'rb')
classes = pickle.load(pkl_file)
pkl_file.close()
def run(input_data):
img = base64ToImg(json.loads(input_data)['data'])
img = preprocess_image(img)
# get prediction
output = model(img)
softmax = nn.Softmax(dim=1)
pred_probs = softmax(model(img)).detach().numpy()[0]
index = torch.argmax(output, 1)
result = json.dumps({"label": classes[index], "probability": str(pred_probs[index])})
return result