1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
|
- from fastapi import FastAPI, File, UploadFile, Form, HTTPException
- from pydantic import BaseModel
- from typing import Optional
- import uuid
- import io
- import cv2
- import numpy as np
- from helper import load_model, show_model_not_loaded_warning, model_path
- import logging
- import os
- logger = logging.getLogger(__name__)
- app = FastAPI()
- class_names = {0: "Fitoftora", 1: "Monilia", 2: "Sana", 3: "Healthy"}
- class PredictionInfo(BaseModel):
- nb_classe: int
- nb_box: int
- confidence_min: float
- pred_classes: set
- bbx_coordinates: list[tuple[float, float, float, float, float]]
- pred_img_id: uuid.UUID
- box_id: uuid.UUID
- diseases_id: uuid.UUID
- det_image_path: str
- def predict_image(confidence: float, image_bytes: bytes, model):
- try:
- nparr = np.frombuffer(image_bytes, np.uint8)
- image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
- if image is None:
- raise HTTPException(
- status_code=400,
- detail="L'image est vide. Assurez-vous de télécharger une image valide."
- )
-
- # Convert image to RGB
- image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
- res = model.predict(image_rgb, conf=confidence)
- nb_classe = len(set(res[0].boxes.cls.tolist()))
- nb_box = len(res[0].boxes)
- confidence_min = min(res[0].boxes.conf.tolist())
- pred_classes = set([class_names[int(box.cls.item())] for box in res[0].boxes])
- pred_img_id = uuid.uuid4()
- box_id = uuid.uuid4()
- diseases_id = uuid.uuid4()
- # Normalisation des coordonnées des boîtes englobantes
- bbx_coordinates = [
- (box.xyxy[0][0], box.xyxy[0][1],
- box.xyxy[0][2], box.xyxy[0][3], box.conf)
- for box in res[0].boxes]
- # Define image saving path for debugging (optional)
- current_directory = os.getcwd()
- IMAGE_SAVE_PATH = os.path.join(current_directory, "images_bbx")
- os.makedirs(IMAGE_SAVE_PATH, exist_ok=True)
- image_with_boxes_path = os.path.join(IMAGE_SAVE_PATH, f"imagebbx_{pred_img_id}.jpg")
- res_plotted = res[0].plot()[:, :, ::-1]
- cv2.imwrite(image_with_boxes_path, res_plotted)
- return PredictionInfo(
- pred_img_id=pred_img_id,
- diseases_id=diseases_id,
- box_id=box_id,
- nb_classe=nb_classe,
- nb_box=nb_box,
- confidence_min=confidence_min,
- pred_classes=pred_classes,
- bbx_coordinates=bbx_coordinates,
- det_image_path=image_with_boxes_path
- )
- except Exception as e:
- logger.exception("Une erreur est survenue pendant la prédiction : %s", e)
- raise HTTPException(status_code=500, detail="Une erreur interne est survenue.")
- @app.post("/predict_image/")
- async def predict_image_route(confidence: float = Form(...), image: UploadFile = File(...)):
- try:
- model = load_model(model_path)
- image_bytes = await image.read()
- return predict_image(confidence, image_bytes, model)
- except HTTPException as e:
- raise e
- except Exception as e:
- logger.exception("Une erreur interne est survenue lors de la prédiction d'image.")
- raise HTTPException(status_code=500, detail="Une erreur interne est survenue.")
|