-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
134 lines (108 loc) · 3.92 KB
/
main.py
File metadata and controls
134 lines (108 loc) · 3.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import os
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense
base_dir = 'images_processed'
train_dir = os.path.join(base_dir, 'train')
test_dir = os.path.join(base_dir, 'test')
train_datagen = ImageDataGenerator(
rescale=1./255,
rotation_range=40,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest'
)
test_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_directory(
train_dir,
target_size=(150, 150),
batch_size=20,
class_mode='binary'
)
test_generator = test_datagen.flow_from_directory(
test_dir,
target_size=(150, 150),
batch_size=20,
class_mode='binary'
)
steps_per_epoch_train = len(train_generator)
# Calculate steps_per_epoch based on the number of batches in your test set
steps_per_epoch_test = len(test_generator)
model = Sequential()
model.add(Conv2D(32, (3, 3), activation='relu', input_shape=(150, 150, 3)))
model.add(MaxPooling2D(2, 2))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(2, 2))
model.add(Conv2D(128, (3, 3), activation='relu'))
model.add(MaxPooling2D(2, 2))
model.add(Conv2D(128, (3, 3), activation='relu'))
model.add(MaxPooling2D(2, 2))
model.add(Flatten())
model.add(Dense(512, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
# Train the model
history = model.fit(
train_generator,
steps_per_epoch=steps_per_epoch_train,
epochs=20,
validation_data=test_generator,
validation_steps=steps_per_epoch_test
)
model.save("cataract_detection_model.h5")
# Evaluate the model
test_loss, test_acc = model.evaluate(test_generator)
print(f'Test accuracy: {test_acc}')
from flask import Flask, request, jsonify
from werkzeug.utils import secure_filename
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.preprocessing import image
import numpy as np
import os
import json
import re
app = Flask(__name__)
current_dir = os.path.dirname(os.path.abspath(__file__))
model_path = 'cataract_detection_model.h5'
loaded_model = load_model(model_path)
# Classes for binary classification
classes = ['Normal', 'Cataract']
def preprocess_image(file_path):
img = image.load_img(file_path, target_size=(150, 150))
img_array = image.img_to_array(img)
img_array = np.expand_dims(img_array, axis=0)
img_array /= 255.0
return img_array
@app.route('/predict', methods=['POST'])
def predict():
try:
if 'file' not in request.files:
return jsonify({"error": "Request Error"})
file = request.files['file']
if file.filename == '':
return jsonify({"error": "No selected file"})
if file:
# Save the uploaded file temporarily
filename = secure_filename(file.filename)
file_path = os.path.join("temp", filename)
file.save(file_path)
# Preprocess the image
input_data = preprocess_image(file_path)
# Make predictions using the loaded model
predictions = loaded_model.predict(input_data)
predicted_class = int(np.round(predictions[0][0]))
percentage = float(predictions[0][0])
os.remove(file_path)
result = {
"percentage": percentage,
"class": classes[predicted_class]
}
return jsonify(result)
except Exception as e:
return jsonify({"error": str(e)}), 500
if __name__ == "__main__":
app.run(host="0.0.0.0", port=8080, debug=True)