Shortcuts

Source code for pytorchvideo.data.domsev

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

import logging
import math
import random
import time
from dataclasses import dataclass
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Tuple

import numpy as np
import torch
from iopath.common.file_io import g_pathmgr
from PIL import Image
from pytorchvideo.data.dataset_manifest_utils import (
    ImageDataset,
    ImageFrameInfo,
    VideoClipInfo,
    VideoDataset,
    VideoDatasetType,
)
from pytorchvideo.data.utils import (
    DataclassFieldCaster,
    load_dataclass_dict_from_csv,
)
from pytorchvideo.data.video import Video


try:
    import cv2
except ImportError:
    _HAS_CV2 = False
else:
    _HAS_CV2 = True


USER_ENVIRONMENT_MAP = {
    0: "none",
    1: "indoor",
    2: "nature",
    3: "crowded_environment",
    4: "urban",
}

USER_ACTIVITY_MAP = {
    0: "none",
    1: "walking",
    2: "running",
    3: "standing",
    4: "biking",
    5: "driving",
    6: "playing",
    7: "cooking",
    8: "eating",
    9: "observing",
    10: "in_conversation",
    11: "browsing",
    12: "shopping",
}

USER_ATTENTION_MAP = {
    0: "none",
    1: "paying_attention",
    2: "interacting",
}


class LabelType(Enum):
    Environment = 1
    Activity = 2
    UserAttention = 3


LABEL_TYPE_2_MAP = {
    LabelType.Environment: USER_ENVIRONMENT_MAP,
    LabelType.Activity: USER_ACTIVITY_MAP,
    LabelType.UserAttention: USER_ATTENTION_MAP,
}


@dataclass
class LabelData(DataclassFieldCaster):
    """
    Class representing a contiguous label for a video segment from the DoMSEV dataset.
    """

    video_id: str
    start_time: float  # Start time of the label, in seconds
    stop_time: float  # Stop time of the label, in seconds
    start_frame: int  # 0-indexed ID of the start frame (inclusive)
    stop_frame: int  # 0-index ID of the stop frame (inclusive)
    label_id: int
    label_name: str


# Utility functions
def _seconds_to_frame_index(
    time_in_seconds: float, fps: int, zero_indexed: Optional[bool] = True
) -> int:
    """
    Converts a point in time (in seconds) within a video clip to its closest
    frame indexed (rounding down), based on a specified frame rate.

    Args:
        time_in_seconds (float): The point in time within the video.
        fps (int): The frame rate (frames per second) of the video.
        zero_indexed (Optional[bool]): Whether the returned frame should be
            zero-indexed (if True) or one-indexed (if False).

    Returns:
        (int) The index of the nearest frame (rounding down to the nearest integer).
    """
    frame_idx = math.floor(time_in_seconds * fps)
    if not zero_indexed:
        frame_idx += 1
    return frame_idx


def _get_overlap_for_time_range_pair(
    t1_start: float, t1_stop: float, t2_start: float, t2_stop: float
) -> Optional[Tuple[float, float]]:
    """
    Calculates the overlap between two time ranges, if one exists.

    Returns:
        (Optional[Tuple]) A tuple of <overlap_start_time, overlap_stop_time> if
        an overlap is found, or None otherwise.
    """
    # Check if there is an overlap
    if (t1_start <= t2_stop) and (t2_start <= t1_stop):
        # Calculate the overlap period
        overlap_start_time = max(t1_start, t2_start)
        overlap_stop_time = min(t1_stop, t2_stop)
        return (overlap_start_time, overlap_stop_time)
    else:
        return None


