Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

ui.py 4.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
  1. import io
  2. import pathlib
  3. import textwrap
  4. from enum import Enum
  5. import requests
  6. import streamlit as st
  7. import streamlit.components.v1 as components
  8. from models import PredictionStats
  9. from requests_toolbelt.multipart.encoder import MultipartEncoder
  10. from PIL import Image
  11. # Source: https://github.com/robmarkcole/streamlit-image-juxtapose.git
  12. def juxtapose(img1: str, img2: str, height: int = 1000): # data
  13. """Create a new timeline component.
  14. Parameters
  15. ----------
  16. height: int or None
  17. Height of the timeline in px
  18. Returns
  19. -------
  20. static_component: Boolean
  21. Returns a static component with a timeline
  22. """
  23. # load css + js
  24. cdn_path = "https://cdn.knightlab.com/libs/juxtapose/latest"
  25. css_block = f'<link rel="stylesheet" href="{cdn_path}/css/juxtapose.css">'
  26. js_block = f'<script src="{cdn_path}/js/juxtapose.min.js"></script>'
  27. # write html block
  28. htmlcode = (
  29. css_block
  30. + """
  31. """
  32. + js_block
  33. + """
  34. <div id="foo" style="width: 95%; height: """
  35. + str(height)
  36. + '''px; margin: 1px;"></div>
  37. <script>
  38. slider = new juxtapose.JXSlider('#foo',
  39. [
  40. {
  41. src: "'''
  42. + img1
  43. + '''",
  44. label: 'source',
  45. },
  46. {
  47. src: "'''
  48. + img2
  49. + """",
  50. label: 'prediction',
  51. }
  52. ],
  53. {
  54. animate: true,
  55. showLabels: true,
  56. showCredits: true,
  57. startingPosition: "50%",
  58. makeResponsive: true
  59. });
  60. </script>
  61. """
  62. )
  63. static_component = components.html(
  64. htmlcode,
  65. height=height,
  66. )
  67. return static_component
  68. STREAMLIT_STATIC_PATH = (
  69. pathlib.Path(st.__path__[0]) / "static"
  70. ) # at venv/lib/python3.9/site-packages/streamlit/static
  71. # interact with FastAPI endpoint
  72. backend = "http://backend:8000/segmentation"
  73. # TDOD: refactor to central localtion
  74. class ModelTypes(Enum):
  75. """allowed model types"""
  76. PYTORCH = "pytorch"
  77. ONNX = "onnx"
  78. def process(image: bytes, server_url: str):
  79. m = MultipartEncoder(fields={"file": ("filename", image, "image/jpeg")})
  80. r = requests.post(
  81. server_url, data=m, headers={"Content-Type": m.content_type}, timeout=8000
  82. )
  83. return {
  84. "mask": r.content,
  85. "stats": PredictionStats.parse_obj(r.headers),
  86. }
  87. # construct UI layout
  88. st.title("DeadTree image segmentation")
  89. st.write(
  90. """Obtain semantic segmentation maps of the image in input via our UNet implemented in PyTorch.
  91. Visit this URL at port 8000 for REST API."""
  92. ) # description and instructions
  93. inf_types = {
  94. ModelTypes.PYTORCH: "PyTorch (native)",
  95. ModelTypes.ONNX: "ONNX",
  96. }
  97. col1, col2 = st.beta_columns(2)
  98. itype = col1.selectbox(
  99. "Inference type", list(inf_types.keys()), format_func=inf_types.get
  100. )
  101. vtype = col2.radio("Display", ("Side-by-side", "Slider"), index=1)
  102. input_image = st.file_uploader("Insert Image") # image upload widget
  103. if st.button("Get Segmentation Map"):
  104. if input_image:
  105. result = process(input_image, f"{backend}?model_type={itype.value}")
  106. rgb_image = Image.open(input_image).convert("RGB")
  107. mask_image = Image.open(io.BytesIO(result["mask"])).convert("RGB")
  108. if vtype == "Side-by-side":
  109. col1, col2 = st.beta_columns(2)
  110. col1.header("Source")
  111. col1.image(rgb_image, use_column_width=True)
  112. col2.header("Prediction")
  113. col2.image(mask_image, use_column_width=True)
  114. else:
  115. IMG1 = "source.png"
  116. IMG2 = "prediction.png"
  117. rgb_image.save(STREAMLIT_STATIC_PATH / IMG1)
  118. mask_image.save(STREAMLIT_STATIC_PATH / IMG2)
  119. juxtapose(IMG1, IMG2, height=600)
  120. stats = result["stats"]
  121. st.markdown(
  122. textwrap.dedent(
  123. f"""\
  124. ### Stats 📊
  125. Model: **{stats.model_name}**
  126. Format: **{stats.model_type}**
  127. Percentage of dead trees detected: **{stats.fraction*100:.2f}%**
  128. Inference duration: **{stats.elapsed:.1f}sec**
  129. """ # noqa
  130. )
  131. )
  132. else:
  133. # handle case with no image
  134. st.write("Insert an image!")
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...