1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
|
- from typing import List, Optional, Tuple
- import cv2
- import numpy as np
- __all__ = ["load_video", "save_video"]
- def load_video(file_path: str, max_frames: Optional[int] = None) -> Tuple[List[np.ndarray], int]:
- """Open a video file and extract each frame into numpy array.
- :param file_path: Path to the video file.
- :param max_frames: Optional, maximum number of frames to extract.
- :return:
- - Frames representing the video, each in (H, W, C), RGB.
- - Frames per Second (FPS).
- """
- cap = _open_video(file_path)
- frames = _extract_frames(cap, max_frames)
- fps = cap.get(cv2.CAP_PROP_FPS)
- cap.release()
- return frames, fps
- def _open_video(file_path: str) -> cv2.VideoCapture:
- """Open a video file.
- :param file_path: Path to the video file
- :return: Opened video capture object
- """
- cap = cv2.VideoCapture(file_path)
- if not cap.isOpened():
- raise ValueError(f"Failed to open video file: {file_path}")
- return cap
- def _extract_frames(cap: cv2.VideoCapture, max_frames: Optional[int] = None) -> List[np.ndarray]:
- """Extract frames from an opened video capture object.
- :param cap: Opened video capture object.
- :param max_frames: Optional maximum number of frames to extract.
- :return: Frames representing the video, each in (H, W, C), RGB.
- """
- frames = []
- while max_frames != len(frames):
- frame_read_success, frame = cap.read()
- if not frame_read_success:
- break
- frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
- return frames
- def save_video(output_path: str, frames: List[np.ndarray], fps: int) -> None:
- """Save a video locally.
- :param output_path: Where the video will be saved
- :param frames: Frames representing the video, each in (H, W, C), RGB. Note that all the frames are expected to have the same shape.
- :param fps: Frames per second
- """
- video_height, video_width = _validate_frames(frames)
- video_writer = cv2.VideoWriter(
- output_path,
- cv2.VideoWriter_fourcc(*"mp4v"),
- fps,
- (video_width, video_height),
- )
- for frame in frames:
- video_writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
- video_writer.release()
- def _validate_frames(frames: List[np.ndarray]) -> Tuple[float, float]:
- """Validate the frames to make sure that every frame has the same size and includes the channel dimension. (i.e. (H, W, C))
- :param frames: Frames representing the video, each in (H, W, C), RGB. Note that all the frames are expected to have the same shape.
- :return: (Height, Weight) of the video.
- """
- min_height = min(frame.shape[0] for frame in frames)
- max_height = max(frame.shape[0] for frame in frames)
- min_width = min(frame.shape[1] for frame in frames)
- max_width = max(frame.shape[1] for frame in frames)
- if (min_height, min_width) != (max_height, max_width):
- raise RuntimeError(
- f"Your video is made of frames that have (height, width) going from ({min_height}, {min_width}) to ({max_height}, {max_width}).\n"
- f"Please make sure that all the frames have the same shape."
- )
- if set(frame.ndim for frame in frames) != {3} or set(frame.shape[-1] for frame in frames) != {3}:
- raise RuntimeError("Your frames must include 3 channels.")
- return max_height, max_width
|