[docs]class DomsevFrameDataset(torch.utils.data.Dataset): """ Egocentric video classification frame-based dataset for `DoMSEV <https://www.verlab.dcc.ufmg.br/semantic-hyperlapse/cvpr2018-dataset/>`_ This dataset handles the loading, decoding, and configurable sampling for the image frames. """
[docs] def __init__( self, video_data_manifest_file_path: str, video_info_file_path: str, labels_file_path: str, transform: Optional[Callable[[Dict[str, Any]], Any]] = None, multithreaded_io: bool = False, ) -> None: """ Args: video_data_manifest_file_path (str): The path to a json file outlining the available video data for the associated videos. File must be a csv (w/header) with columns: ``{[f.name for f in dataclass_fields(EncodedVideoInfo)]}`` To generate this file from a directory of video frames, see helper functions in module: ``pytorchvideo.data.domsev.utils`` video_info_file_path (str): Path or URI to manifest with basic metadata of each video. File must be a csv (w/header) with columns: ``{[f.name for f in dataclass_fields(VideoInfo)]}`` labels_file_path (str): Path or URI to manifest with temporal annotations for each video. File must be a csv (w/header) with columns: ``{[f.name for f in dataclass_fields(LabelData)]}`` dataset_type (VideoDatasetType): The data format in which dataset video data is stored (e.g. video frames, encoded video etc). transform (Optional[Callable[[Dict[str, Any]], Any]]): This callable is evaluated on the clip output before the clip is returned. It can be used for user-defined preprocessing and augmentations to the clips. The clip output format is described in __next__(). multithreaded_io (bool): Boolean to control whether io operations are performed across multiple threads. """ assert video_info_file_path assert labels_file_path assert video_data_manifest_file_path ## Populate image frame and metadata data providers ## # Maps a image frame ID to an `ImageFrameInfo` frames_dict: Dict[str, ImageFrameInfo] = ImageDataset._load_images( video_data_manifest_file_path, video_info_file_path, multithreaded_io, ) video_labels: Dict[str, List[LabelData]] = load_dataclass_dict_from_csv( labels_file_path, LabelData, "video_id", list_per_key=True ) # Maps an image frame ID to the singular frame label self._labels_per_frame: Dict[ str, int ] = DomsevFrameDataset._assign_labels_to_frames(frames_dict, video_labels) self._user_transform = transform self._transform = self._transform_frame # Shuffle the frames order for iteration self._frames = list(frames_dict.values()) random.shuffle(self._frames)
@staticmethod def _assign_labels_to_frames( frames_dict: Dict[str, ImageFrameInfo], video_labels: Dict[str, List[LabelData]], ): """ Args: frames_dict: The mapping of <frame_id, ImageFrameInfo> for all the frames in the dataset. video_labels: The list of temporal labels for each video Also unpacks one label per frame. Also converts them to class IDs and then a tensor. """ labels_per_frame: Dict[str, int] = {} for frame_id, image_info in frames_dict.items(): # Filter labels by only the ones that appear within the clip boundaries, # and unpack the labels so there is one per frame in the clip labels_in_video = video_labels[image_info.video_id] for label in labels_in_video: if (image_info.frame_number >= label.start_frame) and ( image_info.frame_number <= label.stop_frame ): labels_per_frame[frame_id] = label.label_id return labels_per_frame
[docs] def __getitem__(self, index) -> Dict[str, Any]: """ Samples an image frame associated to the given index. Args: index (int): index for the image frame Returns: An image frame with the following format if transform is None. .. code-block:: text {{ 'frame_id': <str>, 'image': <image_tensor>, 'label': <label_tensor>, }} """ frame = self._frames[index] label_in_frame = self._labels_per_frame[frame.frame_id] image_data = _load_image_from_path(frame.frame_file_path) frame_data = { "frame_id": frame.frame_id, "image": image_data, "label": label_in_frame, } if self._transform: frame_data = self._transform(frame_data) return frame_data
[docs] def __len__(self) -> int: """ Returns: The number of frames in the dataset. """ return len(self._frames)
def _transform_frame(self, frame: Dict[str, Any]) -> Dict[str, Any]: """ Transforms a given image frame, according to some pre-defined transforms and an optional user transform function (self._user_transform). Args: clip (Dict[str, Any]): The clip that will be transformed. Returns: (Dict[str, Any]) The transformed clip. """ for key in frame: if frame[key] is None: frame[key] = torch.tensor([]) if self._user_transform: frame = self._user_transform(frame) return frame
[docs]class DomsevVideoDataset(torch.utils.data.Dataset): """ Egocentric classification video clip-based dataset for `DoMSEV <https://www.verlab.dcc.ufmg.br/semantic-hyperlapse/cvpr2018-dataset/>`_ stored as an encoded video (with frame-level labels). This dataset handles the loading, decoding, and configurable clip sampling for the videos. """
[docs] def __init__( self, video_data_manifest_file_path: str, video_info_file_path: str, labels_file_path: str, clip_sampler: Callable[ [Dict[str, Video], Dict[str, List[LabelData]]], List[VideoClipInfo] ], dataset_type: VideoDatasetType = VideoDatasetType.Frame, frames_per_second: int = 1, transform: Optional[Callable[[Dict[str, Any]], Any]] = None, frame_filter: Optional[Callable[[List[int]], List[int]]] = None, multithreaded_io: bool = False, ) -> None: """ Args: video_data_manifest_file_path (str): The path to a json file outlining the available video data for the associated videos. File must be a csv (w/header) with columns: ``{[f.name for f in dataclass_fields(EncodedVideoInfo)]}`` To generate this file from a directory of video frames, see helper functions in module: ``pytorchvideo.data.domsev.utils`` video_info_file_path (str): Path or URI to manifest with basic metadata of each video. File must be a csv (w/header) with columns: ``{[f.name for f in dataclass_fields(VideoInfo)]}`` labels_file_path (str): Path or URI to manifest with annotations for each video. File must be a csv (w/header) with columns: ``{[f.name for f in dataclass_fields(LabelData)]}`` clip_sampler (Callable[[Dict[str, Video], Dict[str, List[LabelData]]], List[VideoClipInfo]]): Defines how clips should be sampled from each video. See the clip sampling documentation for more information. dataset_type (VideoDatasetType): The data format in which dataset video data is stored (e.g. video frames, encoded video etc). frames_per_second (int): The FPS of the stored videos. (NOTE: this is variable and may be different than the original FPS reported on the DoMSEV dataset website -- it depends on the preprocessed subsampling and frame extraction). transform (Optional[Callable[[Dict[str, Any]], Any]]): This callable is evaluated on the clip output before the clip is returned. It can be used for user-defined preprocessing and augmentations to the clips. The clip output format is described in __next__(). frame_filter (Optional[Callable[[List[int]], List[int]]]): This callable is evaluated on the set of available frame indices to be included in a sampled clip. This can be used to subselect frames within a clip to be loaded. multithreaded_io (bool): Boolean to control whether io operations are performed across multiple threads. """ assert video_info_file_path assert labels_file_path assert video_data_manifest_file_path # Populate video and metadata data providers self._videos: Dict[str, Video] = VideoDataset._load_videos( video_data_manifest_file_path, video_info_file_path, multithreaded_io, dataset_type, ) self._labels_per_video: Dict[ str, List[LabelData] ] = load_dataclass_dict_from_csv( labels_file_path, LabelData, "video_id", list_per_key=True ) # Sample datapoints self._clips: List[VideoClipInfo] = clip_sampler( self._videos, self._labels_per_video ) self._frames_per_second = frames_per_second self._user_transform = transform self._transform = self._transform_clip self._frame_filter = frame_filter
[docs] def __getitem__(self, index) -> Dict[str, Any]: """ Samples a video clip associated to the given index. Args: index (int): index for the video clip. Returns: A video clip with the following format if transform is None. .. code-block:: text {{ 'video_id': <str>, 'video': <video_tensor>, 'audio': <audio_tensor>, 'labels': <labels_tensor>, 'start_time': <float>, 'stop_time': <float> }} """ clip = self._clips[index] # Filter labels by only the ones that appear within the clip boundaries, # and unpack the labels so there is one per frame in the clip labels_in_video = self._labels_per_video[clip.video_id] labels_in_clip = [] for label_data in labels_in_video: overlap_period = _get_overlap_for_time_range_pair( clip.start_time, clip.stop_time, label_data.start_time, label_data.stop_time, ) if overlap_period is not None: overlap_start_time, overlap_stop_time = overlap_period # Convert the overlapping period between clip and label to # 0-indexed start and stop frame indexes, so we can unpack 1 # label per frame. overlap_start_frame = _seconds_to_frame_index( overlap_start_time, self._frames_per_second ) overlap_stop_frame = _seconds_to_frame_index( overlap_stop_time, self._frames_per_second ) # Append 1 label per frame for _ in range(overlap_start_frame, overlap_stop_frame): labels_in_clip.append(label_data) # Convert the list of LabelData objects to a tensor of just the label IDs label_ids = [labels_in_clip[i].label_id for i in range(len(labels_in_clip))] label_ids_tensor = torch.tensor(label_ids) clip_data = { "video_id": clip.video_id, **self._videos[clip.video_id].get_clip(clip.start_time, clip.stop_time), "labels": label_ids_tensor, "start_time": clip.start_time, "stop_time": clip.stop_time, } if self._transform: clip_data = self._transform(clip_data) return clip_data
[docs] def __len__(self) -> int: """ Returns: The number of video clips in the dataset. """ return len(self._clips)
def _transform_clip(self, clip: Dict[str, Any]) -> Dict[str, Any]: """ Transforms a given video clip, according to some pre-defined transforms and an optional user transform function (self._user_transform). Args: clip (Dict[str, Any]): The clip that will be transformed. Returns: (Dict[str, Any]) The transformed clip. """ for key in clip: if clip[key] is None: clip[key] = torch.tensor([]) if self._user_transform: clip = self._user_transform(clip) return clip
def _load_image_from_path(image_path: str, num_retries: int = 10) -> Image: """ Loads the given image path using PathManager and decodes it as an RGB image. Args: image_path (str): the path to the image. num_retries (int): number of times to retry image reading to handle transient error. Returns: A PIL Image of the image RGB data with shape: (channel, height, width). The frames are of type np.uint8 and in the range [0 - 255]. Raises an exception if unable to load images. """ if not _HAS_CV2: raise ImportError( "opencv2 is required to use FrameVideo. Please " "install with 'pip install opencv-python'" ) img_arr = None for i in range(num_retries): with g_pathmgr.open(image_path, "rb") as f: img_str = np.frombuffer(f.read(), np.uint8) img_bgr = cv2.imdecode(img_str, flags=cv2.IMREAD_COLOR) img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) if img_rgb is not None: img_arr = img_rgb break else: logging.warning(f"Reading attempt {i}/{num_retries} failed.") time.sleep(1e-6) if img_arr is None: raise Exception("Failed to load image from {}".format(image_path)) pil_image = Image.fromarray(img_arr) return pil_image
Read the Docs v: stable
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.