Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

api.py 3.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
  1. from fastapi import FastAPI, File, UploadFile, Form, HTTPException
  2. from pydantic import BaseModel
  3. from typing import Optional
  4. import uuid
  5. import io
  6. import cv2
  7. import numpy as np
  8. from helper import load_model, show_model_not_loaded_warning, model_path
  9. import logging
  10. import os
  11. logger = logging.getLogger(__name__)
  12. app = FastAPI()
  13. class_names = {0: "Fitoftora", 1: "Monilia", 2: "Sana", 3: "Healthy"}
  14. class PredictionInfo(BaseModel):
  15. nb_classe: int
  16. nb_box: int
  17. confidence_min: float
  18. pred_classes: set
  19. bbx_coordinates: list[tuple[float, float, float, float, float]]
  20. pred_img_id: uuid.UUID
  21. box_id: uuid.UUID
  22. diseases_id: uuid.UUID
  23. det_image_path: str
  24. def predict_image(confidence: float, image_bytes: bytes, model):
  25. try:
  26. nparr = np.frombuffer(image_bytes, np.uint8)
  27. image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
  28. if image is None:
  29. raise HTTPException(
  30. status_code=400,
  31. detail="L'image est vide. Assurez-vous de télécharger une image valide."
  32. )
  33. # Convert image to RGB
  34. image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  35. res = model.predict(image_rgb, conf=confidence)
  36. nb_classe = len(set(res[0].boxes.cls.tolist()))
  37. nb_box = len(res[0].boxes)
  38. confidence_min = min(res[0].boxes.conf.tolist())
  39. pred_classes = set([class_names[int(box.cls.item())] for box in res[0].boxes])
  40. pred_img_id = uuid.uuid4()
  41. box_id = uuid.uuid4()
  42. diseases_id = uuid.uuid4()
  43. # Normalisation des coordonnées des boîtes englobantes
  44. bbx_coordinates = [
  45. (box.xyxy[0][0], box.xyxy[0][1],
  46. box.xyxy[0][2], box.xyxy[0][3], box.conf)
  47. for box in res[0].boxes]
  48. # Define image saving path for debugging (optional)
  49. current_directory = os.getcwd()
  50. IMAGE_SAVE_PATH = os.path.join(current_directory, "images_bbx")
  51. os.makedirs(IMAGE_SAVE_PATH, exist_ok=True)
  52. image_with_boxes_path = os.path.join(IMAGE_SAVE_PATH, f"imagebbx_{pred_img_id}.jpg")
  53. res_plotted = res[0].plot()[:, :, ::-1]
  54. cv2.imwrite(image_with_boxes_path, res_plotted)
  55. return PredictionInfo(
  56. pred_img_id=pred_img_id,
  57. diseases_id=diseases_id,
  58. box_id=box_id,
  59. nb_classe=nb_classe,
  60. nb_box=nb_box,
  61. confidence_min=confidence_min,
  62. pred_classes=pred_classes,
  63. bbx_coordinates=bbx_coordinates,
  64. det_image_path=image_with_boxes_path
  65. )
  66. except Exception as e:
  67. logger.exception("Une erreur est survenue pendant la prédiction : %s", e)
  68. raise HTTPException(status_code=500, detail="Une erreur interne est survenue.")
  69. @app.post("/predict_image/")
  70. async def predict_image_route(confidence: float = Form(...), image: UploadFile = File(...)):
  71. try:
  72. model = load_model(model_path)
  73. image_bytes = await image.read()
  74. return predict_image(confidence, image_bytes, model)
  75. except HTTPException as e:
  76. raise e
  77. except Exception as e:
  78. logger.exception("Une erreur interne est survenue lors de la prédiction d'image.")
  79. raise HTTPException(status_code=500, detail="Une erreur interne est survenue.")
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...