Source code for seqio.metrics

# Copyright 2026 The SeqIO Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""MetricValue objects to wrap results being returned by metric funcitons."""

import dataclasses
import enum
import inspect
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union

import clu.metrics
import flax
import jax
import jax.numpy as jnp
import numpy as np
from seqio import utils
import tensorflow.compat.v2 as tf


[docs] @dataclasses.dataclass class MetricValue: """A base method for the dataclasses that represent tensorboard values. Task `metric_fn`s should output `Mapping[str, MetricValue]` which will be written by a `Logger`. """
[docs] @dataclasses.dataclass class Scalar(MetricValue): """The default tensorflow value, used for creating time series graphs.""" value: Union[int, float]
[docs] @dataclasses.dataclass class Text(MetricValue): """Text to output to tensorboard, markdown is rendered by tensorboard.""" textdata: Union[str, bytes]
[docs] @dataclasses.dataclass class Image(MetricValue): """An image to output to tensorboard. The format for the image array should match the format expected for the data parameter described [here](https://www.tensorflow.org/api_docs/python/tf/summary/image). """ image: np.ndarray max_outputs: int = 3
[docs] @dataclasses.dataclass class Audio(MetricValue): """An audio example to output to tensorboard. The format for the audio array should match the format expected for the data parameter described [here](https://www.tensorflow.org/api_docs/python/tf/summary/audio). """ audiodata: np.ndarray sample_rate: int = 44100 max_outputs: int = 3
[docs] @dataclasses.dataclass class Histogram(MetricValue): """A histogram to output to tensorboard.""" values: np.ndarray bins: Optional[int] = None
[docs] @dataclasses.dataclass class Generic(MetricValue): """A raw tensor to output to tensorboard.""" tensor: np.ndarray metadata: tf.compat.v1.SummaryMetadata
[docs] class ModelOutputType(enum.IntEnum): """Model output types.""" PREDICTION = 1 SCORE = 2 PREDICTION_WITH_AUX = 3 SCORE_WITH_INTERMEDIATES = 4 @classmethod def to_str(cls, enm): return { cls.PREDICTION: "prediction", cls.SCORE: "score", cls.PREDICTION_WITH_AUX: "prediction_with_aux", cls.SCORE_WITH_INTERMEDIATES: "score_with_intermediates", }[enm]
MetricFnCallable = Callable[..., Mapping[str, Union[MetricValue, float]]]
[docs] @flax.struct.dataclass class Metric(clu.metrics.Metric): """Base Metric class for seqio evaluation.""" model_output_type: ModelOutputType
[docs] @classmethod def from_model_output( cls, inputs: Sequence[Mapping[str, Any]], model_output: Union[np.ndarray, Tuple[np.ndarray, np.ndarray]], features: Mapping[str, utils.Feature], target_field_name: str = "targets", mask: Optional[np.ndarray] = None, indices_2d: Optional[np.ndarray] = None, ) -> "Metric": """Creates a `seqio.Metric` from model outputs. Args: inputs: Examples in dataset. model_output: Model output computed by model functions. features: Output features defined in seqio.Task. target_field_name: Field name of the target sequence. mask: A boolean array to indicate which examples in the inputs are included for metric evaluation. indices_2d: 2d-indices of examples in the inputs/model_output. First dimension is shard id, the second is the example id within that shard. Returns: An instance of Metric. Raises: NotImplementedError: Must override from_model_output() """ raise NotImplementedError("Must override from_model_output()")
[docs] class CollectingMetric(clu.metrics.CollectingMetric): """CollectingMetric interface for seqio evaluation."""
[docs] @classmethod def from_model_output( # pylint:disable=missing-function-docstring cls, inputs: Sequence[Mapping[str, Any]], model_output: Union[np.ndarray, Tuple[np.ndarray, np.ndarray]], features: Mapping[str, utils.Feature], target_field_name: str = "targets", mask: Optional[np.ndarray] = None, indices_2d: Optional[np.ndarray] = None, ): del inputs, features, target_field_name num_examples = ( len(model_output[0]) if isinstance(model_output, tuple) else len(model_output) ) if mask is None: mask = jnp.ones((num_examples,), jnp.int32) if indices_2d is None: indices_2d = jnp.transpose( jnp.stack([ jnp.zeros((num_examples,), jnp.int32), jnp.arange(num_examples, dtype=jnp.int32), ]) ) return cls( values={ "model_output": model_output, "indices_2d": indices_2d, "mask": mask, } )
[docs] def actual_compute( self, task_dataset_as_numpy, task_output_features, target_field_name: str = "targets", cached_targets: Optional[List[str]] = None, ): """Implements the metric computation logics for CollectingMetric. Args: task_dataset_as_numpy: Examples in dataset. task_output_features: Output features defined in the seqio.Task. target_field_name: Field name of the target sequence. cached_targets: targets that have been cached by Evaluator and can be supplied here to save time of post-processing targets. Returns: A tuple of two items, first item is a dict of metric results, the second item is targets_and_inferences. Raises: NotImplementedError: Must override from_model_output() """ raise NotImplementedError("Must override from_model_output()")
# TODO(kehanghan): consider using CollectingMetric for LegacyMetric.
[docs] @flax.struct.dataclass class LegacyMetric(Metric): """Metric class for legacy use-case where metric fn is supplied.""" _metric_fn: MetricFnCallable _postprocess_fn: Callable[..., Any] metric_fn_kwargs: Dict[str, Any] targets_and_inferences: Dict[str, Any]
[docs] @classmethod def empty(cls, metric_fn, postprocess_fn) -> "LegacyMetric": pos_args = tuple( key for key, param in inspect.signature(metric_fn).parameters.items() if param.default == inspect.Parameter.empty and param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD ) if pos_args == ("targets", "scores"): model_output_type = ModelOutputType.SCORE elif pos_args == ("targets", "predictions"): model_output_type = ModelOutputType.PREDICTION elif pos_args == ("targets", "predictions", "aux_values"): model_output_type = ModelOutputType.PREDICTION_WITH_AUX else: raise ValueError( "Metric functions must have positional arguments matching either " "('targets', 'scores'), ('targets', 'predictions') or " "('targets', 'predictions', 'aux_values'). " f"Got: {pos_args}" ) return cls( _metric_fn=metric_fn, _postprocess_fn=postprocess_fn, model_output_type=model_output_type, metric_fn_kwargs={}, targets_and_inferences={}, )
[docs] def postprocess_fn( self, targets_or_predictions: Any, **postprocess_kwargs ) -> Any: """Applies the postprocessing to targets or predictions.""" if self._postprocess_fn: return self._postprocess_fn(targets_or_predictions, **postprocess_kwargs) return targets_or_predictions
[docs] def from_model_output( # pylint:disable=arguments-renamed self, inputs: Sequence[Mapping[str, Any]], model_output: Union[np.ndarray, Tuple[np.ndarray, np.ndarray]], features: Mapping[str, utils.Feature], target_field_name: str = "targets", mask: Optional[np.ndarray] = None, ) -> "LegacyMetric": if not self.metric_fn_kwargs.get("targets"): # Postprocesses the targets here. postprocessed_targets = [] for ex in inputs: pretokenized_target_field_name = target_field_name + "_pretokenized" if pretokenized_target_field_name in ex: target = ex[pretokenized_target_field_name] else: target = features[target_field_name].vocabulary.decode( list(ex[target_field_name]) ) if isinstance(target, bytes): target = target.decode("utf-8") postprocessed_targets.append( self.postprocess_fn(target, example=ex, is_target=True) ) self.metric_fn_kwargs["targets"] = postprocessed_targets self.targets_and_inferences["targets"] = postprocessed_targets if self.model_output_type == ModelOutputType.SCORE: self.metric_fn_kwargs["scores"] = model_output self.targets_and_inferences["score"] = model_output else: vocab = features[target_field_name].vocabulary if self.model_output_type == ModelOutputType.PREDICTION_WITH_AUX: self.metric_fn_kwargs["aux_values"] = model_output[1] self.targets_and_inferences["aux_value"] = model_output[1] predictions = [vocab.decode(tokens) for tokens in model_output[0]] elif self.model_output_type == ModelOutputType.PREDICTION: # Default behavior for top-1 decoding, model_output is a 2d array. # first dim is for batch, second is for sequence length. if isinstance(model_output, np.ndarray) and model_output.ndim == 2: predictions = [vocab.decode(tokens) for tokens in model_output] else: # In case of top-k decoding, model_output will be a 3d array # first dim is for batch, second is for num_decodes, third is for # sequence length. predictions = [] for sequences in model_output: predictions_for_one_example = [] for sequence in sequences: predictions_for_one_example.append(vocab.decode(sequence)) predictions.append(predictions_for_one_example) self.targets_and_inferences["output"] = predictions # Postprocesses the predictions here. postprocessed_predictions = [ self.postprocess_fn(p, example=ex, is_target=False) for ex, p in zip(inputs, predictions) ] self.metric_fn_kwargs["predictions"] = postprocessed_predictions self.targets_and_inferences["prediction"] = postprocessed_predictions return self
[docs] def compute(self): return self._metric_fn(**self.metric_fn_kwargs)
[docs] def remove_padding_examples(model_output, indices_2d, mask): """Removes padding examples indicated by the mask array. Args: model_output: model outputs of all the examples (including the padding ones). The padding examples are used to make sure during inference, the inference function receives full batch if the last batch does not enough examples. indices_2d: 2d indices of all the examples. mask: an array of booleans. 1 indicates valid example, 0 indicates padded example that needs to be removed. Returns: 2d-indices and model outputs of all the non-padding examples. """ indices_2d = indices_2d[mask == 1] model_output = jax.tree.map(lambda x: x[mask == 1], model_output) return indices_2d, model_output
[docs] def globally_sort_model_output(model_output, indices_2d): """Globally sorts model ouputs by the 2d indices of the examples. The sorting is done first by shard id (first index of the 2d-index) and then by example id (second index of the 2d-index). Args: model_output: model outputs of all the examples. indices_2d: 2d indices of all the examples. Returns: sorted model outputs. """ permutation = np.lexsort((indices_2d[:, 1], indices_2d[:, 0])) def _sort_by_permutation(x): return np.array([x[permutation[i]] for i in range(len(permutation))]) model_output = jax.tree.map(_sort_by_permutation, model_output) return model_output
[docs] @flax.struct.dataclass class PassthroughLegacyMetric(CollectingMetric): """Makes PassthroughLegacyMetric from metric functions."""
[docs] @classmethod def from_metric_fn( cls, metric_fn: MetricFnCallable, postprocess_fn: Optional[Callable[..., Any]] = None, ): """Creates `PassthroughLegacyMetric` from `metric_fn` and `postprocess_fn`. Example: ``` squad_cls = PassthroughLegacyMetric.from_metric_fn( metric_fn=t5_metrics.squd, postprocess_fn=t5_postprocessors.qa) ``` Args: metric_fn: Function used to compute metric. postprocess_fn: Function used to process targets (vocab decoded) and predictions (vocab decoded) before feeding into metric_fn. Returns: A `Metric` that calls `metric_fn` and `postprocess_fn` in its `.from_model_output()`. """ def _get_model_output_type() -> ModelOutputType: pos_args = tuple( key for key, param in inspect.signature(metric_fn).parameters.items() if param.default == inspect.Parameter.empty and param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD ) if pos_args == ("targets", "scores"): model_output_type = ModelOutputType.SCORE elif pos_args == ("targets", "predictions"): model_output_type = ModelOutputType.PREDICTION elif pos_args == ("targets", "predictions", "aux_values"): model_output_type = ModelOutputType.PREDICTION_WITH_AUX else: raise ValueError( "Metric functions must have positional arguments matching either " "('targets', 'scores'), ('targets', 'predictions') or " "('targets', 'predictions', 'aux_values'). " f"Got: {pos_args}" ) return model_output_type @flax.struct.dataclass class FromMetricFun(cls): """Wrapper PassthroughLegacyMetric class that runs metric_fn.""" model_output_type: ModelOutputType = _get_model_output_type() @classmethod def postprocess( cls, targets_or_predictions: Any, **postprocess_kwargs ) -> Any: """Applies the postprocessing to targets or predictions.""" if postprocess_fn: return postprocess_fn(targets_or_predictions, **postprocess_kwargs) return targets_or_predictions def postprocess_targets( self, task_dataset_as_numpy, task_output_features, target_field_name: str = "targets", ): """Applies the postprocessing to targets.""" # Postprocesses the targets here. postprocessed_targets = [] for ex in task_dataset_as_numpy: pretokenized_target_field_name = target_field_name + "_pretokenized" if pretokenized_target_field_name in ex: target = ex[pretokenized_target_field_name] else: target = task_output_features[target_field_name].vocabulary.decode( list(ex[target_field_name]) ) if isinstance(target, bytes): target = target.decode("utf-8") postprocessed_targets.append( type(self).postprocess(target, example=ex, is_target=True) ) return postprocessed_targets def actual_compute( self, task_dataset_as_numpy, task_output_features, target_field_name: str = "targets", cached_targets: Optional[List[str]] = None, ): # Postprocesses the targets here. if not cached_targets: postprocessed_targets = self.postprocess_targets( task_dataset_as_numpy, task_output_features, target_field_name ) else: postprocessed_targets = cached_targets metric_fn_kwargs, targets_and_inferences = {}, {} metric_fn_kwargs["targets"] = postprocessed_targets targets_and_inferences["targets"] = postprocessed_targets # We process the model outputs here by the steps below. # Step 1: removes padded examples using mask. indices_2d, model_output = remove_padding_examples( self.values["model_output"], self.values["indices_2d"], self.values["mask"], ) assert len(postprocessed_targets) == len(indices_2d) # Step 2: sorts the model outputs by 2d-indices, namely (shard_id, # index_within_shard) to align with targets. model_output = globally_sort_model_output(model_output, indices_2d) if type(self).model_output_type == ModelOutputType.SCORE: metric_fn_kwargs["scores"] = model_output targets_and_inferences["score"] = model_output else: vocab = task_output_features[target_field_name].vocabulary if ( type(self).model_output_type == ModelOutputType.PREDICTION_WITH_AUX ): metric_fn_kwargs["aux_values"] = model_output[1] targets_and_inferences["aux_value"] = model_output[1] predictions = [vocab.decode(tokens) for tokens in model_output[0]] elif type(self).model_output_type == ModelOutputType.PREDICTION: # Default behavior for top-1 decoding, model_output is a 2d array. # first dim is for batch, second is for sequence length. if isinstance(model_output, np.ndarray) and model_output.ndim == 2: predictions = [vocab.decode(tokens) for tokens in model_output] elif ( isinstance(model_output, np.ndarray) and model_output.ndim == 3 ): # In case of top-k decoding, model_output will be a 3d array # first dim is for batch, second is for num_decodes, third is for # sequence length. predictions = [] for sequences in model_output: predictions_for_one_example = [] for sequence in sequences: predictions_for_one_example.append(vocab.decode(sequence)) predictions.append(predictions_for_one_example) else: # If neither 2d or 3d, assume that model_output is already # decoded. predictions = model_output targets_and_inferences["output"] = predictions # Postprocesses the predictions here. postprocessed_predictions = [ type(self).postprocess(p, example=ex, is_target=False) for ex, p in zip(task_dataset_as_numpy, predictions) ] metric_fn_kwargs["predictions"] = postprocessed_predictions targets_and_inferences["prediction"] = postprocessed_predictions return metric_fn(**metric_fn_kwargs), targets_and_inferences return FromMetricFun