Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#815 Feature/sg 747 Adding utils to load/save videos

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-747-support_predict_video
1 changed files with 99 additions and 0 deletions
  1. 99
    0
      src/super_gradients/training/utils/videos.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
  1. from typing import List, Optional, Tuple
  2. import cv2
  3. import numpy as np
  4. __all__ = ["load_video", "save_video"]
  5. def load_video(file_path: str, max_frames: Optional[int] = None) -> Tuple[List[np.ndarray], int]:
  6. """Open a video file and extract each frame into numpy array.
  7. :param file_path: Path to the video file.
  8. :param max_frames: Optional, maximum number of frames to extract.
  9. :return:
  10. - Frames representing the video, each in (H, W, C), RGB.
  11. - Frames per Second (FPS).
  12. """
  13. cap = _open_video(file_path)
  14. frames = _extract_frames(cap, max_frames)
  15. fps = cap.get(cv2.CAP_PROP_FPS)
  16. cap.release()
  17. return frames, fps
  18. def _open_video(file_path: str) -> cv2.VideoCapture:
  19. """Open a video file.
  20. :param file_path: Path to the video file
  21. :return: Opened video capture object
  22. """
  23. cap = cv2.VideoCapture(file_path)
  24. if not cap.isOpened():
  25. raise ValueError(f"Failed to open video file: {file_path}")
  26. return cap
  27. def _extract_frames(cap: cv2.VideoCapture, max_frames: Optional[int] = None) -> List[np.ndarray]:
  28. """Extract frames from an opened video capture object.
  29. :param cap: Opened video capture object.
  30. :param max_frames: Optional maximum number of frames to extract.
  31. :return: Frames representing the video, each in (H, W, C), RGB.
  32. """
  33. frames = []
  34. while max_frames != len(frames):
  35. frame_read_success, frame = cap.read()
  36. if not frame_read_success:
  37. break
  38. frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
  39. return frames
  40. def save_video(output_path: str, frames: List[np.ndarray], fps: int) -> None:
  41. """Save a video locally.
  42. :param output_path: Where the video will be saved
  43. :param frames: Frames representing the video, each in (H, W, C), RGB. Note that all the frames are expected to have the same shape.
  44. :param fps: Frames per second
  45. """
  46. video_height, video_width = _validate_frames(frames)
  47. video_writer = cv2.VideoWriter(
  48. output_path,
  49. cv2.VideoWriter_fourcc(*"mp4v"),
  50. fps,
  51. (video_width, video_height),
  52. )
  53. for frame in frames:
  54. video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
  55. video_writer.release()
  56. def _validate_frames(frames: List[np.ndarray]) -> Tuple[float, float]:
  57. """Validate the frames to make sure that every frame has the same size and includes the channel dimension. (i.e. (H, W, C))
  58. :param frames: Frames representing the video, each in (H, W, C), RGB. Note that all the frames are expected to have the same shape.
  59. :return: (Height, Weight) of the video.
  60. """
  61. min_height = min(frame.shape[0] for frame in frames)
  62. max_height = max(frame.shape[0] for frame in frames)
  63. min_width = min(frame.shape[1] for frame in frames)
  64. max_width = max(frame.shape[1] for frame in frames)
  65. if (min_height, min_width) != (max_height, max_width):
  66. raise RuntimeError(
  67. f"Your video is made of frames that have (height, width) going from ({min_height}, {min_width}) to ({max_height}, {max_width}).\n"
  68. f"Please make sure that all the frames have the same shape."
  69. )
  70. if set(frame.ndim for frame in frames) != {3} or set(frame.shape[-1] for frame in frames) != {3}:
  71. raise RuntimeError("Your frames must include 3 channels.")
  72. return max_height, max_width
Discard