# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import logging
import math
import random
import time
from dataclasses import dataclass
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Tuple
import numpy as np
import torch
from iopath.common.file_io import g_pathmgr
from PIL import Image
from pytorchvideo.data.dataset_manifest_utils import (
ImageDataset,
ImageFrameInfo,
VideoClipInfo,
VideoDataset,
VideoDatasetType,
)
from pytorchvideo.data.utils import DataclassFieldCaster, load_dataclass_dict_from_csv
from pytorchvideo.data.video import Video
try:
import cv2
except ImportError:
_HAS_CV2 = False
else:
_HAS_CV2 = True
USER_ENVIRONMENT_MAP = {
0: "none",
1: "indoor",
2: "nature",
3: "crowded_environment",
4: "urban",
}
USER_ACTIVITY_MAP = {
0: "none",
1: "walking",
2: "running",
3: "standing",
4: "biking",
5: "driving",
6: "playing",
7: "cooking",
8: "eating",
9: "observing",
10: "in_conversation",
11: "browsing",
12: "shopping",
}
USER_ATTENTION_MAP = {
0: "none",
1: "paying_attention",
2: "interacting",
}
class LabelType(Enum):
Environment = 1
Activity = 2
UserAttention = 3
LABEL_TYPE_2_MAP = {
LabelType.Environment: USER_ENVIRONMENT_MAP,
LabelType.Activity: USER_ACTIVITY_MAP,
LabelType.UserAttention: USER_ATTENTION_MAP,
}
@dataclass
class LabelData(DataclassFieldCaster):
"""
Class representing a contiguous label for a video segment from the DoMSEV dataset.
"""
video_id: str
start_time: float # Start time of the label, in seconds
stop_time: float # Stop time of the label, in seconds
start_frame: int # 0-indexed ID of the start frame (inclusive)
stop_frame: int # 0-index ID of the stop frame (inclusive)
label_id: int
label_name: str
# Utility functions
def _seconds_to_frame_index(
time_in_seconds: float, fps: int, zero_indexed: Optional[bool] = True
) -> int:
"""
Converts a point in time (in seconds) within a video clip to its closest
frame indexed (rounding down), based on a specified frame rate.
Args:
time_in_seconds (float): The point in time within the video.
fps (int): The frame rate (frames per second) of the video.
zero_indexed (Optional[bool]): Whether the returned frame should be
zero-indexed (if True) or one-indexed (if False).
Returns:
(int) The index of the nearest frame (rounding down to the nearest integer).
"""
frame_idx = math.floor(time_in_seconds * fps)
if not zero_indexed:
frame_idx += 1
return frame_idx
def _get_overlap_for_time_range_pair(
t1_start: float, t1_stop: float, t2_start: float, t2_stop: float
) -> Optional[Tuple[float, float]]:
"""
Calculates the overlap between two time ranges, if one exists.
Returns:
(Optional[Tuple]) A tuple of <overlap_start_time, overlap_stop_time> if
an overlap is found, or None otherwise.
"""
# Check if there is an overlap
if (t1_start <= t2_stop) and (t2_start <= t1_stop):
# Calculate the overlap period
overlap_start_time = max(t1_start, t2_start)
overlap_stop_time = min(t1_stop, t2_stop)
return (overlap_start_time, overlap_stop_time)
else:
return None
[docs]class DomsevFrameDataset(torch.utils.data.Dataset):
"""
Egocentric video classification frame-based dataset for
`DoMSEV <https://www.verlab.dcc.ufmg.br/semantic-hyperlapse/cvpr2018-dataset/>`_
This dataset handles the loading, decoding, and configurable sampling for
the image frames.
"""
[docs] def __init__(
self,
video_data_manifest_file_path: str,
video_info_file_path: str,
labels_file_path: str,
transform: Optional[Callable[[Dict[str, Any]], Any]] = None,
multithreaded_io: bool = False,
) -> None:
"""
Args:
video_data_manifest_file_path (str):
The path to a json file outlining the available video data for the
associated videos. File must be a csv (w/header) with columns:
``{[f.name for f in dataclass_fields(EncodedVideoInfo)]}``
To generate this file from a directory of video frames, see helper
functions in module: ``pytorchvideo.data.domsev.utils``
video_info_file_path (str):
Path or URI to manifest with basic metadata of each video.
File must be a csv (w/header) with columns:
``{[f.name for f in dataclass_fields(VideoInfo)]}``
labels_file_path (str):
Path or URI to manifest with temporal annotations for each video.
File must be a csv (w/header) with columns:
``{[f.name for f in dataclass_fields(LabelData)]}``
dataset_type (VideoDatasetType): The data format in which dataset
video data is stored (e.g. video frames, encoded video etc).
transform (Optional[Callable[[Dict[str, Any]], Any]]):
This callable is evaluated on the clip output before the clip is returned.
It can be used for user-defined preprocessing and augmentations to the clips.
The clip output format is described in __next__().
multithreaded_io (bool):
Boolean to control whether io operations are performed across multiple
threads.
"""
assert video_info_file_path
assert labels_file_path
assert video_data_manifest_file_path
## Populate image frame and metadata data providers ##
# Maps a image frame ID to an `ImageFrameInfo`
frames_dict: Dict[str, ImageFrameInfo] = ImageDataset._load_images(
video_data_manifest_file_path,
video_info_file_path,
multithreaded_io,
)
video_labels: Dict[str, List[LabelData]] = load_dataclass_dict_from_csv(
labels_file_path, LabelData, "video_id", list_per_key=True
)
# Maps an image frame ID to the singular frame label
self._labels_per_frame: Dict[
str, int
] = DomsevFrameDataset._assign_labels_to_frames(frames_dict, video_labels)
self._user_transform = transform
self._transform = self._transform_frame
# Shuffle the frames order for iteration
self._frames = list(frames_dict.values())
random.shuffle(self._frames)
@staticmethod
def _assign_labels_to_frames(
frames_dict: Dict[str, ImageFrameInfo],
video_labels: Dict[str, List[LabelData]],
):
"""
Args:
frames_dict: The mapping of <frame_id, ImageFrameInfo> for all the frames
in the dataset.
video_labels: The list of temporal labels for each video
Also unpacks one label per frame.
Also converts them to class IDs and then a tensor.
"""
labels_per_frame: Dict[str, int] = {}
for frame_id, image_info in frames_dict.items():
# Filter labels by only the ones that appear within the clip boundaries,
# and unpack the labels so there is one per frame in the clip
labels_in_video = video_labels[image_info.video_id]
for label in labels_in_video:
if (image_info.frame_number >= label.start_frame) and (
image_info.frame_number <= label.stop_frame
):
labels_per_frame[frame_id] = label.label_id
return labels_per_frame
[docs] def __getitem__(self, index) -> Dict[str, Any]:
"""
Samples an image frame associated to the given index.
Args:
index (int): index for the image frame
Returns:
An image frame with the following format if transform is None.
.. code-block:: text
{{
'frame_id': <str>,
'image': <image_tensor>,
'label': <label_tensor>,
}}
"""
frame = self._frames[index]
label_in_frame = self._labels_per_frame[frame.frame_id]
image_data = _load_image_from_path(frame.frame_file_path)
frame_data = {
"frame_id": frame.frame_id,
"image": image_data,
"label": label_in_frame,
}
if self._transform:
frame_data = self._transform(frame_data)
return frame_data
[docs] def __len__(self) -> int:
"""
Returns:
The number of frames in the dataset.
"""
return len(self._frames)
def _transform_frame(self, frame: Dict[str, Any]) -> Dict[str, Any]:
"""
Transforms a given image frame, according to some pre-defined transforms
and an optional user transform function (self._user_transform).
Args:
clip (Dict[str, Any]): The clip that will be transformed.
Returns:
(Dict[str, Any]) The transformed clip.
"""
for key in frame:
if frame[key] is None:
frame[key] = torch.tensor([])
if self._user_transform:
frame = self._user_transform(frame)
return frame
[docs]class DomsevVideoDataset(torch.utils.data.Dataset):
"""
Egocentric classification video clip-based dataset for
`DoMSEV <https://www.verlab.dcc.ufmg.br/semantic-hyperlapse/cvpr2018-dataset/>`_
stored as an encoded video (with frame-level labels).
This dataset handles the loading, decoding, and configurable clip
sampling for the videos.
"""
[docs] def __init__(
self,
video_data_manifest_file_path: str,
video_info_file_path: str,
labels_file_path: str,
clip_sampler: Callable[
[Dict[str, Video], Dict[str, List[LabelData]]], List[VideoClipInfo]
],
dataset_type: VideoDatasetType = VideoDatasetType.Frame,
frames_per_second: int = 1,
transform: Optional[Callable[[Dict[str, Any]], Any]] = None,
frame_filter: Optional[Callable[[List[int]], List[int]]] = None,
multithreaded_io: bool = False,
) -> None:
"""
Args:
video_data_manifest_file_path (str):
The path to a json file outlining the available video data for the
associated videos. File must be a csv (w/header) with columns:
``{[f.name for f in dataclass_fields(EncodedVideoInfo)]}``
To generate this file from a directory of video frames, see helper
functions in module: ``pytorchvideo.data.domsev.utils``
video_info_file_path (str):
Path or URI to manifest with basic metadata of each video.
File must be a csv (w/header) with columns:
``{[f.name for f in dataclass_fields(VideoInfo)]}``
labels_file_path (str):
Path or URI to manifest with annotations for each video.
File must be a csv (w/header) with columns:
``{[f.name for f in dataclass_fields(LabelData)]}``
clip_sampler (Callable[[Dict[str, Video], Dict[str, List[LabelData]]],
List[VideoClipInfo]]):
Defines how clips should be sampled from each video. See the clip
sampling documentation for more information.
dataset_type (VideoDatasetType): The data format in which dataset
video data is stored (e.g. video frames, encoded video etc).
frames_per_second (int): The FPS of the stored videos. (NOTE:
this is variable and may be different than the original FPS
reported on the DoMSEV dataset website -- it depends on the
preprocessed subsampling and frame extraction).
transform (Optional[Callable[[Dict[str, Any]], Any]]):
This callable is evaluated on the clip output before the clip is returned.
It can be used for user-defined preprocessing and augmentations to the clips.
The clip output format is described in __next__().
frame_filter (Optional[Callable[[List[int]], List[int]]]):
This callable is evaluated on the set of available frame indices to be
included in a sampled clip. This can be used to subselect frames within
a clip to be loaded.
multithreaded_io (bool):
Boolean to control whether io operations are performed across multiple
threads.
"""
assert video_info_file_path
assert labels_file_path
assert video_data_manifest_file_path
# Populate video and metadata data providers
self._videos: Dict[str, Video] = VideoDataset._load_videos(
video_data_manifest_file_path,
video_info_file_path,
multithreaded_io,
dataset_type,
)
self._labels_per_video: Dict[
str, List[LabelData]
] = load_dataclass_dict_from_csv(
labels_file_path, LabelData, "video_id", list_per_key=True
)
# Sample datapoints
self._clips: List[VideoClipInfo] = clip_sampler(
self._videos, self._labels_per_video
)
self._frames_per_second = frames_per_second
self._user_transform = transform
self._transform = self._transform_clip
self._frame_filter = frame_filter
[docs] def __getitem__(self, index) -> Dict[str, Any]:
"""
Samples a video clip associated to the given index.
Args:
index (int): index for the video clip.
Returns:
A video clip with the following format if transform is None.
.. code-block:: text
{{
'video_id': <str>,
'video': <video_tensor>,
'audio': <audio_tensor>,
'labels': <labels_tensor>,
'start_time': <float>,
'stop_time': <float>
}}
"""
clip = self._clips[index]
# Filter labels by only the ones that appear within the clip boundaries,
# and unpack the labels so there is one per frame in the clip
labels_in_video = self._labels_per_video[clip.video_id]
labels_in_clip = []
for label_data in labels_in_video:
overlap_period = _get_overlap_for_time_range_pair(
clip.start_time,
clip.stop_time,
label_data.start_time,
label_data.stop_time,
)
if overlap_period is not None:
overlap_start_time, overlap_stop_time = overlap_period
# Convert the overlapping period between clip and label to
# 0-indexed start and stop frame indexes, so we can unpack 1
# label per frame.
overlap_start_frame = _seconds_to_frame_index(
overlap_start_time, self._frames_per_second
)
overlap_stop_frame = _seconds_to_frame_index(
overlap_stop_time, self._frames_per_second
)
# Append 1 label per frame
for _ in range(overlap_start_frame, overlap_stop_frame):
labels_in_clip.append(label_data)
# Convert the list of LabelData objects to a tensor of just the label IDs
label_ids = [labels_in_clip[i].label_id for i in range(len(labels_in_clip))]
label_ids_tensor = torch.tensor(label_ids)
clip_data = {
"video_id": clip.video_id,
**self._videos[clip.video_id].get_clip(clip.start_time, clip.stop_time),
"labels": label_ids_tensor,
"start_time": clip.start_time,
"stop_time": clip.stop_time,
}
if self._transform:
clip_data = self._transform(clip_data)
return clip_data
[docs] def __len__(self) -> int:
"""
Returns:
The number of video clips in the dataset.
"""
return len(self._clips)
def _transform_clip(self, clip: Dict[str, Any]) -> Dict[str, Any]:
"""
Transforms a given video clip, according to some pre-defined transforms
and an optional user transform function (self._user_transform).
Args:
clip (Dict[str, Any]): The clip that will be transformed.
Returns:
(Dict[str, Any]) The transformed clip.
"""
for key in clip:
if clip[key] is None:
clip[key] = torch.tensor([])
if self._user_transform:
clip = self._user_transform(clip)
return clip
def _load_image_from_path(image_path: str, num_retries: int = 10) -> Image:
"""
Loads the given image path using PathManager and decodes it as an RGB image.
Args:
image_path (str): the path to the image.
num_retries (int): number of times to retry image reading to handle transient error.
Returns:
A PIL Image of the image RGB data with shape:
(channel, height, width). The frames are of type np.uint8 and
in the range [0 - 255]. Raises an exception if unable to load images.
"""
if not _HAS_CV2:
raise ImportError(
"opencv2 is required to use FrameVideo. Please "
"install with 'pip install opencv-python'"
)
img_arr = None
for i in range(num_retries):
with g_pathmgr.open(image_path, "rb") as f:
img_str = np.frombuffer(f.read(), np.uint8)
img_bgr = cv2.imdecode(img_str, flags=cv2.IMREAD_COLOR)
img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
if img_rgb is not None:
img_arr = img_rgb
break
else:
logging.warning(f"Reading attempt {i}/{num_retries} failed.")
time.sleep(1e-6)
if img_arr is None:
raise Exception("Failed to load image from {}".format(image_path))
pil_image = Image.fromarray(img_arr)
return pil_image