Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

server.py 4.7 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
  1. import io
  2. from enum import Enum
  3. from typing import Any, Dict, Optional
  4. from fastapi import FastAPI, File
  5. from pydantic import BaseModel
  6. from starlette.responses import HTMLResponse, Response
  7. import numpy as np
  8. import torch
  9. from deadtrees.data.deadtreedata import val_transform
  10. from deadtrees.deployment.inference import ONNXInference, PyTorchInference
  11. from deadtrees.deployment.models import PredictionStats, predictionstats_to_str
  12. from deadtrees.utils.timer import record_execution_time
  13. from numpy.lib.arraysetops import isin
  14. from PIL import Image
  15. MODEL = "bestmodel"
  16. # TODO: make this an endpoint
  17. pytorch_model = PyTorchInference(f"checkpoints/{MODEL}.ckpt")
  18. onnx_model = ONNXInference(f"checkpoints/{MODEL}.onnx")
  19. app = FastAPI(
  20. title="DeadTrees image segmentation",
  21. description="""Obtain semantic segmentation maps of the image in input via our UNet implemented in PyTorch.
  22. Visit this URL at port 8501 for the streamlit interface.""",
  23. version="0.1.0",
  24. )
  25. @app.get("/", response_class=HTMLResponse, include_in_schema=False)
  26. async def root():
  27. return """\
  28. <!doctype html>
  29. <html lang="en">
  30. <head>
  31. <!-- Required meta tags -->
  32. <meta charset="utf-8">
  33. <meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no">
  34. <!-- Bootstrap CSS -->
  35. <link rel="stylesheet" href="https://stackpath.bootstrapcdn.com/bootstrap/4.3.1/css/bootstrap.min.css" integrity="sha384-ggOyR0iXCbMQv3Xipma34MD+dH/1fQ784/j6cY/iJTQUOhcWr7x9JvoRxT2MZw1T" crossorigin="anonymous">
  36. <title>DeadTrees Inference API</title>
  37. <meta http-equiv="refresh" content="7; URL=./docs" />
  38. </head>
  39. <body>
  40. <div class="d-flex vh-100">
  41. <div class="d-flex w-100 justify-content-center align-self-center">
  42. <div class="jumbotron">
  43. <h1 class="display-4">🌲☠️🌲🌲🌲 DeadTrees Inference API 🌲🌲☠️☠️🌲</h1>
  44. <p class="lead">REST API for semantic segmentation of dead trees from ortho photos</p>
  45. <hr class="my-4">
  46. <p>
  47. There also is an <a href="./" onmouseover="javascript:event.target.port=8502">interactive streamlit frontend</a>. You will be redirected to the <a href="./docs"><b>OpenAPI documentation page</b></a> in 10 seconds.
  48. </p>
  49. </div>
  50. </div>
  51. </div>
  52. <!-- Optional JavaScript -->
  53. <!-- jQuery first, then Popper.js, then Bootstrap JS -->
  54. <script src="https://code.jquery.com/jquery-3.3.1.slim.min.js" integrity="sha384-q8i/X+965DzO0rT7abK41JStQIAqVgRVzpbzo5smXKp4YfRvH+8abtTE1Pi6jizo" crossorigin="anonymous"></script>
  55. <script src="https://cdnjs.cloudflare.com/ajax/libs/popper.js/1.14.7/umd/popper.min.js" integrity="sha384-UO2eT0CpHqdSJQ6hJty5KVphtPhzWj9WO1clHTMGa3JDZwrnQq4sF86dIHNDz0W1" crossorigin="anonymous"></script>
  56. <script src="https://stackpath.bootstrapcdn.com/bootstrap/4.3.1/js/bootstrap.min.js" integrity="sha384-JjSmVgyd0p3pXB1rRibZUAYoIIy6OrQ6VrjIEaFf/nJGzIxFDsf4x0xIM+B07jRM" crossorigin="anonymous"></script>
  57. </body>
  58. </html>
  59. """
  60. class ModelTypes(Enum):
  61. """allowed model types"""
  62. PYTORCH = "pytorch"
  63. ONNX = "onnx"
  64. def split_image_into_tiles(image: Image):
  65. # complete this: what about batches?
  66. batch = val_transform(image=image)["image"]
  67. return batch
  68. @app.post("/segmentation")
  69. def get_segmentation_map(
  70. file: bytes = File(...), model_type: Optional[ModelTypes] = None
  71. ):
  72. """Get segmentation maps from image file"""
  73. model_type = model_type or ModelTypes.PYTORCH
  74. image = Image.open(io.BytesIO(file)).convert("RGB")
  75. input_tensor = val_transform(image=np.array(image))["image"]
  76. # call prediction and measure execution time
  77. with record_execution_time() as elapsed:
  78. if model_type == ModelTypes.PYTORCH:
  79. out = pytorch_model.run(input_tensor)
  80. elif model_type == ModelTypes.ONNX:
  81. out = onnx_model.run(input_tensor.detach().cpu().numpy())
  82. else:
  83. raise ValueError("only pytorch and onnx models allowed")
  84. if isinstance(out, torch.Tensor):
  85. out = out.detach().cpu().numpy()
  86. # TODO: compose batch if required
  87. image = Image.fromarray(np.uint8(out * 255), "L")
  88. dead_tree_fraction = float(out.sum() / out.size)
  89. stats = PredictionStats(
  90. fraction=dead_tree_fraction,
  91. model_name=MODEL,
  92. model_type=model_type.value,
  93. elapsed=elapsed(),
  94. )
  95. bytes_io = io.BytesIO()
  96. image.save(bytes_io, format="PNG")
  97. return Response(
  98. bytes_io.getvalue(),
  99. headers=predictionstats_to_str(stats),
  100. media_type="image/png",
  101. )
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...