diff --git a/data/dataset.py b/data/dataset.py index 90d4992..5f5a623 100644 --- a/data/dataset.py +++ b/data/dataset.py @@ -25,6 +25,7 @@ import tensorflow as tf import tensorflow_datasets as tfds + DatasetRegistry = registry.Registry() @@ -199,6 +200,11 @@ def num_eval_examples(self): class TFRecordDataset(Dataset): """A dataset created from tfrecord files.""" + def __init__(self, config: ml_collections.ConfigDict): + """Constructs the dataset.""" + super().__init__(config) + self.dataset_cls = tf.data.TFRecordDataset + def load_dataset(self, input_context, training): """Load tf.data.Dataset from TFRecord files.""" if training or self.config.eval_split == 'train': @@ -207,7 +213,8 @@ def load_dataset(self, input_context, training): file_pattern = self.config.val_file_pattern dataset = tf.data.Dataset.list_files(file_pattern, shuffle=training) dataset = dataset.interleave( - tf.data.TFRecordDataset, cycle_length=32, deterministic=not training) + self.dataset_cls, cycle_length=32, deterministic=not training, + num_parallel_calls=tf.data.experimental.AUTOTUNE) return dataset @abc.abstractmethod @@ -254,3 +261,5 @@ def num_train_examples(self): def num_eval_examples(self): return self.config.eval_num_examples if not self.task_config.get( 'unbatch', False) else None + + diff --git a/metrics/segmentation_and_tracking_quality.py b/metrics/segmentation_and_tracking_quality.py new file mode 100644 index 0000000..56a8b05 --- /dev/null +++ b/metrics/segmentation_and_tracking_quality.py @@ -0,0 +1,359 @@ +# coding=utf-8 +# Copyright 2022 The Pix2Seq Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implementation of the Segmentation and Tracking Quality (STQ) metric. + +This is a copy of +https://github.com/google-research/deeplab2/blob/main/evaluation/depth_aware_segmentation_and_tracking_quality.py +""" + +import collections +from typing import Any, Dict, MutableMapping, Optional, Sequence, Text, Union + +import warnings + +import numpy as np +import tensorflow as tf + + +def _check_weights(unique_weight_list: Sequence[float]): + if not set(unique_weight_list).issubset({0.5, 1.0}): + warnings.warn( + 'Potential performance degration as the code is not optimized' + ' when weights has too many different elements.' + ) + + +def _update_dict_stats( + stat_dict: MutableMapping[int, tf.Tensor], + id_array: tf.Tensor, + weights: Optional[tf.Tensor] = None, +): + """Updates a given dict with corresponding counts.""" + if weights is None: + unique_weight_list = [1.0] + else: + unique_weight_list, _ = tf.unique(weights) + unique_weight_list = unique_weight_list.numpy().tolist() + _check_weights(unique_weight_list) + # Iterate through the unique weight values, and weighted-average the counts. + # Example usage: lower the weights in the region covered by multiple camera in + # panoramic video panoptic segmentation (PVPS). + for weight in unique_weight_list: + if weights is None: + ids, _, counts = tf.unique_with_counts(id_array) + else: + ids, _, counts = tf.unique_with_counts( + tf.boolean_mask(id_array, tf.equal(weight, weights))) + for idx, count in zip(ids.numpy(), tf.cast(counts, tf.float32)): + if idx in stat_dict: + stat_dict[idx] += count * weight + else: + stat_dict[idx] = count * weight + + +class STQuality(object): + """Metric class for the Segmentation and Tracking Quality (STQ). + + The metric computes the geometric mean of two terms. + - Association Quality: This term measures the quality of the track ID + assignment for `thing` classes. It is formulated as a weighted IoU + measure. + - Segmentation Quality: This term measures the semantic segmentation quality. + The standard class IoU measure is used for this. + + Example usage: + + stq_obj = segmentation_tracking_quality.STQuality(num_classes, things_list, + ignore_label, max_instances_per_category, offset) + stq_obj.update_state(y_true_1, y_pred_1) + stq_obj.update_state(y_true_2, y_pred_2) + ... + result = stq_obj.result().numpy() + """ + + def __init__(self, + num_classes: int, + things_list: Sequence[int], + ignore_label: int, + max_instances_per_category: int, + offset: int, + name='stq' + ): + """Initialization of the STQ metric. + + Args: + num_classes: Number of classes in the dataset as an integer. + things_list: A sequence of class ids that belong to `things`. + ignore_label: The class id to be ignored in evaluation as an integer or + integer tensor. + max_instances_per_category: The maximum number of instances for each class + as an integer or integer tensor. + offset: The maximum number of unique labels as an integer or integer + tensor. + name: An optional name. (default: 'st_quality') + """ + self._name = name + self._num_classes = num_classes + self._ignore_label = ignore_label + self._things_list = things_list + self._max_instances_per_category = max_instances_per_category + + if ignore_label >= num_classes: + self._confusion_matrix_size = num_classes + 1 + self._include_indices = np.arange(self._num_classes) + else: + self._confusion_matrix_size = num_classes + self._include_indices = np.array( + [i for i in range(num_classes) if i != self._ignore_label]) + + self._iou_confusion_matrix_per_sequence = collections.OrderedDict() + self._predictions = collections.OrderedDict() + self._ground_truth = collections.OrderedDict() + self._intersections = collections.OrderedDict() + self._sequence_length = collections.OrderedDict() + self._offset = offset + lower_bound = num_classes * max_instances_per_category + if offset < lower_bound: + raise ValueError('The provided offset %d is too small. No guarantess ' + 'about the correctness of the results can be made. ' + 'Please choose an offset that is higher than num_classes' + ' * max_instances_per_category = %d' % lower_bound) + + def update_state(self, + y_true: tf.Tensor, + y_pred: tf.Tensor, + sequence_id: Union[int, str] = 0, + weights: Optional[tf.Tensor] = None): + """Accumulates the segmentation and tracking quality statistics. + + Args: + y_true: The ground-truth panoptic label map for a particular video frame + (defined as semantic_map * max_instances_per_category + instance_map). + y_pred: The predicted panoptic label map for a particular video frame + (defined as semantic_map * max_instances_per_category + instance_map). + sequence_id: The optional ID of the sequence the frames belong to. When no + sequence is given, all frames are considered to belong to the same + sequence (default: 0). + weights: The weights for each pixel with the same shape of `y_true`. + """ + y_true = tf.cast(y_true, dtype=tf.int64) + y_pred = tf.cast(y_pred, dtype=tf.int64) + if weights is not None: + weights = tf.reshape(weights, y_true.shape) + semantic_label = y_true // self._max_instances_per_category + semantic_prediction = y_pred // self._max_instances_per_category + # Check if the ignore value is outside the range [0, num_classes]. If yes, + # map `_ignore_label` to `_num_classes`, so it can be used to create the + # confusion matrix. + if self._ignore_label > self._num_classes: + semantic_label = tf.where( + tf.not_equal(semantic_label, self._ignore_label), semantic_label, + self._num_classes) + semantic_prediction = tf.where( + tf.not_equal(semantic_prediction, self._ignore_label), + semantic_prediction, self._num_classes) + if sequence_id in self._iou_confusion_matrix_per_sequence: + self._iou_confusion_matrix_per_sequence[sequence_id] += ( + tf.math.confusion_matrix( + tf.reshape(semantic_label, [-1]), + tf.reshape(semantic_prediction, [-1]), + self._confusion_matrix_size, + dtype=tf.float64, + weights=tf.reshape(weights, [-1]) + if weights is not None else None)) + self._sequence_length[sequence_id] += 1 + else: + self._iou_confusion_matrix_per_sequence[sequence_id] = ( + tf.math.confusion_matrix( + tf.reshape(semantic_label, [-1]), + tf.reshape(semantic_prediction, [-1]), + self._confusion_matrix_size, + dtype=tf.float64, + weights=tf.reshape(weights, [-1]) + if weights is not None else None)) + self._predictions[sequence_id] = {} + self._ground_truth[sequence_id] = {} + self._intersections[sequence_id] = {} + self._sequence_length[sequence_id] = 1 + + instance_label = y_true % self._max_instances_per_category + + label_mask = tf.zeros_like(semantic_label, dtype=tf.bool) + prediction_mask = tf.zeros_like(semantic_prediction, dtype=tf.bool) + for things_class_id in self._things_list: + label_mask = tf.logical_or(label_mask, + tf.equal(semantic_label, things_class_id)) + prediction_mask = tf.logical_or( + prediction_mask, tf.equal(semantic_prediction, things_class_id)) + + # Select the `crowd` region of the current class. This region is encoded + # instance id `0`. + is_crowd = tf.logical_and(tf.equal(instance_label, 0), label_mask) + # Select the non-crowd region of the corresponding class as the `crowd` + # region is ignored for the tracking term. + label_mask = tf.logical_and(label_mask, tf.logical_not(is_crowd)) + # Do not punish id assignment for regions that are annotated as `crowd` in + # the ground-truth. + prediction_mask = tf.logical_and(prediction_mask, tf.logical_not(is_crowd)) + + seq_preds = self._predictions[sequence_id] + seq_gts = self._ground_truth[sequence_id] + seq_intersects = self._intersections[sequence_id] + + # Compute and update areas of ground-truth, predictions and intersections. + _update_dict_stats( + seq_preds, y_pred[prediction_mask], + weights[prediction_mask] if weights is not None else None) + _update_dict_stats(seq_gts, y_true[label_mask], + weights[label_mask] if weights is not None else None) + + non_crowd_intersection = tf.logical_and(label_mask, prediction_mask) + intersection_ids = ( + y_true[non_crowd_intersection] * self._offset + + y_pred[non_crowd_intersection]) + _update_dict_stats( + seq_intersects, intersection_ids, + weights[non_crowd_intersection] if weights is not None else None) + + def merge_state(self, metrics: Sequence['STQuality']): + """Merges the results of multiple STQuality metrics. + + This can be used to distribute metric computation for multiple sequences on + multiple instances, by computing metrics on each sequence separately, and + then merging the metrics with this function. + + Note that only metrics with unique sequences are supported. Passing in + metrics with common instances is not supported. + + Args: + metrics: A sequence of STQuality objects with unique sequences. + + Raises: + ValueError: If a sequence is re-used between different metrics, or is + already in this metric. + """ + # pylint: disable=protected-access + for metric in metrics: + for sequence in metric._ground_truth.keys(): + if sequence in self._ground_truth: + raise ValueError('Tried to merge metrics with duplicate sequences.') + self._ground_truth[sequence] = metric._ground_truth[sequence] + self._predictions[sequence] = metric._predictions[sequence] + self._intersections[sequence] = metric._intersections[sequence] + self._iou_confusion_matrix_per_sequence[sequence] = ( + metric._iou_confusion_matrix_per_sequence[sequence]) + self._sequence_length[sequence] = metric._sequence_length[sequence] + # pylint: enable=protected-access + + def result(self) -> Dict[Text, Any]: + """Computes the segmentation and tracking quality. + + Returns: + A dictionary containing: + - 'STQ': The total STQ score. + - 'AQ': The total association quality (AQ) score. + - 'IoU': The total mean IoU. + - 'STQ_per_seq': A list of the STQ score per sequence. + - 'AQ_per_seq': A list of the AQ score per sequence. + - 'IoU_per_seq': A list of mean IoU per sequence. + - 'Id_per_seq': A list of sequence Ids to map list index to sequence. + - 'Length_per_seq': A list of the length of each sequence. + """ + # Compute association quality (AQ) + num_tubes_per_seq = [0] * len(self._ground_truth) + aq_per_seq = [0] * len(self._ground_truth) + iou_per_seq = [0] * len(self._ground_truth) + id_per_seq = [''] * len(self._ground_truth) + + for index, sequence_id in enumerate(self._ground_truth): + outer_sum = 0.0 + predictions = self._predictions[sequence_id] + ground_truth = self._ground_truth[sequence_id] + intersections = self._intersections[sequence_id] + num_tubes_per_seq[index] = len(ground_truth) + id_per_seq[index] = sequence_id + + for gt_id, gt_size in ground_truth.items(): + inner_sum = 0.0 + for pr_id, pr_size in predictions.items(): + tpa_key = self._offset * gt_id + pr_id + if tpa_key in intersections: + tpa = intersections[tpa_key].numpy() + fpa = pr_size.numpy() - tpa + fna = gt_size.numpy() - tpa + inner_sum += tpa * (tpa / (tpa + fpa + fna)) + + outer_sum += 1.0 / gt_size.numpy() * inner_sum + aq_per_seq[index] = outer_sum + + aq_mean = np.sum(aq_per_seq) / np.maximum(np.sum(num_tubes_per_seq), 1e-15) + aq_per_seq = aq_per_seq / np.maximum(num_tubes_per_seq, 1e-15) + + # Compute IoU scores. + # The rows correspond to ground-truth and the columns to predictions. + # Remove fp from confusion matrix for the void/ignore class. + total_confusion = np.zeros( + (self._confusion_matrix_size, self._confusion_matrix_size), + dtype=np.float64) + for index, confusion in enumerate( + self._iou_confusion_matrix_per_sequence.values()): + confusion = confusion.numpy() + removal_matrix = np.zeros_like(confusion) + removal_matrix[self._include_indices, :] = 1.0 + confusion *= removal_matrix + total_confusion += confusion + + # `intersections` corresponds to true positives. + intersections = confusion.diagonal() + fps = confusion.sum(axis=0) - intersections + fns = confusion.sum(axis=1) - intersections + unions = intersections + fps + fns + + num_classes = np.count_nonzero(unions) + ious = (intersections.astype(np.double) / + np.maximum(unions, 1e-15).astype(np.double)) + iou_per_seq[index] = np.sum(ious) / num_classes + + # `intersections` corresponds to true positives. + intersections = total_confusion.diagonal() + fps = total_confusion.sum(axis=0) - intersections + fns = total_confusion.sum(axis=1) - intersections + unions = intersections + fps + fns + + num_classes = np.count_nonzero(unions) + ious = (intersections.astype(np.double) / + np.maximum(unions, 1e-15).astype(np.double)) + iou_mean = np.sum(ious) / num_classes + + st_quality = np.sqrt(aq_mean * iou_mean) + st_quality_per_seq = np.sqrt(aq_per_seq * iou_per_seq) + return {'STQ': st_quality, + 'AQ': aq_mean, + 'IoU': float(iou_mean), + 'STQ_per_seq': st_quality_per_seq, + 'AQ_per_seq': aq_per_seq, + 'IoU_per_seq': iou_per_seq, + 'ID_per_seq': id_per_seq, + 'Length_per_seq': list(self._sequence_length.values()), + } + + def reset_states(self): + """Resets all states that accumulated data.""" + self._iou_confusion_matrix_per_sequence = collections.OrderedDict() + self._predictions = collections.OrderedDict() + self._ground_truth = collections.OrderedDict() + self._intersections = collections.OrderedDict() + self._sequence_length = collections.OrderedDict() diff --git a/metrics/vps_metrics.py b/metrics/vps_metrics.py new file mode 100644 index 0000000..5ab75d7 --- /dev/null +++ b/metrics/vps_metrics.py @@ -0,0 +1,369 @@ +# coding=utf-8 +# Copyright 2022 The Pix2Seq Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Video panoptic segmentation metrics.""" + +import collections +import io +import os +import tempfile +from typing import Sequence + +from absl import logging +import numpy as np +import pandas as pd +import PIL +import utils +from metrics import metric_registry +from metrics import segmentation_and_tracking_quality as stq +import seaborn as sns +from skimage import segmentation +import tensorflow as tf + +_SEMANTIC_PALETTE = [ + 0, 0, 0, + 128, 0, 0, + 0, 128, 0, + 128, 128, 0, + 0, 0, 128, + 128, 0, 128, + 0, 128, 128, + 128, 128, 128, + 64, 0, 0, + 191, 0, 0, + 64, 128, 0, + 191, 128, 0, + 64, 0, 128, + 191, 0, 128, + 64, 128, 128, + 191, 128, 128, + 31, 119, 180, + 255, 127, 14, + 44, 160, 44, + 214, 39, 40, + 148, 103, 189, + 140, 86, 75, + 227, 119, 194, + 127, 127, 127] + +# For instance map palette, get 32 colors from the seaborn palette, repeat +# 8 times to get a 256-color palette. Assign (0, 0, 0) as the first color. +_INSTANCE_PALETTE = [0, 0, 0] + (list( + np.asarray( + np.reshape(np.asarray(sns.color_palette('Spectral', 32)) * 255, [-1]), + np.uint8)) * 8)[:-3] + +_PANOPTIC_METRIC_OFFSET = 256 * 256 * 256 + + +def _write_to_png_file(img, filepath): + with io.BytesIO() as out: + img.save(out, format='PNG') + with tf.io.gfile.GFile(filepath, 'wb') as f: + f.write(out.getvalue()) + + +def _semantic_instance_maps_from_rgb(filename, rgb_instance_label_divisor=256): + with tf.io.gfile.GFile(filename, 'rb') as f: + panoptic_map = np.array(PIL.Image.open(f)).astype(np.int32) + semantic_map = panoptic_map[:, :, 0] + instance_map = ( + panoptic_map[:, :, 1] * rgb_instance_label_divisor + + panoptic_map[:, :, 2]) + return semantic_map, instance_map + + +def _panoptic_map_from_semantic_instance_maps(semantic_map, instance_map, + panoptic_label_divisor): + return semantic_map * panoptic_label_divisor + instance_map + + +def _panoptic_map_from_rgb(filename, panoptic_label_divisor, + rgb_instance_label_divisor=256): + """Loads a rgb format panoptic map from file and encode to single channel.""" + semantic_map, instance_map = _semantic_instance_maps_from_rgb( + filename, rgb_instance_label_divisor) + panoptic_map = _panoptic_map_from_semantic_instance_maps( + semantic_map, instance_map, panoptic_label_divisor) + return panoptic_map + + +def _panoptic_map_to_rgb(semantic_map, instance_map, filename, + rgb_instance_label_divisor=256): + """Converts a panoptic map to rgb format and write to file.""" + instance_map_1 = instance_map // rgb_instance_label_divisor + instance_map_2 = instance_map % rgb_instance_label_divisor + panoptic_map = np.stack( + [semantic_map, instance_map_1, instance_map_2], -1).astype(np.uint8) + panoptic_map = PIL.Image.fromarray(panoptic_map) + _write_to_png_file(panoptic_map, filename) + + +def _visualize_id_map(id_map, palette): + boundaries = segmentation.find_boundaries(id_map, mode='thin') + id_map[boundaries] = 0 + vis = PIL.Image.fromarray(id_map, mode='L') + vis.putpalette(palette) + return vis + + +def _get_new_id(exclusion_list, max_id): + for i in range(max_id): + if i not in exclusion_list: + return i + return 0 + + +def _in_list_of_lists(x, list_of_lists): + for l in list_of_lists: + if x in l: + return True + return False + + +class STQEvaluation(object): + """Evaluation class for the Segmentation and Tracking Quality (STQ).""" + + def __init__(self, + annotation_dir: str, + num_classes: int, + class_has_instances_list: Sequence[int], + ignore_label: int, + panoptic_label_divisor: int, + max_instances: int, + num_cond_frames: int): + self.stq = stq.STQuality( + num_classes=num_classes, + things_list=class_has_instances_list, + ignore_label=ignore_label, + max_instances_per_category=panoptic_label_divisor, + offset=_PANOPTIC_METRIC_OFFSET) + self.panoptic_label_divisor = panoptic_label_divisor + self.annotation_dir = annotation_dir + video_names = tf.io.gfile.listdir(annotation_dir) + self.video_name_to_id_map = { + video_names[i]: i for i in range(len(video_names)) + } + self.max_instances = max_instances + self.num_cond_frames = num_cond_frames + + def evaluate(self, result_dir: str, postprocess_ins: bool = True): + """Evaluates STQ for a result directory. + + Args: + result_dir: str, directory that contains predictions. + postprocess_ins: bool, whether to postprocess instance ids so that when + new instances appear, they get assigned ids that have not been used in + previous frames. + + Returns: + a dict of metric name to value. + """ + # Loop through all videos in result dir. + for video_name in tf.io.gfile.listdir(result_dir): + video_id = self.video_name_to_id_map[video_name] + + if postprocess_ins: + # Keep a map of original instance ids to processed instance ids. + id_map = np.zeros([self.max_instances], np.int32) + # 'used_ids' is the list of all used ids. 'recent_ids' are ids that + # appeared in the conditional frames. Only when an id appears in + # previous frames other than the conditional frames, we need to map it + # to a new id. Therefore we need to keep a running list of ids that + # appear in the conditional frames. + used_ids = [0] + recent_ids = collections.deque([[0]] * self.num_cond_frames, + self.num_cond_frames) + + for img in tf.io.gfile.listdir(os.path.join(result_dir, video_name)): + pred_file = os.path.join(result_dir, video_name, img) + gt_file = os.path.join(self.annotation_dir, video_name, img) + gt = _panoptic_map_from_rgb(gt_file, self.panoptic_label_divisor) + + if postprocess_ins: + semantic_map, instance_map = _semantic_instance_maps_from_rgb( + pred_file, self.panoptic_label_divisor) + + # Update the id map. + ids = list(np.unique(instance_map)) + for i in ids: + if id_map[i] == 0: + # the id hasn't appeared before. + id_map[i] = i + used_ids.append(i) + elif _in_list_of_lists(i, recent_ids): + # the id appeared in the conditional frames. + pass + else: + # the id is not in the conditional frames, but has been used. + new_id = _get_new_id(used_ids + ids, self.max_instances) + id_map[i] = new_id + used_ids.append(new_id) + + recent_ids.append(ids) + # Update the instance map. + instance_map = id_map[instance_map] + + pred = _panoptic_map_from_semantic_instance_maps( + semantic_map, instance_map, self.panoptic_label_divisor) + else: + pred = _panoptic_map_from_rgb(pred_file, self.panoptic_label_divisor) + self.stq.update_state(gt, pred, video_id) + + return self.stq.result() + + def reset_states(self): + self.stq.reset_states() + + +@metric_registry.MetricRegistry.register('segmentation_and_tracking_quality') +class STQMetric(object): + """Metric class for the Segmentation and Tracking Quality (STQ).""" + + def __init__(self, config): + self.config = config + self.results_dir = config.task.metric.get('results_dir') + self.metric_names = ['AQ', 'IoU', 'STQ'] + self.per_sequence_metric_names = [ + 'ID_per_seq', 'Length_per_seq', 'AQ_per_seq', 'IoU_per_seq', + 'STQ_per_seq'] + self.eval = STQEvaluation( + annotation_dir=config.dataset.annotations_dir, + num_classes=config.dataset.num_classes - 1, # exclude void pixel label + class_has_instances_list=config.dataset.class_has_instances_list, + ignore_label=config.dataset.ignore_label, + panoptic_label_divisor=config.dataset.panoptic_label_divisor, + max_instances=config.task.max_instances_per_image, + num_cond_frames=len(config.task.proceeding_frames.split(','))) + self.reset_states() + + def reset_states(self): + self.metric_values = None + self.eval.reset_states() + # For saving predictions for metric evaluation. + self._local_pred_dir_obj = tempfile.TemporaryDirectory() + # For saving visualization images. + self._panoptic_local_vis_dir_obj = tempfile.TemporaryDirectory() + + def _write_predictions(self, frame_id, semantic_map, instance_map, outdir): + # When saving output for evaluation, change semantic ids back to the + # original class ids, and 0 back to 255 which is to be ignored. + semantic_map = np.where(semantic_map == 0, + self.config.dataset.ignore_label, semantic_map - 1) + output_file = os.path.join(outdir, f'{frame_id:06}.png') + _panoptic_map_to_rgb(semantic_map, instance_map, output_file) + + def _write_visualizations(self, frame_id, semantic_map, instance_map, outdir): + """Writes visualization images to a directory. + + Args: + frame_id: int, the frame id. + semantic_map: uint8 of shape (h, w). + instance_map: uint8 of shape (h, w). + outdir: directory to write visualization images to. + """ + sem = _visualize_id_map(semantic_map, _SEMANTIC_PALETTE) + ins = _visualize_id_map(instance_map, _INSTANCE_PALETTE) + + _write_to_png_file(sem, os.path.join(outdir, f'{frame_id:06}_s.png')) + _write_to_png_file(ins, os.path.join(outdir, f'{frame_id:06}_i.png')) + + def record_prediction(self, predictions, video_name, frame_ids, step): + """Records predictions. + + Args: + predictions: uint8 of shape (num_frames, h, w, 2), containing semantic + map and instance map. + video_name: str. Video name. + frame_ids: list of int, or 1-d np.array. Frame ids of predictions. + step: int. The checkpoint step, used to name sub-directories. + """ + pred_dir = os.path.join(self._local_pred_dir_obj.name, str(step), + video_name) + if not tf.io.gfile.exists(pred_dir): + tf.io.gfile.makedirs(pred_dir) + vis_dir = os.path.join(self._panoptic_local_vis_dir_obj.name, str(step), + video_name) + if not tf.io.gfile.exists(vis_dir): + tf.io.gfile.makedirs(vis_dir) + + # Write predictions to temporary directory. + for fid, id_maps in zip(frame_ids, predictions): + sem_map, ins_map = id_maps[..., 0], id_maps[..., 1] + + # Encode semantic map and instance_map into one image and write to file. + self._write_predictions(fid, sem_map, ins_map, pred_dir) + + # Build visualization images and save to vis_dir. + if self.results_dir is not None: + self. _write_visualizations(fid, sem_map, ins_map, vis_dir) + + if self.results_dir is not None: + # Copy visualization images to results dir. + results_dir = os.path.join(self.results_dir, str(step), video_name) + utils.copy_dir(vis_dir, results_dir) + + # TODO(lala) - delete this. + results_dir = os.path.join(self.results_dir, 'pred', video_name) + utils.copy_dir(pred_dir, results_dir) + + logging.info('Done writing out pngs for %s', video_name) + + def _evaluate(self, step): + """Evaluates with predictions for all images. + + Call this function from `self.result`. + + Args: + step: int. The checkpoint step being evaluated. + + Returns: + dict from metric name to float value. + """ + result_path = os.path.join(self._local_pred_dir_obj.name, str(step)) + stq_metric = self.eval.evaluate(result_path) + + if self.results_dir is not None: + # Write metrics to result_dir. + result_dir = os.path.join(self.results_dir, str(step)) + csv_name_global_path = os.path.join(result_dir, 'global_results.csv') + csv_name_per_sequence_path = os.path.join(result_dir, + 'per_sequence_results.csv') + + # Global results. + g_res = np.asarray([stq_metric[n] for n in self.metric_names]) + g_res_ = np.reshape(g_res, [1, len(g_res)]) + table_g = pd.DataFrame(data=g_res_, columns=self.metric_names) + with tf.io.gfile.GFile(csv_name_global_path, 'w') as f: + table_g.to_csv(f, index=False, float_format='%.3f') + logging.info('Global results saved in %s', csv_name_global_path) + + # Per sequence results. + table_seq = pd.DataFrame( + data=list(zip( + *[list(stq_metric[n]) for n in self.per_sequence_metric_names])), + columns=self.per_sequence_metric_names) + with tf.io.gfile.GFile(csv_name_per_sequence_path, 'w') as f: + table_seq.to_csv(f, index=False, float_format='%.3f') + logging.info('Per-sequence results saved in %s', + csv_name_per_sequence_path) + + return {n: stq_metric[n] for n in self.metric_names} + + def result(self, step): + """Return the metric values (and compute it if needed).""" + if self.metric_values is None: + self.metric_values = self._evaluate(step) + return self.metric_values