Source code for seqio.dataset_providers

# Copyright 2026 The SeqIO Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Classes for data loading and processing.

Defines Tasks, TaskRegistry, Mixture, and MixtureRegistry
"""

from __future__ import annotations

import abc
import collections
import dataclasses
import functools
import inspect
import json
import numbers
import operator
import os
import re
from typing import Any, Callable, Iterable, List, Mapping, MutableMapping, Optional, Sequence, Set, Tuple, Type, Union

from absl import logging
import clu.metrics
import editdistance
import numpy as np
from packaging import version as version_lib
import pyglove as pg
from seqio import metrics as metrics_lib
from seqio import preprocessors as seqio_preprocessors
from seqio import task_registry_provenance_tracking
from seqio import utils
from seqio.feature_converters import FeatureConverter
from seqio.vocabularies import PassThroughVocabulary
from seqio.vocabularies import Vocabulary
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
import typing_extensions


_DEFAULT_FEATURE_KEYS = ["inputs", "targets"]

_VALID_TASK_NAME_REGEX = re.compile(r"^[\w\d\.\:_#]+$")
_MAX_EXAMPLES_TO_MEM_CACHE = 10000
SHUFFLE_BUFFER_SIZE = 1000

DatasetReaderType = Callable[[Union[str, Iterable[str]]], tf.data.Dataset]
DecodeFnType = Callable[..., Mapping[str, tf.train.Feature]]
Feature = utils.Feature


@dataclasses.dataclass(frozen=True)
class ContinuousFeature(Feature):
  """A container for multi-modal output features of data providers."""

  vocabulary: Vocabulary = dataclasses.field(
      default_factory=lambda: PassThroughVocabulary(size=0)
  )
  add_eos: bool = False


[docs] @dataclasses.dataclass(frozen=True) class ShardInfo: """A container for specifying sharding info.""" index: int num_shards: int
[docs] @dataclasses.dataclass(frozen=True) class SourceInfo: """Information about the source location of a class or function. Attributes: file_path: where on disk the source code is located. line_number: the line number in the file where the class/function/etc is defined. """ file_path: Optional[str] = None line_number: Optional[int] = None
[docs] @classmethod @functools.lru_cache(maxsize=None) def for_class(cls, klass) -> SourceInfo: """Returns info about where the given class was defined.""" try: source_file = inspect.getsourcefile(klass) except TypeError: source_file = None try: _, line_number = inspect.getsourcelines(klass) except TypeError: line_number = None return SourceInfo( file_path=source_file, line_number=line_number, )
def has_meaningful_info(self) -> bool: return bool(self.file_path)
[docs] class DatasetProviderBase(metaclass=abc.ABCMeta): """Abstract base for classes that provide a tf.data.Dataset.""" @property @abc.abstractmethod def output_features(self) -> Mapping[str, Feature]: raise NotImplementedError @property @abc.abstractmethod def splits(self) -> Sequence[str]: raise NotImplementedError
[docs] @abc.abstractmethod def get_dataset( self, sequence_length: Optional[Mapping[str, int]] = None, split: str = tfds.Split.TRAIN, use_cached: bool = False, shuffle: bool = True, seed: Optional[int] = None, shard_info: Optional[ShardInfo] = None, num_epochs: Optional[int] = 1, ) -> tf.data.Dataset: """Returns the requested tf.data.Dataset.""" raise NotImplementedError
@abc.abstractmethod def num_input_examples(self, split: str) -> Optional[int]: raise NotImplementedError
[docs] class DatasetProviderRegistry(object): """Base for registry of data providers. Subclasses must wrap `get` method to override the return type for pytype. TODO(adarob): Remove the need to override `get`. """ # Class variables must be defined in subclasses. _REGISTRY: MutableMapping[str, DatasetProviderBase] _PROVIDER_TYPE: Type[DatasetProviderBase]
[docs] @classmethod def add_provider(cls, name: str, provider): """Adds a data provider instance to the registry.""" if name in cls._REGISTRY: raise ValueError("Attempting to register duplicate provider: %s" % name) if not isinstance(provider, cls._PROVIDER_TYPE): raise ValueError( "Attempting to register a class of an invalid type. " "Expecting instance of %s, got %s" % (cls._PROVIDER_TYPE, type(provider).__name__) ) cls._REGISTRY[name] = provider task_registry_provenance_tracking.maybe_record_provenance( frame=inspect.currentframe(), name=name, provider_type=provider.__class__.__name__, )
[docs] @classmethod def add(cls, name: str, provider_cls, provider_kwargs): """Instantiates and adds provider to the registry.""" if not issubclass(provider_cls, cls._PROVIDER_TYPE): raise ValueError( "Attempting to register a class of an invalid type. " "Expecting instance of %s, got %s" % (cls._PROVIDER_TYPE, provider_cls) ) provider = provider_cls(**provider_kwargs) # pytype: disable=wrong-arg-types # dynamic-method-lookup cls.add_provider(name, provider) return provider
[docs] @classmethod def remove(cls, name): """Remove provider from the registry, if it exists.""" if name in cls._REGISTRY: del cls._REGISTRY[name]
[docs] @classmethod def get(cls, name): """Returns provider from the registry.""" if name not in cls._REGISTRY: raise ValueError("Provider name not registered: %s" % name) return cls._REGISTRY[name]
[docs] @classmethod def names(cls): """Returns all provider names in registry.""" return cls._REGISTRY.keys()
[docs] @classmethod def reset(cls): """Removes all of the registered tasks.""" cls._REGISTRY = {}
[docs] @classmethod def get_dataset( cls, name, sequence_length, split, use_cached=False, shuffle=True, seed=None, shard_info=None, num_epochs=1, ): """Returns the requested tf.data.Dataset.""" return cls.get(name).get_dataset( sequence_length=sequence_length, split=split, use_cached=use_cached, shuffle=shuffle, seed=seed, shard_info=shard_info, num_epochs=num_epochs, )
# =============================== DataSources ==================================
[docs] class DataSourceInterface(typing_extensions.Protocol): """Interface for DataSource.""" def num_input_examples(self, split: str) -> int: ... @property def caching_permitted(self) -> bool: ... @property def splits(self) -> Sequence[str]: ... @property def supports_arbitrary_sharding(self) -> bool: ... @property def output_features(self) -> Mapping[str, Feature]: ... def list_shards(self, split: str) -> Sequence[str]: ... def get_dataset( self, split: str, shuffle: bool = True, seed: Optional[int] = None, shard_info: Optional[ShardInfo] = None, ) -> tf.data.Dataset: ...
[docs] class DataSource(DatasetProviderBase): """A `DatasetProvider` that provides raw data from an input source. Inherits all abstract methods and properties of `DatasetProviderBase` except those overridden below. """ def __init__( self, splits: Iterable[str], num_input_examples: Optional[Mapping[str, int]] = None, caching_permitted: bool = True, performs_internal_shuffling: bool = False, ): self._splits = tuple(splits) self._num_input_examples = ( dict(num_input_examples) if num_input_examples is not None else None ) self._caching_permitted = caching_permitted self._performs_internal_shuffling = performs_internal_shuffling @property def caching_permitted(self) -> bool: """Indicates whether this data source may be cached. Caching may be prohibited for the sake of data versioning rigor or as a matter of policy for certain datasets. """ return self._caching_permitted @property def splits(self) -> Sequence[str]: return self._splits @property @abc.abstractmethod def supports_arbitrary_sharding(self) -> bool: """Whether supports sharding beyond those available in `list_shards`.""" raise NotImplementedError @property def output_features(self) -> Mapping[str, Feature]: """Override unused property of `DatasetProviderBase`.""" raise NotImplementedError @property def performs_internal_shuffling(self) -> bool: """Indicates whether this data source performs internal shuffling. Some datasets may provide internal shuffling mechanisms that could allow the dataset to be shuffled without calling ds.shuffle(). """ return self._performs_internal_shuffling
[docs] @abc.abstractmethod def list_shards(self, split: str) -> Sequence[str]: """Returns string identifiers of input shards.""" raise NotImplementedError
[docs] @abc.abstractmethod def get_dataset( self, # pytype: disable=signature-mismatch # overriding-default-value-checks split: str = tfds.Split.TRAIN, shuffle: bool = True, seed: Optional[int] = None, shard_info: Optional[ShardInfo] = None, *, # remaining args are out of order from parent sequence_length: Optional[Mapping[str, int]] = None, # Unused use_cached: bool = False, # Unused num_epochs: Optional[int] = 1, # Unused ) -> tf.data.Dataset: """Overrides base class to add shard identifier and remove use_cached. Args: split: string, the split to return. shuffle: bool, whether to shuffle the input source. seed: tf.int64 scalar tf.Tensor (or None) for shuffling input source. shard_info: optional specification for loading a shard of the split. sequence_length: Unused use_cached: Unused num_epochs: Unused """ raise NotImplementedError
def num_input_examples(self, split: str) -> Optional[int]: # pytype: disable=signature-mismatch # overriding-return-type-checks if self._num_input_examples is None: return None return self._num_input_examples[split]
def _validate_args(fn, expected_args: Sequence[str]): """Ensure function/protocol is callable with exactly expected args.""" params = tuple(inspect.signature(fn).parameters.values()) actual_args = tuple(p.name for p in params) expected_args = tuple(expected_args) if actual_args[: len(expected_args)] != expected_args: raise ValueError( "'%s' must have initial args %s, got: %s" % (utils.function_name(fn), expected_args, actual_args) ) actual_nondefault_args = tuple(p.name for p in params if p.default == p.empty) if actual_nondefault_args != expected_args[: len(actual_nondefault_args)]: raise ValueError( "'%s' may only have positional args %s, got: %s" % (utils.function_name(fn), expected_args, actual_nondefault_args) ) class DatasetFnCallable(typing_extensions.Protocol): def __call__( self, split: str, shuffle_files: bool, seed: Optional[int] = None ) -> tf.data.Dataset: ...
[docs] class FunctionDataSource(DataSource): """A `DataSource` that uses a function to provide the input data. This source is not recommended when shuffling is required unless it is cached/materialized in advance. Using this source without caching for training will result in insufficient shuffling and lead to repeated data on restarts. """ def __init__( self, dataset_fn: DatasetFnCallable, splits: Iterable[str], num_input_examples: Optional[Mapping[str, int]] = None, caching_permitted: bool = True, ): """FunctionDataSource constructor. Args: dataset_fn: a function with the signature `dataset_fn(split, shuffle_files)' (and optionally the variable `seed`) that returns a `tf.data.Dataset`. splits: an iterable of applicable string split names. num_input_examples: dict or None, an optional dictionary mapping split to its size in number of input examples (before preprocessing). The `num_input_examples` method will return None if not provided. caching_permitted: indicates whether this data source may be cached. Default True. """ _validate_args(dataset_fn, ["split", "shuffle_files"]) self._dataset_fn = dataset_fn super().__init__( splits=splits, num_input_examples=num_input_examples, caching_permitted=caching_permitted, ) @property def supports_arbitrary_sharding(self) -> bool: return False def __repr__(self): return ( f"{self.__class__.__name__}(" f"dataset_fn={utils.function_name(self._dataset_fn)}," f" splits={self.splits}," f" num_input_examples={self._num_input_examples}," f" caching_permitted={self.caching_permitted})" )
[docs] def get_dataset( self, split: str = tfds.Split.TRAIN, shuffle: bool = True, seed: Optional[int] = None, shard_info: Optional[ShardInfo] = None, *, # remaining args are out of order from parent sequence_length: Optional[Mapping[str, int]] = None, # Unused use_cached: bool = False, # Unused num_epochs: Optional[int] = 1, # Unused ) -> tf.data.Dataset: if shard_info and shard_info.num_shards > 1: raise ValueError( "`FunctionDataSource` does not support low-level sharding. Use " "tf.data.Dataset.shard instead." ) if shuffle: logging.warning( "Using an uncached FunctionDataset for training is not recommended " "since it often results in insufficient shuffling on restarts, " "resulting in overfitting. It is highly recommended that you cache " "this task before training with it or use a data source that " "supports lower-level shuffling (e.g., FileDataSource)." ) if seed is None: ds = self._dataset_fn(split=split, shuffle_files=shuffle) else: _validate_args(self._dataset_fn, ["split", "shuffle_files", "seed"]) ds = self._dataset_fn(split=split, shuffle_files=shuffle, seed=seed) return ds
[docs] def list_shards(self, split: str) -> Sequence[str]: return [split]
[docs] class TfdsDataSource(DataSource): """A `DataSource` that uses TensorFlow Datasets to provide the input data.""" def __init__( self, tfds_name: Optional[str] = None, tfds_data_dir: Optional[str] = None, splits: Optional[ Union[Iterable[str], Mapping[str, str], Mapping[str, utils.TfdsSplit]] ] = None, caching_permitted: bool = True, decoders: Optional[tfds.typing.TreeDict[tfds.decode.Decoder]] = None, tfds_builder_kwargs: Optional[dict[str, Any]] = None, read_only: bool = False, ): """TfdsTask constructor. Args: tfds_name: The name and version number of a TFDS dataset, optionally with a config. If `tfds_name` is not specified then `splits` values must be instances of `TfdsSplit`. tfds_data_dir: An optional path to a specific TFDS data directory to use. If provided `tfds_name` must be a valid dataset in the directory. If `tfds_name` is empty `tfds_dara_dir` must point to the directory with one dataset. splits: an iterable of allowable string split names, a dict mapping allowable canonical splits (e.g., 'validation') to TFDS splits or slices (e.g., 'train[':1%']), or `TfdsSplit` (e.g. `TfdsSplit(dataset='mnist', split='train')`), or None. The default, None, uses all available splits from the TFDS dataset info. If `TfdsSplit` are used then `tfds_name` must be empty. caching_permitted: indicates whether this data source may be cached. Default True. decoders: dict (optional), mapping from features to tfds.decode.Decoders, such as tfds.decode.SkipDecoding() for skipping image byte decoding. tfds_builder_kwargs: `dict` (optional), keyword arguments to be passed to the `tfds.core.DatasetBuilder` constructor through `tfds.load()` and `tfds.builder()`. read_only: whether `get_dataset` can trigger the generation of a dataset. """ if splits and not isinstance(splits, dict): splits = {k: k for k in splits} self._tfds_dataset = utils.LazyTfdsLoader( tfds_name, data_dir=tfds_data_dir, split_map=splits if isinstance(splits, dict) else None, decoders=decoders, builder_kwargs=tfds_builder_kwargs, read_only=read_only, ) # If splits are not provided, we pass an empty tuple and use the lazy # lookup in the `splits` property. super().__init__(splits=splits or (), caching_permitted=caching_permitted) @property def splits(self): """Overrides since we can't call `info.splits` until after init.""" return self._splits or self.tfds_dataset.info.splits @property def tfds_dataset(self) -> utils.LazyTfdsLoader: return self._tfds_dataset @property def supports_arbitrary_sharding(self) -> bool: return False def __str__(self): return f"{self.__class__.__name__}(tfds_dataset={str(self.tfds_dataset)})" def __repr__(self): return ( f"{self.__class__.__name__}(tfds_dataset={str(self.tfds_dataset)}," f" splits={self.splits}, caching_permitted={self.caching_permitted})" )
[docs] def get_dataset( self, split: Optional[str] = None, shuffle: bool = True, seed: Optional[int] = None, shard_info: Optional[ShardInfo] = None, *, # remaining args are out of order from parent sequence_length: Optional[Mapping[str, int]] = None, # Unused use_cached: bool = False, # Unused num_epochs: Optional[int] = 1, # Unused ) -> tf.data.Dataset: if split is None: split = tfds.Split.TRAIN return self.tfds_dataset.load( split, shuffle_files=shuffle, seed=seed, shard_info=shard_info )
[docs] def num_input_examples(self, split: str) -> Optional[int]: """Overrides since we can't call `info.splits` until after init.""" return self.tfds_dataset.size(split)
[docs] @functools.lru_cache(maxsize=1024) def list_shards(self, split: str) -> Sequence[str]: def _get_filename(info): if isinstance(info, dict): # this is true for unit tests return info["filename"] return info.filename # TFDS FileInstruction return [_get_filename(info) for info in self.tfds_dataset.files(split)]
[docs] class FileDataSource(DataSource): """A `DataSource` that reads a file to provide the input dataset.""" def __init__( self, read_file_fn: Callable[[tf.data.Dataset], tf.data.Dataset], split_to_filepattern: Mapping[str, Union[str, Iterable[str]]], num_input_examples: Optional[Mapping[str, int]] = None, caching_permitted: bool = True, file_shuffle_buffer_size: Optional[int] = None, cycle_length: int = 16, block_length: int = 16, performs_internal_shuffling: bool = False, ): """FileDataSource constructor. Args: read_file_fn: a callable for creating a `tf.data.Dataset` from a `tf.data.Dataset` of file paths, e.g., `tf.data.TFRecordDataset`. split_to_filepattern: a mapping from split names to filepatterns to be expanded with glob. num_input_examples: dict or None, an optional dictionary mapping split to its size in number of input examples (before preprocessing). The `num_input_examples` method will return None if not provided. caching_permitted: indicates whether this data source may be cached. Default True. file_shuffle_buffer_size: The buffer size to shuffle files when needed. If None, the number of files is used as buffer size for a perfect shuffle (default and recommended). A value of 16 may be explicitly set to replicate earlier behavior. cycle_length: The cycle_length to pass to tf.data.Dataset.interleave. block_length: The block_length to pass to tf.data.Dataset.interleave. performs_internal_shuffling: Allow enclosing task to call get_dataset with shuffle_buffer_size=None. In this case, only filename shuffling will be performed when shuffle==True. """ self._split_to_filepattern = split_to_filepattern self._reader = read_file_fn self._file_shuffle_buffer_size = file_shuffle_buffer_size self._cycle_length = cycle_length self._block_length = block_length super().__init__( splits=split_to_filepattern.keys(), num_input_examples=num_input_examples, caching_permitted=caching_permitted, performs_internal_shuffling=performs_internal_shuffling, ) @property def supports_arbitrary_sharding(self) -> bool: return False def __str__(self): return f"{self.__class__.__name__}({self._split_to_filepattern})" def __repr__(self): return ( f"{self.__class__.__name__}(" f"split_to_filepattern={self._split_to_filepattern}," f" num_input_examples={self._num_input_examples}," f" caching_permitted={self._caching_permitted}," f" file_shuffle_buffer_size={self._file_shuffle_buffer_size}," f" cycle_length={self._cycle_length}," f" block_length={self._block_length})" )
[docs] def get_dataset( self, split: str = tfds.Split.TRAIN, shuffle: bool = True, seed: Optional[int] = None, shard_info: Optional[ShardInfo] = None, *, # remaining args are out of order from parent sequence_length: Optional[Mapping[str, int]] = None, # Unused use_cached: bool = False, # Unused num_epochs: Optional[int] = 1, # Unused ) -> tf.data.Dataset: files = self.list_shards(split) if not files: raise ValueError( "No file is found for the file pattern: " f"{self._split_to_filepattern[split]}." ) files_ds = tf.data.Dataset.from_tensor_slices(np.array(files, dtype=str)) if shard_info: if len(files) < shard_info.num_shards: raise ValueError( f"Dataset has too few files to shard. {len(files)} files vs " f"{shard_info.num_shards} shards requested." ) files_ds = files_ds.shard(shard_info.num_shards, shard_info.index) if shuffle: if self._file_shuffle_buffer_size: logging.warning( ( "`file_shuffle_buffer_size` is explicitly set to %d; this may" " lead to an imperfect file shuffle. Leave" " `file_shuffle_buffer_size` unset for a perfect shuffle." ), self._file_shuffle_buffer_size, ) file_shuffle_buffer_size = self._file_shuffle_buffer_size or len(files) files_ds = files_ds.shuffle( buffer_size=file_shuffle_buffer_size, seed=seed ) return files_ds.interleave( self._reader, cycle_length=self._cycle_length, block_length=self._block_length, num_parallel_calls=tf.data.experimental.AUTOTUNE, )
[docs] def list_shards(self, split: str) -> Sequence[str]: filepattern = self._split_to_filepattern[split] return utils.list_files(filepattern)
[docs] class TextLineDataSource(FileDataSource): """A `FileDataSource` that reads lines of text from a file as input.""" def __init__( self, split_to_filepattern: Mapping[str, Union[str, Iterable[str]]], skip_header_lines: int = 0, num_input_examples: Optional[Mapping[str, int]] = None, caching_permitted: bool = True, file_shuffle_buffer_size: Optional[int] = None, cycle_length: int = 16, block_length: int = 16, ): """TextLineDataSource constructor. Args: split_to_filepattern: a mapping from split names to filepatterns to be expanded with glob. skip_header_lines: int, number of header lines to skip in each source file. num_input_examples: dict or None, an optional dictionary mapping split to its size in number of input examples (before preprocessing). The `num_input_examples` method will return None if not provided. caching_permitted: indicates whether this data source may be cached. Default True. file_shuffle_buffer_size: The buffer size to shuffle files when needed. If None, the number of files is used as buffer size for a perfect shuffle (default and recommended). A value of 16 may be explicitly set to replicate earlier behavior. cycle_length: The cycle_length to pass to tf.data.Dataset.interleave. block_length: The block_length to pass to tf.data.Dataset.interleave. """ # Used during caching. self._skip_header_lines = skip_header_lines def read_file_fn(filepattern): return tf.data.TextLineDataset(filepattern).skip(skip_header_lines) super().__init__( read_file_fn=read_file_fn, split_to_filepattern=split_to_filepattern, num_input_examples=num_input_examples, caching_permitted=caching_permitted, file_shuffle_buffer_size=file_shuffle_buffer_size, cycle_length=cycle_length, block_length=block_length, )
[docs] class TFExampleDataSource(FileDataSource): """A `FileDataSource` that reads files of tf.train.Example protos as input.""" def __init__( self, split_to_filepattern: Mapping[str, Union[str, Iterable[str]]], feature_description: Mapping[ str, tf.io.FixedLenFeature | tf.io.VarLenFeature | tf.io.RaggedFeature, ], reader_cls: DatasetReaderType = tf.data.TFRecordDataset, num_input_examples: Optional[Mapping[str, int]] = None, caching_permitted: bool = True, file_shuffle_buffer_size: Optional[int] = None, cycle_length: int = 16, block_length: int = 16, performs_internal_shuffling: bool = False, ): """TFExampleDataSource constructor. Args: split_to_filepattern: dict of string (split name) to either string (filename or filepattern) or list of strings (filenames or filepatterns). feature_description: dict, a mapping of string feature keys to `tf.io.FixedLenFeature`, `tf.io.VarLenFeature`, or `tf.io.RaggedFeature` values. reader_cls: `tf.data.Dataset`, a dataset class to read the input files. num_input_examples: dict or None, an optional dictionary mapping split to its size in number of input examples (before preprocessing). The `num_input_examples` method will return None if not provided. caching_permitted: indicates whether this data source may be cached. Default True. file_shuffle_buffer_size: The buffer size to shuffle files when needed. If None, the number of files is used as buffer size for a perfect shuffle (default and recommended). A value of 16 may be explicitly set to replicate earlier behavior. cycle_length: The cycle_length to pass to tf.data.Dataset.interleave. block_length: The block_length to pass to tf.data.Dataset.interleave. performs_internal_shuffling: Allow enclosing task to call get_dataset with shuffle_buffer_size=None. In this case, only filename shuffling will be performed when shuffle==True. """ def parse_fn(*args): pb = args[-1] # Some readers have more than 1 arg. return tf.io.parse_single_example(pb, feature_description) def read_file_fn(filepattern): return reader_cls(filepattern).map( parse_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE ) self.reader_cls = reader_cls self.parse_fn = parse_fn self.feature_description = feature_description super().__init__( read_file_fn=read_file_fn, split_to_filepattern=split_to_filepattern, num_input_examples=num_input_examples, caching_permitted=caching_permitted, file_shuffle_buffer_size=file_shuffle_buffer_size, cycle_length=cycle_length, block_length=block_length, performs_internal_shuffling=performs_internal_shuffling, ) def __str__(self): return ( f"{self.__class__.__name__}(" f"split_to_filepattern={self._split_to_filepattern}," f" feature_description={self.feature_description})" ) def __repr__(self): return ( f"{self.__class__.__name__}(" f"split_to_filepattern={self._split_to_filepattern}," f" feature_description={self.feature_description}," f" reader_cls={self.reader_cls}," f" num_input_examples={self._num_input_examples}," f" caching_permitted={self._caching_permitted}," f" file_shuffle_buffer_size={self._file_shuffle_buffer_size}," f" cycle_length={self._cycle_length}," f" block_length={self._block_length})" )
[docs] class ProtoDataSource(FileDataSource): """A `FileDataSource` that reads files of arbitrary protos as input.""" def __init__( self, split_to_filepattern: Mapping[str, Union[str, Iterable[str]]], decode_proto_fn: DecodeFnType, reader_cls: DatasetReaderType = tf.data.TFRecordDataset, num_input_examples: Optional[Mapping[str, int]] = None, caching_permitted: bool = True, file_shuffle_buffer_size: Optional[int] = None, cycle_length: int = 16, block_length: int = 16, ): """ProtoDataSource constructor. Args: split_to_filepattern: dict of string (split name) to either string (filename or filepattern) or list of strings (filenames or filepatterns). decode_proto_fn: a callable to parse a serialized proto to features. reader_cls: `tf.data.Dataset`, a dataset class to read the input files. num_input_examples: dict or None, an optional dictionary mapping split to its size in number of input examples (before preprocessing). The `num_input_examples` method will return None if not provided. caching_permitted: indicates whether this data source may be cached. Default True. file_shuffle_buffer_size: The buffer size to shuffle files when needed. If None, the number of files is used as buffer size for a perfect shuffle (default and recommended). A value of 16 may be explicitly set to replicate earlier behavior. cycle_length: The cycle_length to pass to tf.data.Dataset.interleave. block_length: The block_length to pass to tf.data.Dataset.interleave. """ def read_file_fn(filepattern: Union[str, Iterable[str]]): return reader_cls(filepattern).map( decode_proto_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE ) self.reader_cls = reader_cls self.decode_proto_fn = decode_proto_fn super().__init__( read_file_fn=read_file_fn, split_to_filepattern=split_to_filepattern, num_input_examples=num_input_examples, caching_permitted=caching_permitted, file_shuffle_buffer_size=file_shuffle_buffer_size, cycle_length=cycle_length, block_length=block_length, )
# ========================== Offline Caching Helpers =========================== def _rename_plaintext_to_pretokenized( dataset: tf.data.Dataset, ) -> tf.data.Dataset: """Rename cached _plaintext features to new _pretokenized standard.""" def _rename(inputs): outputs = {} for k, v in inputs.items(): if k.endswith("_plaintext"): k = k[: -len("plaintext")] + "pretokenized" outputs[k] = v return outputs return dataset.map(_rename, num_parallel_calls=tf.data.experimental.AUTOTUNE) class _CachedDataSource(FileDataSource): """A `FileDataSource` for reading datasets cached offline.""" def __init__( self, cache_dir: str, split: str, file_shuffle_buffer_size: Optional[int] = None, cycle_length: int = 16, block_length: int = 16, ): with tf.io.gfile.GFile(utils.get_cached_info_path(cache_dir, split)) as f: split_info = json.load(f) features = split_info["features"] with tf.io.gfile.GFile(utils.get_cached_stats_path(cache_dir, split)) as f: stats = json.load(f) version_when_cached = version_lib.Version( split_info.get("seqio_version", "0.pre") ) version_with_true_dtypes = version_lib.Version("0.0.0") if version_when_cached < version_with_true_dtypes: # Assume that all int64 features are really int32. for name, feat in features.items(): if feat["dtype"] == "int64": logging.info("Casting cached '%s' to int32.", name) feat["dtype"] = "int32" # Use `FixedLenSequenceFeature` for sequences with variable length. def _feature_config( key: str, shape, dtype: str, ) -> Union[tf.io.FixedLenFeature, tf.io.RaggedFeature]: if dtype in ("int32", "bool"): # int32 and bool are stored as int64 in the tf.train.Example protobuf. # TODO(adarob): Support other conversions. dtype = "int64" if shape: num_none_components = 0 for x in shape[1:]: if x is None: num_none_components += 1 if num_none_components > 0: # Parse as a ragged feature. partitions = [] ragged_idx = 0 for x in shape[1:]: if x is None: partitions.append( tf.io.RaggedFeature.RowLengths( utils.tfexample_ragged_length_key(key, ragged_idx) ) ) ragged_idx += 1 else: partitions.append(tf.io.RaggedFeature.UniformRowLength(x)) return tf.io.RaggedFeature( value_key=key, partitions=partitions, dtype=dtype ) if shape and shape[0] is None: return tf.io.FixedLenSequenceFeature( shape[1:], dtype, allow_missing=True ) return tf.io.FixedLenFeature(shape, dtype) feature_description = { feat: _feature_config(feat, **desc) for feat, desc in features.items() } def read_file_fn(filepattern): ds = tf.data.TFRecordDataset(filepattern) ds = ds.map( lambda pb: tf.io.parse_single_example(pb, feature_description), num_parallel_calls=tf.data.experimental.AUTOTUNE, ) # Cast features back to the types from the info JSON since some features # must be cast for storage (e.g., in32 is stored as int64). ds = ds.map( lambda x: {k: tf.cast(v, features[k]["dtype"]) for k, v in x.items()}, num_parallel_calls=tf.data.experimental.AUTOTUNE, ) # Legacy cached datasets may use old "_plaintext" suffix. Rename to # "_pretokenized". ds = _rename_plaintext_to_pretokenized(ds) return ds split_to_filepattern = { split: "%s-*-of-*%d" % ( utils.get_cached_tfrecord_prefix(cache_dir, split), split_info["num_shards"], ) } super().__init__( read_file_fn=read_file_fn, split_to_filepattern=split_to_filepattern, num_input_examples={split: stats["examples"]}, file_shuffle_buffer_size=file_shuffle_buffer_size, cycle_length=cycle_length, block_length=block_length, )
[docs] class CacheDatasetPlaceholder(object): """A placeholder to signal when in the pipeline offline caching will occur.""" def __init__( self, required: bool = False, file_shuffle_buffer_size: Optional[int] = None, ): """CacheDatasetPlaceholder constructor. Args: required: whether the dataset must be accessed in its cached form, and on-the-fly preprocessing is disallowed. file_shuffle_buffer_size: The buffer size to shuffle files when needed. If None, the number of files is used as buffer size for a perfect shuffle (default and recommended). A value of 16 may be explicitly set to replicate earlier behavior. """ self._required = required self._file_shuffle_buffer_size = file_shuffle_buffer_size @property def required(self): return self._required @property def file_shuffle_buffer_size(self): return self._file_shuffle_buffer_size def __call__(self, dataset): raise RuntimeError("`CacheDatasetPlaceholder` should never be called.")
# ================================ Tasks ======================================= MetricFnCallable = metrics_lib.MetricFnCallable
[docs] class Task(DatasetProviderBase): """A class to manage a dataset and its related metrics.""" def __init__( self, name: str, source: DataSource, output_features: Mapping[str, Feature], preprocessors: Optional[Sequence[Callable[..., tf.data.Dataset]]] = None, postprocess_fn: Optional[Callable[..., Any]] = None, metric_fns: Optional[Sequence[MetricFnCallable]] = None, metric_objs: Optional[Sequence[metrics_lib.Metric]] = None, shuffle_buffer_size: Optional[int] = SHUFFLE_BUFFER_SIZE, source_info: Optional[SourceInfo] = None, ): """Task constructor. Args: name: a unique name for the Task. source: a `DataSource` that provides a raw `tf.data.Dataset`. output_features: dict(str, Feature), output features of the Task to be passed to the model. After preprocessing, examples will be validated to ensure they include features that match this specification. Note that additional features may be included (e.g., for evaluation), but they will not be passed to the model. preprocessors: list(callable), an optional list of functions that receive a tf.data.Dataset and return a tf.data.Dataset. These will be executed sequentially and the final dataset must include features matching `output_features`. postprocess_fn: callable, an optional function that receives decoded model outputs and converts them to a form that is ready for evaluation using the metric functions in `metric_fns`. metric_fns: list(callable), an optional list of metric functions. Be aware that `metric_fns` are being deprecated, please use `metric_objs` instead. The metric functions must have a signature that matches one of three possible forms: (1) `(targets, scores)` where `scores` refers to the score the model assigned the target sequence, given the input, (2) `(targets, predictions)`, (3) `(targets, predictions, aux_values)` where `aux_values` refers to a dictionary of auxiliary values that the model assigned to each sequence. metric_objs: list(clu Metric instances), an optional list of clu Metric objects. shuffle_buffer_size: an optional integer to set the shuffle buffer size. If None, shuffling will be disallowed. source_info: optional metadata about where this `Task` was defined. """ if not _VALID_TASK_NAME_REGEX.match(name): raise ValueError( "Task name '%s' contains invalid characters. Must match regex: %s" % (name, _VALID_TASK_NAME_REGEX.pattern) ) # Capture constructor arguments and use them lazily to speed up # Task initialization in case many Tasks are being created that are unused. self._metric_objs_constructor_args = metric_objs or [] self._metric_fn_constructor_args = metric_fns or [] self._name = name self._source = source self._source_info = source_info # Capture constructor arguments and use them lazily to speed up # Task initialization in case many Tasks are being created that are unused. self._preprocessor_constructor_args = preprocessors or () self._cache_step_idx: Optional[int] = None self._cache_dataset_placerholder: Optional[CacheDatasetPlaceholder] = None for i, p in enumerate(preprocessors or []): if isinstance(p, CacheDatasetPlaceholder): if self._cache_step_idx is not None: raise ValueError( "`CacheDatasetPlaceholder` can appear at most once in the " f"preprocessing pipeline. Found multiple in '{name}'." ) self._cache_step_idx = i self._cache_dataset_placerholder = p if self._cache_step_idx is not None: if not self.source.caching_permitted: raise ValueError( f"Caching was requested for '{self.name}', but the underlying data " "source prohibits caching. Please remove `CacheDatasetPlaceholder` " "and try again." ) self._postprocess_fn = postprocess_fn self._cache_dir = None self._stats = {} self._shuffle_buffer_size = shuffle_buffer_size self._output_features = collections.OrderedDict( sorted(list(output_features.items())) ) @property def name(self) -> str: return self._name def __str__(self): return f"Task(name={self.name}, source={str(self.source)})" @property def source_info(self) -> Optional[SourceInfo]: return self._source_info @functools.cached_property def metric_objs(self) -> Sequence[metrics_lib.Metric]: """List of all metric objects.""" # Copy list to prevent callers from directly modifying by accessing public # attribute. to_return = list(x for x in self._metric_objs_constructor_args) if self.metric_fns: to_return += [ metrics_lib.PassthroughLegacyMetric.from_metric_fn( mf, self._postprocess_fn ).empty() for mf in self.metric_fns ] return to_return @functools.cached_property def _all_metric_fns( self, ) -> Tuple[ List[MetricFnCallable], List[MetricFnCallable], List[MetricFnCallable], ]: """Creates all metric functions {predict,score,predict_with_aux}_metric_fns. Validation of metric functions, which depend on slow `inspect` calls to help catch common errors, is deferred slightly: 1) only validate the Tasks that are used, and 2) as a result, to improve loading time. If/when the module-level TaskRegistry.add pattern is turned down, validation can probably be made eager again. Returns: tuple: predict_metric_fns, score_metric_fns, predict_with_aux_metric_fns. Raises: ValueError if metric functions don't have positional arguments matching (targets, scores), (targets, predictions), or (targets, predictions, aux_values) """ predict_fns = [] score_fns = [] predict_with_aux_fns = [] for metric_fn in self._metric_fn_constructor_args: pos_args = tuple( key for key, param in inspect.signature(metric_fn).parameters.items() if param.default == inspect.Parameter.empty and param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD ) if pos_args == ("targets", "predictions"): predict_fns.append(metric_fn) elif pos_args == ("targets", "scores"): score_fns.append(metric_fn) elif pos_args == ("targets", "predictions", "aux_values"): predict_with_aux_fns.append(metric_fn) else: raise ValueError( "Metric functions must have positional arguments matching either " "('targets', 'scores'), ('targets', 'predictions') or " "('targets', 'predictions', 'aux_values'). " f"Got: {pos_args}" ) return predict_fns, score_fns, predict_with_aux_fns @property def metric_fns(self) -> Sequence[MetricFnCallable]: """List of all metric functions.""" predict_fns, score_fns, predict_with_aux_fns = self._all_metric_fns return predict_fns + score_fns + predict_with_aux_fns # pytype: disable=unsupported-operands @property def predict_metric_fns(self) -> Sequence[MetricFnCallable]: """List of metric functions that use model predictions.""" return self._all_metric_fns[0] @property def score_metric_fns(self) -> Sequence[MetricFnCallable]: """List of metric functions that use log likelihood scores.""" return self._all_metric_fns[1] @functools.cached_property def predict_with_aux_metric_fns(self) -> Sequence[MetricFnCallable]: """List of metric functions that use model predictions with aux values.""" return self._all_metric_fns[2] @property def output_features(self) -> Mapping[str, Feature]: return self._output_features @property def splits(self) -> Sequence[str]: s = self.source.splits if not s: raise ValueError(f"Task {self.name} has no splits") return s @property def source(self) -> DataSource: return self._source def _validate_preprocessors(self): """Validates that some common errors are not made with preprocessors. Raises: ValueError if caching is improperly requested. """ if self._cache_step_idx is not None: for prep in self._preprocessor_constructor_args[: self._cache_step_idx]: prep_args = inspect.signature(prep).parameters.keys() if "sequence_length" in prep_args: raise ValueError( f"'{utils.function_name(prep)}' has a `sequence_length` argument" f" but occurs before `CacheDatasetPlaceholder` in '{self.name}'." " This is not allowed since the sequence length is specified at" " run time." ) if "seed" in prep_args or "seeds" in prep_args: logging.warning( ( "'%s' has a `seed(s)` argument but occurs before " "`CacheDatasetPlaceholder` in '%s'. This is not recommended " "since the same samples will be used each epoch when reading " "from the cache." ), utils.function_name(prep), self.name, ) @functools.cached_property def preprocessors(self) -> Sequence[Callable[..., tf.data.Dataset]]: # Validation of preprocessors, which depends on slow `inspect` calls to # help catch common errors, is deferred slightly: # 1) only validate the Tasks that are used, and # 2) as a result, to improve loading time. # If/when the module-level TaskRegistry.add pattern is turned down, # validation can probably be made eager again. self._validate_preprocessors() return self._preprocessor_constructor_args @property def postprocessor(self) -> Optional[Callable[..., Any]]: return self._postprocess_fn @property def shuffle_buffer_size(self) -> Optional[int]: return self._shuffle_buffer_size
[docs] def replace(self, **kwargs): """Create a new variant of the current task using properties in kwargs.""" properties = [ "name", "source", "output_features", "preprocessors", "postprocess_fn", "metric_fns", "metric_objs", "shuffle_buffer_size", ] if set(kwargs.keys() - set(properties)): raise ValueError( "Expected keys of kwargs argument task.replace to be one of" f" {properties}. However, there were keys in kwargs that are not in" f" this set: {set(kwargs.keys() - set(properties))}" ) task_kwargs = {k: v for k, v in kwargs.items() if k in properties} keys_not_specified_by_user = [k for k in properties if k not in kwargs] for key in keys_not_specified_by_user: if key == "postprocess_fn": task_kwargs["postprocess_fn"] = self.postprocessor elif key == "preprocessors": # This check isn't strictly needed, but if additional functionality # is added to self.preprocessors, it will be. So we leave it in # to help future-proof. task_kwargs["preprocessors"] = self._preprocessor_constructor_args elif key == "metric_fns": task_kwargs["metric_fns"] = self._metric_fn_constructor_args elif key == "metric_objs": task_kwargs["metric_objs"] = self._metric_objs_constructor_args else: task_kwargs[key] = getattr(self, key) return Task(**task_kwargs)
def num_input_examples(self, split: str) -> Optional[int]: # pytype: disable=signature-mismatch # overriding-return-type-checks return self.source.num_input_examples(split) def _preprocess_dataset( self, dataset: tf.data.Dataset, preprocessors: Sequence[Callable[..., tf.data.Dataset]], sequence_length: Optional[Mapping[str, int]] = None, ) -> tf.data.Dataset: """Sequentially applies preprocessors.""" for prep_fn in preprocessors: prep_fn = utils.add_kwargs_to_transform( prep_fn, sequence_length=sequence_length, output_features=self.output_features, ) dataset = prep_fn(dataset) return dataset def _validate_preprocessing( self, dataset: tf.data.Dataset ) -> tf.data.Dataset: """Validates preprocessed dataset, raising Exceptions if needed. Args: dataset: a tf.data.Dataset to validate. Returns: a validated tf.data.Dataset. """ actual_specs = dataset.element_spec for feat, feat_spec in self.output_features.items(): if feat not in actual_specs: if feat_spec.required: raise ValueError( "Task dataset is missing expected output feature after " f"preprocessing: {feat}" ) else: # It's ok that this feature does not exist. continue actual_spec = actual_specs[feat] if feat_spec.dtype != actual_spec.dtype: raise ValueError( f"Task dataset has incorrect type for feature '{feat}' after " f"preprocessing: Got {actual_spec.dtype.name}, expected " f"{feat_spec.dtype.name}" ) if feat_spec.rank != actual_spec.shape.rank: raise ValueError( f"Task dataset has incorrect rank for feature '{feat}' after " f"preprocessing: Got {actual_spec.shape.rank}, expected " f"{feat_spec.rank}" ) return dataset def _trim_output_features( self, dataset: tf.data.Dataset, sequence_length: Optional[Mapping[str, Union[int, Sequence[int]]]], ) -> tf.data.Dataset: """Trim output features to sequence length.""" return utils.trim_dataset(dataset, sequence_length, self.output_features)
[docs] def preprocess_precache( self, dataset: tf.data.Dataset, seed: Optional[int] = None ) -> tf.data.Dataset: """Runs preprocessing steps before the optional CacheDatasetPlaceholder.""" if not self.supports_caching: return dataset with utils.map_seed_manager(seed): return self._preprocess_dataset( dataset, self.preprocessors[: self._cache_step_idx], )
[docs] def preprocess_postcache( self, dataset: tf.data.Dataset, sequence_length: Optional[Mapping[str, int]], seed: Optional[int] = None, ) -> tf.data.Dataset: """Runs preprocessing steps after the optional CacheDatasetPlaceholder. Args: dataset: a tf.data.Dataset sequence_length: dict mapping feature key to int length for that feature. If None, the features will not be truncated. seed: an optional random seed for deterministic preprocessing. Returns: a tf.data.Dataset """ start_idx = 0 if self.supports_caching: # Skip a sufficient number of seeds to avoid duplicating any from # pre-cache preprocessing. seed = None if seed is None else seed + 42 * self._cache_step_idx start_idx = self._cache_step_idx + 1 with utils.map_seed_manager(seed): dataset = self._preprocess_dataset( dataset, self.preprocessors[start_idx:], sequence_length=sequence_length, ) return dataset
@property def cache_dir(self) -> Optional[str]: """Returns the cache directory (or None), initializing if needed.""" if not self._cache_dir: # See if cached data exists in any of the cache directories. potential_cache_dirs = [ os.path.join(d, utils.get_task_dir_from_name(self.name)) for d in utils.get_global_cache_dirs() ] for cache_dir in potential_cache_dirs: try: if tf.io.gfile.exists(os.path.join(cache_dir, "COMPLETED")): self._cache_dir = cache_dir logging.info("'%s' is cached at %s.", self.name, self.cache_dir) break except tf.errors.PermissionDeniedError: logging.warning( "Task %s: Permission denied for global cache folder: %s", self.name, cache_dir, ) except tf.errors.FailedPreconditionError as e: logging.warning( ( "Task %s: Failed precondition for global cache folder: " "%s with %r" ), self.name, cache_dir, e, ) if not self._cache_dir: logging.info( "'%s' does not exist in any task cache directories (searched %s).", self.name, potential_cache_dirs, ) logging.info( "Using cache directory %s for '%s'.", self._cache_dir, self.name ) return self._cache_dir @property def supports_caching(self) -> bool: """Whether or not this task supports offline caching.""" return self._cache_step_idx is not None @property def requires_caching(self) -> bool: """Whether or not this task requires offline caching.""" return ( self._cache_dataset_placerholder is not None and self._cache_dataset_placerholder.required )
[docs] def assert_cached(self) -> None: """Raises an assertion error if cached dataset does not exist.""" assert ( self.cache_dir ), f"'{self.name}' does not exist in any of the task cache directories."
[docs] def get_cached_stats( self, split: str = tfds.Split.TRAIN ) -> Mapping[str, Union[int, float]]: """Returns basic statistics for cached dataset.""" self.assert_cached() if split not in self._stats: stats_path = utils.get_cached_stats_path(self.cache_dir, split) if not tf.io.gfile.exists(stats_path): raise ValueError( "Stats do not exist for '%s' split: %s" % (self.name, split) ) with tf.io.gfile.GFile(stats_path) as f: self._stats[split] = json.load(f) return self._stats[split]
[docs] def get_dataset( self, # pytype: disable=signature-mismatch # overriding-default-value-checks sequence_length: Optional[Mapping[str, int]] = None, split: str = tfds.Split.TRAIN, use_cached: bool = False, shuffle: bool = True, shuffle_buffer_size: Optional[int] = None, # Unique to Task seed: Optional[int] = None, shard_info: Optional[ShardInfo] = None, num_epochs: Optional[int] = 1, trim_output_features: bool = True, # Unique to Task try_in_mem_cache: bool = True, ) -> tf.data.Dataset: """Returns a tf.data.Dataset from cache or generated on the fly. Args: sequence_length: dict mapping feature key to maximum int length for that feature. If longer after preprocessing, the feature will be truncated. May be set to None to avoid truncation. split: string, the split to return. use_cached: bool, whether to use the cached dataset instead of processing it on the fly. Defaults to False. shuffle: bool, whether to shuffle the dataset. Only used when generating on the fly (use_cached=False). shuffle_buffer_size: an integer or None to use task-specific buffer size. seed: tf.int64 scalar tf.Tensor (or None) for shuffling tf.data. shard_info: optional specification for loading a shard of the split. If the Task's DataSource contains at least the number of shards in the specification, it will be passed the shard info to avoid loading the full source dataset. Otherwise, the full source dataset will be loaded and sharded at the individual examples. num_epochs: the number of times to iterate through the dataset, or `None` to repeat indefinitely. Note that the repeat occurs in the pipeline after offline caching, but before applying potentially stochastic post-cache preprocessors and is therefore typically preferred to calling `repeat()` on the returned dataset. Defaults to `1`. trim_output_features: If True, it trims output features to be less than the length given by `sequence_length`. try_in_mem_cache: If True, caches sufficiently small datasets in memory for efficiency. Returns: A tf.data.Dataset. """ if use_cached and not self.supports_caching: logging.warning( ( "Task '%s' does not support caching. Switching to on-the-fly " "preprocessing." ), self.name, ) use_cached = False elif self.requires_caching and not use_cached: raise ValueError( f"Task '{self.name}' requires caching, but was called with " "`use_cached=False`." ) if use_cached: source = self._get_cached_source(split) else: source = self.source if source.supports_arbitrary_sharding: shard_data_source = True elif shard_info: # Whether we should shard at source or on the examples from the source. shard_data_source = ( len(source.list_shards(split=split)) >= shard_info.num_shards ) logging.info( "Sharding at the %s: %d of %d", "data source" if shard_data_source else "examples", shard_info.index + 1, shard_info.num_shards, ) else: # Call get_dataset on the source without a shard_info. shard_data_source = True shard_info = None if shard_data_source: ds = source.get_dataset( split=split, shuffle=shuffle, seed=seed, shard_info=shard_info ) else: ds = source.get_dataset(split=split, shuffle=shuffle, seed=seed) ds = ds.shard(shard_info.num_shards, shard_info.index) num_shards = shard_info.num_shards if shard_info else 1 if try_in_mem_cache and ( ( use_cached and self.get_cached_stats(split)["examples"] < _MAX_EXAMPLES_TO_MEM_CACHE * num_shards ) or ( source.num_input_examples(split) and source.num_input_examples(split) < _MAX_EXAMPLES_TO_MEM_CACHE * num_shards ) ): logging.info( "Automatically caching small dataset in memory: '%s:%s'", self.name, split, ) ds = ds.cache() if not use_cached: ds = self.preprocess_precache(ds, seed=seed) # We repeat before calling any (potentially) stochastic post-cache # preprocessing in order to take new samples each epoch. if num_epochs != 1: ds = ds.repeat(num_epochs) # Post cache processing. ds = self.preprocess_postcache( ds, sequence_length=sequence_length, seed=seed ) ds = self._validate_preprocessing(ds) if trim_output_features: ds = self._trim_output_features(ds, sequence_length=sequence_length) if shuffle: if self._shuffle_buffer_size is None: if not self.source.performs_internal_shuffling: raise ValueError( f"Shuffling is disallowed for Task '{self.name}' since its " "`shuffle_buffer_size` was set to `None` on construction." ) else: shuffle_buffer_size = shuffle_buffer_size or self._shuffle_buffer_size # Shuffle before mixing since preprocessor can output multiple # (correlated) examples per input. ds = ds.shuffle(shuffle_buffer_size, seed=seed) return ds.prefetch(tf.data.experimental.AUTOTUNE)
def _get_cached_source( self, split: str, file_shuffle_buffer_size: Optional[int] = None ) -> _CachedDataSource: """Returns a DataSource to read cached files for split.""" self.assert_cached() file_shuffle_buffer_size = ( file_shuffle_buffer_size or self._cache_dataset_placerholder.file_shuffle_buffer_size ) return _CachedDataSource( cache_dir=self.cache_dir, split=split, file_shuffle_buffer_size=file_shuffle_buffer_size, )
[docs] def postprocess_fn( self, decoded_model_output: Any, **postprocess_kwargs ) -> Any: """Returns the model output after applying the postprocess function.""" if self._postprocess_fn: return self._postprocess_fn(decoded_model_output, **postprocess_kwargs) return decoded_model_output
[docs] class TaskRegistry(DatasetProviderRegistry): """Registry of Tasks.""" _REGISTRY = {} _PROVIDER_TYPE = Task # pylint: disable=arguments-renamed
[docs] @classmethod def add( cls, name: str, source: DataSourceInterface, output_features: Mapping[str, Feature], preprocessors: Optional[Sequence[Callable[..., tf.data.Dataset]]] = None, postprocess_fn: Optional[Callable[..., Any]] = None, metric_fns: Optional[Sequence[MetricFnCallable]] = None, metric_objs: Optional[Sequence[clu.metrics.Metric]] = None, task_cls: Type[Task] = Task, source_info: Optional[SourceInfo] = None, **kwargs, ) -> Task: """See `Task` constructor for docstring.""" provider_kwargs = { "name": name, "source": source, "output_features": output_features, "preprocessors": preprocessors, "postprocess_fn": postprocess_fn, "metric_fns": metric_fns, "metric_objs": metric_objs, "source_info": source_info, **kwargs, } return super().add( name, provider_cls=task_cls, provider_kwargs=provider_kwargs )
# pylint: enable=arguments-renamed
[docs] @classmethod def get(cls, name) -> Task: return super().get(name)
# ================================ Mixtures ==================================== SampleFn = Callable[ [Sequence[tf.data.Dataset], Sequence[float], Optional[int]], tf.data.Dataset ] MixtureRate = Union[int, float, Callable[[Union[Task, "Mixture"]], float]] SubtaskOrName = Union[Task, "Mixture", str]
[docs] class Mixture(DatasetProviderBase): """Class for mixing multiple tasks.""" def __init__( self, name: str, tasks: Union[ Sequence[SubtaskOrName], Sequence[Tuple[SubtaskOrName, MixtureRate]] ], default_rate: Optional[MixtureRate] = None, sample_fn: SampleFn = functools.partial( tf.data.Dataset.sample_from_datasets, stop_on_empty_dataset=True ), source_info: Optional[SourceInfo] = None, ): """Mixture constructor. A mixture specifies a set of tasks with associated mixing rates. Mixing happens on preprocessed tokenized examples. The mixing rates represent relative numbers of examples to use from their associated tasks. Setting the mixing rates to be equal to the numbers of examples in the tasks will result in each task going through an epoch in about the same amount of time - i.e. all examples are sampled equally across all tasks. Rates can be expressed either as absolute numbers or as functions that receive the Task as an argument. Args: name: string, a unique name for the Mixture. tasks: a list where each element is either a Task/Mixture or string (task/mixture name) or a pair whose first element is the Task/Mixture or name and whose second element is either a float (rate) or a function from Task to float. default_rate: a float or a function from Task to float. This specifies the default rate if rates are not provided in the `tasks` argument. sample_fn: SampleFn callable that implements sampling logic to interleave multiple datasets into a single dataset. source_info: optional metadata about where this `Mixture` was defined. """ self._task_to_rate = {} self._task_map = {} self._tasks = [] self._sub_mixtures = [] self._name = name self._sample_fn = sample_fn self._source_info = source_info for t in tasks: if isinstance(t, (str, Task, Mixture)): task_or_name = t rate = default_rate if default_rate is None: raise ValueError("need a rate for each task") else: task_or_name, rate = t if isinstance(task_or_name, str): task_name: str = task_or_name is_task = task_name in TaskRegistry.names() subtask = ( TaskRegistry.get(task_name) if is_task else MixtureRegistry.get(task_name) ) else: subtask = task_or_name task_name = subtask.name is_task = isinstance(subtask, Task) if is_task: self._tasks.append(subtask) else: self._sub_mixtures.append(subtask) self._task_to_rate[task_name] = rate self._task_map[task_name] = subtask if not self.tasks: raise ValueError(f"Mixture, {self.name}, does not contain any Tasks.") if len(set(tuple(t.output_features) for t in self.tasks)) != 1: task_name_outputs = "\n".join( [t.name + ": " + str(tuple(t.output_features)) for t in self.tasks] ) raise ValueError( f"Mixture, '{self.name}' contains Tasks with different output " f"features:\n {task_name_outputs}" ) @property def name(self) -> str: return self._name @property def source_info(self) -> Optional[SourceInfo]: return self._source_info @property def tasks(self) -> list[Task]: sub_tasks = (mix.tasks for mix in self._sub_mixtures) return sorted(set(sum(sub_tasks, self._tasks)), key=lambda t: t.name) @property def total_rate(self) -> float: return sum( float(rate(self._task_map[name]) if callable(rate) else rate) for name, rate in self.rate_per_task_name.items() ) @property def rate_per_task_name(self) -> Mapping[str, MixtureRate]: """Returns the rate for each task. Note that sub-mixtures are included as tasks and that the tasks part of these sub-mixtures are not in the mapping. """ return self._task_to_rate
[docs] def get_rate(self, task: Task) -> float: """Computes the mixing rate for the given task.""" value = 0.0 for mix in self._sub_mixtures: if task in mix.tasks: rate = self._get_submixture_rate(mix) value += rate * mix.get_rate(task) / mix.total_rate if task.name in self.rate_per_task_name: rate = self.rate_per_task_name[task.name] value += float(rate(task) if callable(rate) else rate) return value
def _get_submixture_rate(self, mix: "Mixture") -> float: """Returns the rate for a sub mixture by name.""" rate = self.rate_per_task_name[mix.name] if not isinstance(rate, numbers.Number): raise ValueError( f"'rate' for sub-mixture {repr(mix.name)} must be a number." ) return float(rate) def num_input_examples(self, split: str) -> int: return sum( t.num_input_examples(split) for t in self.tasks if split in t.splits ) @property def splits(self) -> Sequence[str]: splits = set() for task in self.tasks: splits.update(task.splits) return tuple(splits) @property def output_features(self) -> Mapping[str, Feature]: # We require all tasks to have the same output_features in __init__ # so we can just get the output_features for the 0th task return self.tasks[0].output_features def _check_compatible_features(self) -> None: """Throw Exception if features across tasks have different vocabs or dtypes.""" for name, feature in self.tasks[0].output_features.items(): for task in self.tasks[1:]: task_feature = task.output_features[name] if ( hasattr(feature, "vocabulary") and task_feature.vocabulary != feature.vocabulary ): raise ValueError( "Features across tasks in a mixture must use the same " f"vocabulary. Got {task_feature.vocabulary} for feature " f"'{name}' in task '{task}', expected {feature.vocabulary}." ) if task_feature.dtype != feature.dtype: raise ValueError( "Features across tasks in a mixture must use the same dtype. " f"Got {task_feature.dtype} for feature '{name}' in task " f"'{task}', expected {feature.dtype}." )
[docs] def get_task_dataset( self, task: Task, output_feature_keys: Set[str], sequence_length: Optional[Mapping[str, int]] = None, split: str = tfds.Split.TRAIN, use_cached: bool = False, shuffle: bool = True, seed: Optional[int] = None, shard_info: Optional[ShardInfo] = None, num_epochs: Optional[int] = None, trim_output_features: bool = True, try_in_mem_cache: bool = True, ) -> tf.data.Dataset: """.""" def filter_features(ex): return {k: v for k, v in ex.items() if k in output_feature_keys} return task.get_dataset( sequence_length=sequence_length, split=split, use_cached=use_cached, shuffle=shuffle, seed=seed, shard_info=shard_info, num_epochs=num_epochs, trim_output_features=trim_output_features, try_in_mem_cache=try_in_mem_cache, ).map(filter_features, num_parallel_calls=tf.data.experimental.AUTOTUNE)
def _get_all_mixing_rates(self, tasks): return [self.get_rate(task) for task in tasks]
[docs] def get_dataset( # pytype: disable=signature-mismatch # overriding-parameter-type-checks self, sequence_length: Optional[Mapping[str, int]] = None, split: str = tfds.Split.TRAIN, use_cached: bool = False, shuffle: bool = True, seed: Optional[int] = None, shard_info: Optional[ShardInfo] = None, num_epochs: Optional[int] = None, # Unique default for Mixture copy_pretokenized: bool = False, # Unique (and all below) to Mixture compute_stats_empirically: bool = False, log_mixing_proportions: bool = True, passthrough_features: Optional[Sequence[str]] = None, trim_output_features: bool = True, try_in_mem_cache: bool = True, ) -> tf.data.Dataset: """Returns the dataset of mixed tasks using the object-specified rates. Args: sequence_length: dict mapping feature key to maximum int length for that feature. If longer after preprocessing, the feature will be truncated. May be set to None to avoid truncation. split: string, the split to return for all tasks. use_cached: bool, whether to use the cached dataset instead of processing it on the fly. This will be passed to the underlying Tasks in the Mixture. Defaults to False. shuffle: bool, whether to shuffle the dataset. Only used when generating on the fly (use_cached=False). seed: tf.int64 scalar tf.Tensor (or None) for shuffling tf.data. shard_info: optional specification for loading a shard of the split. num_epochs: the number of times to iterate through the dataset, or `None` to repeat indefinitely. Note that the repeat occurs in the pipeline after offline caching, but before applying potentially stochastic post-cache preprocessors and is therefore typically preferred to calling `repeat()` on the returned dataset. Defaults to `None`. copy_pretokenized: bool, whether to pass through copies of pretokenized features a "_pretokenized" suffix added to the key. compute_stats_empirically: a boolean - does not work on TPU log_mixing_proportions: whether to log the mixing proportions of the tasks passthrough_features: a list of additional features that will be kept after the feature filtering. If set to be None, then only the output_features defined for the mixture will be kept. trim_output_features: If True, it trims output features to be less than the length given by `sequence_length`. try_in_mem_cache: If True, caches sufficiently small datasets in memory for efficiency. """ self._check_compatible_features() tasks = [] for task in self.tasks: if split not in task.splits: logging.warning( "Task %s has no '%s' split, skipping.", task.name, split ) continue tasks.append(task) if not tasks: raise ValueError("No datasets have a '{}' split".format(split)) output_feature_keys = set(self.output_features.keys()) if copy_pretokenized: output_feature_keys.update( {f + "_pretokenized" for f in output_feature_keys} ) if passthrough_features: output_feature_keys.update(passthrough_features) datasets: List[tf.data.Dataset] = [] for task in tasks: try: ds = self.get_task_dataset( task, output_feature_keys, sequence_length, split, use_cached, shuffle, seed, shard_info, num_epochs, trim_output_features, try_in_mem_cache, ) datasets.append(ds) except: logging.error( "Failed to load task '%s' as part of mixture '%s'", task.name, self.name, ) # Re-raise the same exception, same stack-trace. raise rates = self._get_all_mixing_rates(tasks) # Sample from the dataset with the rates rates if seed is not None: sample_seed = seed elif shuffle: sample_seed = None else: sample_seed = 42 dataset = self._sample_fn(datasets, rates, sample_seed) if ( log_mixing_proportions and not isinstance(rates, tf.data.Dataset) and split == "train" and use_cached and all(t.supports_caching for t in tasks) ): _log_mixing_proportions( tasks, datasets, rates, dataset, sequence_length, compute_stats_empirically, ) return dataset
[docs] class PyGloveTunableMixture(Mixture): """Mixture whose task rates can be tuned by PyGlove.""" def __init__( self, name: str, tasks: Union[ Sequence[SubtaskOrName], Sequence[Tuple[SubtaskOrName, MixtureRate]] ], default_rate: Optional[MixtureRate] = None, sample_fn: SampleFn = functools.partial( tf.data.Dataset.sample_from_datasets, stop_on_empty_dataset=True ), source_info: Optional[SourceInfo] = None, ): def hyper_ratio(task_name, hyper): """Function for converting PyGlove hyper primitive as ratio fn.""" def ratio_fn(unused_task): hyper_kwargs = dict(hyper.sym_init_args) if "name" not in hyper_kwargs or hyper_kwargs["name"] is None: hyper_kwargs["name"] = task_name return hyper.__class__(**hyper_kwargs) return ratio_fn converted_tasks = [] for t in tasks: if isinstance(t, (list, tuple)) and isinstance( t[1], pg.hyper.HyperPrimitive ): t = (t[0], hyper_ratio(t[0], t[1])) converted_tasks.append(t) super().__init__( name=name, tasks=converted_tasks, default_rate=default_rate, sample_fn=sample_fn, source_info=source_info, ) def _get_submixture_rate(self, mix: "Mixture") -> float: """Overrides this method to make submixture ratio tunable.""" rate = self.rate_per_task_name[mix.name] if callable(rate): rate = rate(mix) return float(rate)
def _log_padding_fractions(dataset, sequence_length, num_examples=100): """Empirically compute the fraction of padding - log the results. Args: dataset: a tf.data.Dataset sequence_length: dict from string to int (packed lengths) num_examples: an integer """ logging.info("computing padding fractions") keys = sequence_length.keys() padding_frac = {k: 0 for k in keys} for ex in tfds.as_numpy(dataset.take(num_examples)): for k in keys: padding_frac[k] += 1 - (sequence_length[k] / len(ex[k])) for k in keys: logging.info("%s padding fraction = %g", k, padding_frac[k]) def _log_mixing_proportions( tasks, datasets, rates, mixed_dataset, sequence_length, compute_stats_empirically, ): """Log information about the mixing proportions. Called from Mixture.get_dataset. Args: tasks: a list of Task datasets: a list of tf.data.Dataset rates: a list of floats mixed_dataset: a tf.data.Dataset sequence_length: dict from string to int (packed lengths) compute_stats_empirically: a boolean - does not work on TPU """ def _normalize(l): denom = sum(l) if not denom: return l return [x / denom for x in l] # compute some stats about the mixture examples_fraction = _normalize(rates) if compute_stats_empirically: stats_examples = 100 mean_inputs_length = [] mean_targets_length = [] for dataset in datasets: inputs_sum = 0 targets_sum = 0 for ex in tfds.as_numpy(dataset.take(stats_examples)): # Some tasks, like LMs, don't have inputs. if "inputs" in ex: inputs_sum += ex["inputs"].size targets_sum += ex["targets"].size mean_inputs_length.append(inputs_sum / float(stats_examples)) mean_targets_length.append(targets_sum / float(stats_examples)) else: def _estimated_mean_length(task, key): if sequence_length is None or key not in sequence_length: return 0 if ( task.supports_caching and task._cache_step_idx < len(task.preprocessors) - 1 ): # pylint:disable=protected-access # There is processing after caching, so we can't rely on the stats. return sequence_length[key] # Some tasks, like LMs, don't have inputs. if key + "_tokens" in task.get_cached_stats("train"): return min( sequence_length[key], ( task.get_cached_stats("train")[key + "_tokens"] / task.get_cached_stats("train")["examples"] ), ) else: return 0 mean_inputs_length = [ _estimated_mean_length(task, "inputs") for task in tasks ] mean_targets_length = [ _estimated_mean_length(task, "targets") for task in tasks ] inputs_fraction = _normalize( [l * r for l, r in zip(mean_inputs_length, rates)] ) targets_fraction = _normalize( [l * r for l, r in zip(mean_targets_length, rates)] ) logging.info( "%12s %12s %12s %12s %12s %12s %s", "rate", "ex.frac.", "inp.frac.", "tgt.frac.", "inp.len.", "tgt.len", "task", ) for i in range(len(rates)): logging.info( "%12g %12g %12g %12g %12g %12g %s", rates[i], examples_fraction[i], inputs_fraction[i], targets_fraction[i], mean_inputs_length[i], mean_targets_length[i], tasks[i].name, ) if compute_stats_empirically: _log_padding_fractions(mixed_dataset, sequence_length)
[docs] class MixtureRegistry(DatasetProviderRegistry): """Registry of Mixtures.""" _REGISTRY = {} _PROVIDER_TYPE = Mixture # pylint: disable=arguments-renamed
[docs] @classmethod def add( cls, name, tasks, default_rate=None, mixture_cls: Type[Mixture] = Mixture, source_info: Optional[SourceInfo] = None, **kwargs, ) -> Mixture: """See `Mixture` constructor for docstring.""" provider_kwargs = { "name": name, "tasks": tasks, "default_rate": default_rate, "source_info": source_info, **kwargs, } return super().add( name, provider_cls=mixture_cls, provider_kwargs=provider_kwargs )
[docs] @classmethod def get(cls, name) -> Mixture: return super().get(name)
# pylint: enable=arguments-renamed def _get_closest_names( candidate_names: Iterable[str], target_name: str ) -> List[str]: """Order candidate names by distance to target. Args: candidate_names: a list of candidate names to be ordered target_name: target name for distance computation Returns: candidate names ordered by increasing distance to target_name. """ name_to_dist = {} for candidate_name in candidate_names: name_to_dist[candidate_name] = editdistance.eval( candidate_name, target_name ) sorted_d = sorted(name_to_dist.items(), key=operator.itemgetter(1)) return [k for (k, v) in sorted_d]
[docs] def get_mixture_or_task(task_or_mixture_name: str): """Return the Task or Mixture from the appropriate registry.""" assert isinstance(task_or_mixture_name, str), f"Got: {task_or_mixture_name!r}" mixtures = MixtureRegistry.names() tasks = TaskRegistry.names() if task_or_mixture_name in mixtures: if task_or_mixture_name in tasks: logging.warning( "%s is both a Task and a Mixture, returning Mixture", task_or_mixture_name, ) return MixtureRegistry.get(task_or_mixture_name) if task_or_mixture_name in tasks: return TaskRegistry.get(task_or_mixture_name) else: logging.info("TaskRegistry has %s tasks", len(tasks)) for available_task in _get_closest_names(tasks, task_or_mixture_name): logging.info( "Available task (starting from least distance to %s): %s", task_or_mixture_name, available_task, ) for available_mixture in sorted(mixtures): logging.info("Available mixture: %s", available_mixture) raise ValueError( "No Task or Mixture found with name '%s'." % task_or_mixture_name )
[docs] def maybe_get_mixture_or_task( task: Union[str, Task, Mixture], ) -> Union[Task, Mixture]: """Given a task name, Task, or Mixture object, return an object.""" if isinstance(task, str): return get_mixture_or_task(task) if isinstance(task, (Task, Mixture)): return task raise ValueError( "User passed in a task that was not a string, Task, or Mixture." f"Got type: {type(task)}" )
[docs] def get_subtasks(task_or_mixture): """Returns all the Tasks in a Mixture as a list or the Task itself.""" if isinstance(task_or_mixture, Task): return [task_or_mixture] else: return task_or_mixture.tasks
[docs] def get_dataset( mixture_or_task_name: Union[str, Task, Mixture], task_feature_lengths: Mapping[str, int], feature_converter: FeatureConverter, dataset_split: str = "train", use_cached: bool = False, shuffle: bool = False, num_epochs: Optional[int] = 1, shard_info: Optional[ShardInfo] = None, verbose: bool = True, seed: Optional[int] = None, batch_size: Optional[int] = None, trim_output_features: bool = True, ) -> tf.data.Dataset: """Get processed dataset with the model features. In order to use options specific to a feature converter, e.g., packing, `feature_converter` instance should be instantiated with those options before being pased to this function. Getting sharded datasets is supported. To use this feature, pass in `shard_info`, with shard_index and num_shards information. Sharding is done before the feature converter stage. Therefore, if packing is used it will be done on the sharded dataset. Args: mixture_or_task_name: mixture or task name for the Task API. task_feature_lengths: dict mapping task feature key to its sequence length. This specifies the sequence length of the dataset from the Task API. feature_converter: a feature converter object to use to convert the task features to model features. Must be a subclass of FeatureConverter. dataset_split: the split to use. use_cached: whether to use the cached dataset instead of processing it on the fly. shuffle: whether to shuffle the dataset. num_epochs: the number of times to iterate through the dataset, or `None` to repeat indefinitely. Note that the repeat occurs in the pipeline after offline caching, but before applying potentially stochastic post-cache preprocessors and is therefore typically preferred to calling `repeat()` on the returned dataset. Defaults to `1`. shard_info: number of shards and shard index information. verbose: if true, log the feature shapes. seed: a random seed to for shuffling tf.data. batch_size: Optional batch size. trim_output_features: If True, it trims output features to be less than the length given by `sequence_length`. Returns: ds: the processed dataset. """ if not isinstance(feature_converter, FeatureConverter): raise TypeError( "feature_converter should be an instance of FeatureConverter." ) mixture_or_task = ( get_mixture_or_task(mixture_or_task_name) if not isinstance(mixture_or_task_name, DatasetProviderBase) else mixture_or_task_name ) is_grain_task = False if is_grain_task: ds = mixture_or_task.get_dataset( sequence_length=task_feature_lengths, split=dataset_split, use_cached=use_cached, shuffle=shuffle, seed=seed, shard_info=shard_info, num_epochs=num_epochs, batch_size=batch_size, feature_converter=feature_converter, trim_output_features=trim_output_features, ) else: ds = mixture_or_task.get_dataset( task_feature_lengths, split=dataset_split, use_cached=use_cached, shuffle=shuffle, seed=seed, shard_info=shard_info, num_epochs=num_epochs, trim_output_features=trim_output_features, ) ds = feature_converter(ds, task_feature_lengths=task_feature_lengths) if batch_size is not None: ds = ds.batch(batch_size, drop_remainder=True) if verbose: logging.info( "The output dataset from seqio.get_dataset has the following features" ) element_spec = utils.flatten_dict(ds.element_spec, delimiter=".") for feature_name, tensor_spec in element_spec.items(): if isinstance(tensor_spec, tf.TensorSpec): logging.info( "feature: %s \t shape: %s \t dtype: %s", feature_name, tensor_spec.shape.as_list(), tensor_spec.dtype.name, ) else: logging.error( "Unknown tensor_spec type %s for feature %s.", type(tensor_spec), feature_name, ) return ds