Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

app.py 3.8 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
  1. import streamlit as st
  2. from PIL import Image, ImageOps
  3. import cv2
  4. import tensorflow as tf
  5. from tensorflow.keras.models import load_model
  6. from potholeClassifier.utils.common import read_yaml
  7. import numpy as np
  8. from pathlib import Path
  9. class Classifier:
  10. def __init__(self, config_file_path: Path):
  11. """
  12. Constructor method for the Classifier class.
  13. Args:
  14. config_file_path (Path): Path to the configuration file.
  15. """
  16. self.config_file_path = config_file_path
  17. self.config = read_yaml(self.config_file_path)
  18. self.class_names = ["No Pothole", "Has Pothole"]
  19. self.model = None
  20. def load_best_model(self) -> None:
  21. """
  22. Method to load the best trained model from the specified path.
  23. """
  24. if self.model is None:
  25. self.model = load_model(self.config.training.trained_model_path)
  26. def import_and_predict(self, image_data: Image) -> np.ndarray:
  27. """
  28. Method to preprocess the image data and make predictions using the loaded model.
  29. Args:
  30. image_data (Image): Input image data.
  31. Returns:
  32. np.array: Predicted class probabilities.
  33. """
  34. if self.model is None:
  35. raise ValueError("Model not loaded. Please load the model first.")
  36. size = (224, 224)
  37. image = ImageOps.fit(image_data, size)
  38. image = np.asarray(image)
  39. img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  40. img_reshape = img[np.newaxis, ...]
  41. prediction = self.model.predict(img_reshape)
  42. return prediction
  43. def classify_image(self, image: Image) -> str:
  44. """
  45. Method to classify the input image.
  46. Args:
  47. image (Image): Input image data.
  48. Returns:
  49. str: Classification result text.
  50. """
  51. predictions = self.import_and_predict(image)
  52. score = tf.nn.softmax(predictions[0])
  53. result_text = "This image most likely belongs to the <b>{}</b> class.".format(
  54. self.class_names[np.argmax(score)])
  55. return result_text
  56. class StreamlitApp:
  57. def __init__(self, classifier: Classifier):
  58. """
  59. Constructor method for the StreamlitApp class.
  60. Args:
  61. classifier (Classifier): Instance of the Classifier class.
  62. """
  63. self.classifier = classifier
  64. def load_model(self) -> None:
  65. """
  66. Method to run the Streamlit application.
  67. """
  68. with st.spinner('Model is being loaded..'):
  69. self.classifier.load_best_model()
  70. def display_intro(self) -> None:
  71. """Display the intro text"""
  72. st.markdown("## Pothole Image Classification", unsafe_allow_html=True)
  73. st.write("""
  74. Potholes are fatal and can cause severe damage to vehicles as well as can cause deadly accidents. In South Asian countries, pavement
  75. distresses are the primary cause due to poor subgrade conditions, lack of subsurface drainage, and excessive rainfalls. This prediction service classifies images to find whether they have potholes or not.
  76. """)
  77. def run(self) -> None:
  78. """
  79. Run the app
  80. """
  81. self.load_model()
  82. self.display_intro()
  83. file = st.file_uploader(
  84. "Please upload the image file", type=[
  85. "jpg", "png"])
  86. if file is None:
  87. st.text("File has not been uploaded yet.")
  88. else:
  89. image = Image.open(file)
  90. st.image(image, use_column_width=True)
  91. result_text = self.classifier.classify_image(image)
  92. st.markdown(result_text, unsafe_allow_html=True)
  93. if __name__ == "__main__":
  94. CONFIG_FILE_PATH = Path("config/config.yaml")
  95. classifier = Classifier(CONFIG_FILE_PATH)
  96. app = StreamlitApp(classifier)
  97. app.run()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...