# Copyright 2026 The SeqIO Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for data loading and processing."""
import collections
import contextlib
import dataclasses
import functools
import glob
import inspect
import os
import re
import threading
import types
from typing import Any, Callable, Dict, Iterable, Mapping, Optional, Sequence, Tuple, Type, Union
from absl import logging
import numpy as np
from seqio.vocabularies import Vocabulary
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
_INFO_FILENAME = "info.{split}.json"
_STATS_FILENAME = "stats.{split}.json"
_TFRECORD_PREFIX = "{split}.tfrecord"
_TFDS_DATA_DIR_OVERRIDE = None
_TFDS_DATA_READ_CONFIG_OVERRIDE = None
_GLOBAL_CACHE_DIRECTORIES = []
_MapTransform = object
_RandomMapTransform = object
[docs]
@dataclasses.dataclass(frozen=True)
class Feature:
"""A container for attributes of output features of data providers."""
vocabulary: Vocabulary
add_eos: bool = True
required: bool = True
dtype: tf.DType = tf.int32
rank: int = 1
def set_tfds_data_dir_override(tfds_data_dir):
global _TFDS_DATA_DIR_OVERRIDE
_TFDS_DATA_DIR_OVERRIDE = tfds_data_dir
def set_tfds_read_config_override(tfds_read_config):
global _TFDS_DATA_READ_CONFIG_OVERRIDE
_TFDS_DATA_READ_CONFIG_OVERRIDE = tfds_read_config
def get_global_cache_dirs():
return _GLOBAL_CACHE_DIRECTORIES
def set_global_cache_dirs(global_cache_dirs):
global _GLOBAL_CACHE_DIRECTORIES
_GLOBAL_CACHE_DIRECTORIES = []
add_global_cache_dirs(global_cache_dirs)
def add_global_cache_dirs(global_cache_dirs):
for cache_dir in global_cache_dirs:
if cache_dir not in _GLOBAL_CACHE_DIRECTORIES:
_GLOBAL_CACHE_DIRECTORIES.append(cache_dir)
def _validate_tfds_name(name: str) -> None:
"""Validates TFDS dataset name."""
if (
name
and ":" not in name
):
raise ValueError(f"TFDS name must contain a version number, got: {name}")
def _get_data_dir_override(tfds_name: Optional[str]) -> Optional[str]:
"""Returns the data dir in case it is overridden."""
if (
_TFDS_DATA_DIR_OVERRIDE
):
return _TFDS_DATA_DIR_OVERRIDE
return None
[docs]
@dataclasses.dataclass(frozen=True)
class TfdsSplit:
"""Points to a specific TFDS split.
Attributes:
dataset: dataset name.
split: TFDS split (e.g. 'train'), or slice (e.g. 'train[":1%"]').
data_dir: directory to read/write TFDS data.
"""
dataset: str
split: Optional[str]
data_dir: Optional[str] = None
def __post_init__(self):
_validate_tfds_name(self.dataset)
[docs]
class LazyTfdsLoader(object):
"""Wrapper for TFDS datasets with memoization and additional functionality.
Lazily loads info from TFDS and provides memoization to avoid expensive hidden
file operations. Also provides additional utility methods.
"""
_MEMOIZED_BUILDERS = {}
def __init__(
self,
name: Optional[str] = None,
data_dir: Optional[str] = None,
split_map: Union[Mapping[str, str], Mapping[str, TfdsSplit], None] = None,
decoders=None,
builder_kwargs: Optional[dict[str, Any]] = None,
read_only: bool = False,
):
"""LazyTfdsLoader constructor.
Args:
name: str (optional), the name of the TFDS dataset. If `name` is not
specified then `split_map` values must be instances of `TfdsSplit`.
data_dir: str (optional), directory to read/write TFDS data.
split_map: dict (optional), mapping from canonical splits (e.g.,
'validation') to TFDS splits (e.g. 'train'), or slices (e.g.,
'train[':1%']), or `TfdsSplit` (e.g. `TfdsSplit(dataset='mnist',
split='train')`). If `TfdsSplit` are used then `name` must be empty.
decoders: dict (optional), mapping from features to tfds.decode.Decoders,
such as tfds.decode.SkipDecoding() for skipping image byte decoding.
builder_kwargs: `dict` (optional), keyword arguments to be passed to the
`tfds.core.DatasetBuilder` constructor through `tfds.load()` and
`tfds.builder()`.
read_only: whether `get_dataset` can trigger the generation of a dataset.
"""
_validate_tfds_name(name)
self._name = name
self._data_dir = data_dir
self._data_dir_override = None
self._split_map = split_map
self._decoders = decoders
self._builder_kwargs = builder_kwargs
self._read_only = read_only
self._is_custom_split_map = False
if split_map:
random_split_value = next(iter(split_map.values()))
if isinstance(random_split_value, TfdsSplit):
self._is_custom_split_map = True
if self._name or self._data_dir:
raise ValueError(
"If split values are instances of `TfdsSplit`, `name` and"
" `data_dir` must be `None`."
)
@property
def name(self) -> Optional[str]:
return self._name
[docs]
def set_decoders(self, decoders) -> None:
"""Override the decoders for this dataset."""
self._decoders = decoders
@property
def tfds_splits(self) -> Optional[Mapping[str, TfdsSplit]]:
return self._split_map if self._is_custom_split_map else None
[docs]
def resolved_tfds_name(self, split: Optional[str] = None) -> Optional[str]:
"""Returns the resolved TFDS dataset name.
When the specified TFDS name doesn't specify everything, e.g. the version
has a wildcard or the config is not specified, then this function returns
the complete TFDS name if the dataset has already been loaded.
Args:
split: optional split name.
Returns:
complete TFDS name.
"""
if self.is_memoized(split):
return (
self._get_builder(split)
.get_reference()
.tfds_name(include_version=True)
)
else:
dataset, _ = self.get_split_params(split)
return dataset
def __str__(self):
return (
f"{self.__class__.__name__}(name={self.name}, data_dir={self.data_dir})"
)
def __repr__(self):
return (
f"{self.__class__.__name__}("
f"name={self.name},"
f" data_dir={self.data_dir},"
f" split_map={self._split_map},"
f" decoders={self._decoders})"
)
[docs]
def get_split_params(
self, split: Optional[str] = None
) -> Tuple[Optional[str], Optional[str]]:
"""Returns a tuple of (dataset, data_dir) for the given canonical split."""
if self._is_custom_split_map:
if mapped_split := self._split_map.get(split):
dataset = mapped_split.dataset
data_dir = mapped_split.data_dir
else:
raise ValueError(
"`LazyTfdsLoader` refers to multiple datasets, pass `split` value "
"corresponding to one of them to `get_split_params()`."
)
else:
dataset = self.name
data_dir = self.data_dir
return dataset, data_dir
@property
def data_dir(self) -> Optional[str]:
"""Returns the data directory for this TFDS dataset."""
if self._is_custom_split_map:
logging.warning(
"`LazyTfdsLoader` refers to multiple datasets, `data_dir` is unknown."
)
return None
if self._data_dir_override is None:
if data_dir_override := _get_data_dir_override(tfds_name=self.name):
self._data_dir_override = data_dir_override
if self._data_dir_override and self._data_dir:
logging.warning(
"Overriding TFDS data directory '%s' with '%s' for dataset '%s'.",
self._data_dir,
self._data_dir_override,
self.name,
)
return self._data_dir_override or self._data_dir
def override_data_dir(self, data_dir: str) -> None:
self._data_dir_override = data_dir
@property
def read_config(self):
if _TFDS_DATA_READ_CONFIG_OVERRIDE:
return _TFDS_DATA_READ_CONFIG_OVERRIDE
return tfds.ReadConfig()
def _get_builder_key(
self, dataset: Optional[str], data_dir: Optional[str]
) -> Tuple[Optional[str], Optional[str]]:
return (dataset, data_dir)
[docs]
def is_memoized(self, split: Optional[str] = None) -> bool:
"""Returns true if the dataset is memoized."""
dataset, data_dir = self.get_split_params(split)
return (
self._get_builder_key(dataset, data_dir)
in LazyTfdsLoader._MEMOIZED_BUILDERS
)
[docs]
def builder_cls(
self, split: Optional[str] = None
) -> Optional[Type[tfds.core.DatasetBuilder]]:
"""Returns the DatasetBuilder class for this TFDS dataset.
If no builder class can be found (e.g. the class has not been imported),
then `None` will be returned.
Args:
split: optional split for which to return the builder class. This can be
used in case there's a custom split map with different datasets per
split.
Returns:
the builder class of the dataset in case it can be found and `None` if the
builder class cannot be found.
"""
dataset, _ = self.get_split_params(split)
# `builder_cls` only accepts the dataset name, not the config and version.
ds_name, _ = tfds.core.naming.parse_builder_name_kwargs(dataset)
try:
return tfds.builder_cls(str(ds_name))
except ValueError:
return None
@property
def builder(self):
return self._get_builder()
def _get_builder(self, split: Optional[str] = None):
"""Returns the DatasetBuilder for this TFDS dataset."""
dataset, data_dir = self.get_split_params(split)
builder_key = self._get_builder_key(dataset, data_dir)
if builder_key not in LazyTfdsLoader._MEMOIZED_BUILDERS:
if dataset:
builder_kwargs = self._builder_kwargs if self._builder_kwargs else {}
builder = tfds.builder(dataset, data_dir=data_dir, **builder_kwargs)
else:
if self._builder_kwargs:
raise ValueError(
"`builder_kwargs` should be empty when `dataset` value is not"
" present."
)
builder = tfds.builder_from_directory(data_dir)
LazyTfdsLoader._MEMOIZED_BUILDERS[builder_key] = builder
return LazyTfdsLoader._MEMOIZED_BUILDERS[builder_key]
@property
def info(self):
return self.builder.info
def _map_split(self, split: str) -> Optional[str]:
"""Maps the given split to a dataset split."""
if self._is_custom_split_map:
self._split_map: Mapping[str, TfdsSplit]
return self._split_map[split].split
elif self._split_map:
self._split_map: Mapping[str, str]
return self._split_map[split]
else:
return split
[docs]
def files(self, split: str):
"""Returns set of instructions for reading TFDS files for the dataset."""
dataset_split = self._map_split(split)
builder = self._get_builder(split)
if (
self.name is not None
and "/" not in self.name
and builder.BUILDER_CONFIGS
):
# If builder has multiple configs, and no particular config was
# requested, raise an error.
raise ValueError("Dataset '%s' has multiple configs." % self.name)
split_info = builder.info.splits[dataset_split]
files = split_info.file_instructions
if not files:
logging.fatal("No TFRecord files found for dataset: %s", self.name)
return files
[docs]
def load(
self,
split: Optional[str],
shuffle_files: bool,
seed: Optional[int] = None,
shard_info=None,
):
"""Returns a tf.data.Dataset for the given split."""
dataset_split = self._map_split(split)
dataset, data_dir = self.get_split_params(split)
read_config = self.read_config
read_config.input_context = (
tf.distribute.InputContext( # pylint: disable=g-long-ternary
num_input_pipelines=shard_info.num_shards,
input_pipeline_id=shard_info.index,
)
if shard_info
else None
)
read_config.shuffle_seed = seed
read_config.skip_prefetch = True
if self._read_only:
return self.builder.as_dataset(
split=dataset_split,
shuffle_files=shuffle_files,
read_config=read_config,
decoders=self._decoders,
)
else:
return tfds.load(
dataset,
split=dataset_split,
data_dir=data_dir,
shuffle_files=shuffle_files,
download=True,
try_gcs=True,
read_config=read_config,
decoders=self._decoders,
builder_kwargs=self._builder_kwargs,
)
[docs]
def load_shard(
self,
file_instruction,
shuffle_files: bool = False,
seed: Optional[int] = None,
):
"""Returns a dataset for a single shard of the TFDS TFRecord files."""
# pytype:disable=attribute-error
ds = self.builder._tfrecords_reader.read_files( # pylint:disable=protected-access
[file_instruction],
read_config=tfds.ReadConfig(shuffle_seed=seed),
shuffle_files=shuffle_files,
)
# pytype:enable=attribute-error
return ds
[docs]
def size(self, split: str) -> Optional[int]:
"""Returns the number of examples in the split."""
dataset_split = self._map_split(split)
ds_splits = self._get_builder(split).info.splits
dataset_size = ds_splits[dataset_split].num_examples
# Very large datasets have num_examples = 0; default instead to np.inf
dataset_size = dataset_size if dataset_size > 0 else np.inf
return dataset_size
# ============================== TFExamples ====================================
# Type alias for "dictionary of tensors"
TFDict = Dict[
str, Union[np.ndarray, tf.Tensor, tf.RaggedTensor, tf.SparseTensor]
]
# NOTE: We use short prefixes to minimize feature key overhead.
# Used to demarcate features used to store shapes of other tensor
_TFEXAMPLE_SHAPE_PREFIX = "_sh:"
# See `tfexample_ragged_prefix`
_TFEXAMPLE_RAGGED_PREFIX = "_rl:"
# See `tfexample_sparse_indices_prefix`
_TFEXAMPLE_SPARSE_PREFIX = "_sp:"
[docs]
def tfexample_ragged_length_key(key: str, dim: int) -> str:
"""Demarcates feature used to store `dim`-th ragged lengths for `key`.
This function can be used when parsing data generated by `dict_to_tfexample`,
by specifying the following ragged feature:
>>> KEY = "ragged_tensor"
>>> tf.io.parse_single_example(example, {
>>> KEY: tf.io.RaggedFeature(
>>> dtype,
>>> value_key=KEY,
>>> partitions=(
>>> tf.io.RaggedFeature.RowLengths(tfexample_ragged_length_key(KEY, 0)),
>>> tf.io.RaggedFeature.RowLengths(tfexample_ragged_length_key(KEY, 1)),
>>> ... (repeat for however many ragged dimensions the data has)
>>> ),)
>>> })
Args:
key: The key storing the values of the ragged tensor.
dim: The ragged dimension to generate a prefix for.
Returns:
The key of the feature storing the `dim`-th ragged lengths.
"""
return f"{_TFEXAMPLE_RAGGED_PREFIX}{dim}:{key}"
[docs]
def tfexample_sparse_indices_key(key: str, dim: int) -> str:
"""Demarcates feature used to store `dim`-th sparse feature indices for `key`.
This function can be used when parsing data generated by `dict_to_tfexample`,
by specifying the following sparse feature:
>>> KEY = "sparse_tensor"
>>> tf.io.parse_single_example(example, {
>>> KEY: tf.io.SparseFeature(
>>> value_key=KEY,
>>> index_key=[
>>> tfexample_sparse_indices_key(KEY, 0),
>>> tfexample_sparse_indices_key(KEY, 1),
>>> ... (repeat for however many sparse dimensions the data has)
>>> ],
>>> size=[shape0, shape1, ...],
>>> dtype=dtype,
>>> ),
>>> })
Args:
key: The key storing the values of the sparse tensor.
dim: The sparse dimension to generate a prefix for.
Returns:
The key of the feature storing the `dim`-th sparse indices.
"""
return f"{_TFEXAMPLE_SPARSE_PREFIX}{dim}:{key}"
def _to_tffeature(tensor: tf.Tensor) -> tf.train.Feature:
"""Creates an appropriately typed `tf.train.Feature` for `tensor`.
This method only stores the flattened values of `tensor`; make sure you have
stored the shape before use.
Args:
tensor: A tensor to convert.
Returns:
A `tf.train.Feature` for `tensor`
"""
# Flatten the values of the tensor.
values = tf.reshape(tensor, [-1]).numpy().tolist()
if tensor.dtype.is_bool or tensor.dtype.is_integer:
return tf.train.Feature(int64_list=tf.train.Int64List(value=values))
elif tensor.dtype.is_floating:
return tf.train.Feature(float_list=tf.train.FloatList(value=values))
elif tensor.dtype is tf.string:
values = [tf.compat.as_bytes(v) for v in values]
return tf.train.Feature(bytes_list=tf.train.BytesList(value=values))
else:
raise ValueError(f"Unsupported type {tensor.dtype}")
[docs]
def dict_to_tfexample(
dct: TFDict, store_shapes: bool = False
) -> tf.train.Example:
"""Convert dictionary of tensors to a `tf.train.Example` proto.
NOTE: Unfortunately, tensorflow.Example is a very simple proto that can only
store keyed lists of int64, float and bytes, and doesn't have a marker to
store metadata needed for multi-dimensional or ragged tensors. This function
stores using additional features; using `tfexample_to_dict` will allow you to
recover the original dictionary of tensors modulo the following caveat.
tensorflow.train.Example only stores int64s and float32s: this function
casts bools, ints and floats respectively, and this type information is lost.
Args:
dct: A dictionary mapping feature keys to 1-D tensors.
store_shapes: If true, add an additional feature to store the shapes of any
2+d or sparse tensors. This feature is only required when using
`tfexample_to_dict`, as `tf.io.parse_single_example` requires the shape be
provided as an argument.
Returns:
A `tf.train.Example` for `dct`
"""
features = {}
for key, value in dct.items():
if isinstance(value, tf.RaggedTensor):
value: tf.RaggedTensor
features[key] = _to_tffeature(value.flat_values)
for dim, length in enumerate(value.nested_row_lengths()):
features[tfexample_ragged_length_key(key, dim)] = _to_tffeature(length)
elif isinstance(value, tf.SparseTensor):
value: tf.SparseTensor
features[key] = _to_tffeature(value.values)
for dim, indices in enumerate(tf.transpose(value.indices)):
features[tfexample_sparse_indices_key(key, dim)] = _to_tffeature(
indices
)
if store_shapes:
features[_TFEXAMPLE_SHAPE_PREFIX + key] = _to_tffeature(
tf.constant(value.shape)
)
else:
# Cast to a dense tensor.
value = tf.constant(value)
if store_shapes and len(value.shape) > 1:
features[_TFEXAMPLE_SHAPE_PREFIX + key] = _to_tffeature(
tf.constant(value.shape)
)
features[key] = _to_tffeature(value)
return tf.train.Example(features=tf.train.Features(feature=features))
[docs]
def tfexample_to_dict(example: tf.train.Example) -> TFDict:
"""Helper function to create a dictionary of tensors from a TFExample.
NOTE: this function is less efficient than `tf.io.parse_single_example`, but
allows parsing of TFExample without knowing its features ahead of time. See
the documentation of `tfexample_ragged_length_key` for an example of how to
parse ragged tensors efficiently using `tf.io.parse_single_example`.
Args:
example: An instance of tf.train.Example we will convert into a dict.
Returns:
A dict with the keys of `example` and tensors as values.
"""
dct = {}
shapes = {}
ragged_lengths = collections.defaultdict(dict)
sparse_indices = collections.defaultdict(dict)
for key, feature in example.features.feature.items():
if feature.int64_list.value:
value = tf.constant(list(feature.int64_list.value), dtype=tf.int64)
elif feature.float_list.value:
value = tf.constant(list(feature.float_list.value), dtype=tf.float32)
elif feature.bytes_list.value:
value = tf.constant(list(feature.bytes_list.value), dtype=tf.string)
else:
# This is an empty list. We don't know what the type is, so we default
# to tf.string
value = tf.constant(list(feature.bytes_list.value), dtype=tf.string)
if key.startswith(_TFEXAMPLE_RAGGED_PREFIX):
_, dim, key = key.split(":", 2)
ragged_lengths[key][int(dim)] = value
elif key.startswith(_TFEXAMPLE_SPARSE_PREFIX):
_, dim, key = key.split(":", 2)
sparse_indices[key][int(dim)] = value
elif key.startswith(_TFEXAMPLE_SHAPE_PREFIX):
key = key[len(_TFEXAMPLE_SHAPE_PREFIX) :]
shapes[key] = value
else:
dct[key] = value
# Assemble RaggedTensors
for key, length_map in ragged_lengths.items():
flat_values = dct[key]
nested_row_lengths = []
for dim in range(len(length_map)):
if dim not in length_map:
raise ValueError(f"Couldn't find {dim}-th ragged length for {key}")
nested_row_lengths.append(length_map[dim])
dct[key] = tf.RaggedTensor.from_nested_row_lengths(
flat_values, nested_row_lengths
)
# Assemble SparseTensors
for key, indices_map in sparse_indices.items():
if key not in shapes:
raise ValueError(f"Couldn't find dense shape for sparse tensor {key}")
dense_shape = shapes.pop(key)
values = dct[key]
indices_per_dim = []
for dim in range(len(indices_map)):
if dim not in indices_map:
raise ValueError(f"Couldn't find {dim}-th sparse indices for {key}")
indices_per_dim.append(indices_map[dim])
indices = tf.transpose(tf.stack(indices_per_dim))
dct[key] = tf.SparseTensor(indices, values, dense_shape)
# Reshape multi-dimensional tensors
for key, shape in shapes.items():
dct[key] = tf.reshape(dct[key], shape)
return dct
# Type alias that supports nested TFDicts.
NestedTFDict = Dict[
str, Union[tf.Tensor, tf.RaggedTensor, tf.SparseTensor, "NestedTFDict"]
]
# Used to linearize keys in nested TFDicts
_TFEXAMPLE_NESTED_DELIMITER = "/"
[docs]
def unflatten_dict(
dct: TFDict, delimiter=_TFEXAMPLE_NESTED_DELIMITER
) -> NestedTFDict:
"""Create a nested dictionary from one with nested keys.
This method converts a "flat" TFDict with nested keys like:
>>> {
>>> "key1/subkey1": ...,
>>> "key1/subkey2": ...,
>>> "key2/subkey1": ...,
>>> }
into a nested dictionary:
>>> {
>>> "key1": {
>>> "subkey1": ...,
>>> "subkey2": ...,
>>> },
>>> "key2": {
>>> "subkey3": ...
>>> },
>>> }
Args:
dct: An dictionary of tensors.
delimiter: A delimiter used to separate keys from subkeys.
Returns:
A nested-version of `dct`.
"""
nested_dct: NestedTFDict = {}
for key_path, value in dct.items():
keys = key_path.split(delimiter)
# We'll index the value at the last key.
last_key = keys.pop()
sub_dct = nested_dct
for key in keys:
sub_dct = sub_dct.setdefault(key, {})
sub_dct[last_key] = value
return nested_dct
[docs]
def flatten_dict(
nested_dct: NestedTFDict, delimiter=_TFEXAMPLE_NESTED_DELIMITER
) -> TFDict:
"""Create a "flattened" dictionary from one with nested keys.
This method converts a nested TFDict like:
>>> {
>>> "key1": {
>>> "subkey1": ...,
>>> "subkey2": ...,
>>> },
>>> "key2": {
>>> "subkey1": ...
>>> },
>>> }
into a flattened dictionary:
>>> {
>>> "key1/subkey1": ...,
>>> "key1/subkey2": ...,
>>> "key2/subkey1": ...,
>>> }
Args:
nested_dct: A nested dictionary of tensors.
delimiter: A delimiter used to separate keys from subkeys.
Returns:
A flattened version of `nested_dct`.
"""
unnested_dct = {}
def _unnest_dct(dct: NestedTFDict, prefix_key: str = ""):
for key, value in dct.items():
key_ = prefix_key + key
if isinstance(value, dict):
_unnest_dct(value, prefix_key + key + delimiter)
else:
unnested_dct[key_] = value
_unnest_dct(nested_dct)
return unnested_dct
# ================================ Tasks =======================================
def get_cached_info_path(data_dir, split):
return os.path.join(data_dir, _INFO_FILENAME.format(split=split))
def get_cached_tfrecord_prefix(data_dir, split):
return os.path.join(data_dir, _TFRECORD_PREFIX.format(split=split))
def get_cached_stats_path(data_dir, split):
return os.path.join(data_dir, _STATS_FILENAME.format(split=split))
def get_task_dir_from_name(task_name):
return os.path.join(*(task_name.split(":")))
[docs]
def stateless_shuffle(value, seed):
"""Randomly shuffles a tensor, statelessly."""
flat_value = tf.reshape(value, [-1])
indices = tf.argsort(
tf.random.stateless_uniform(tf.shape(flat_value), seed=seed)
)
flat_shuffle = tf.gather(flat_value, indices)
return tf.reshape(flat_shuffle, tf.shape(value))
[docs]
def trim_and_pad_dataset(
dataset: tf.data.Dataset, feature_lengths: Mapping[str, int]
) -> tf.data.Dataset:
"""Trim and pad first dimension of features to `feature_lengths`.
Args:
dataset: tf.data.Dataset, the dataset to trim/pad examples in.
feature_lengths: map from feature key to final length. Other features will
be returned unchanged.
Returns:
Trimmed/padded tf.data.Dataset.
"""
def _trim_and_pad(k: str, t: tf.Tensor) -> tf.Tensor:
"""Trim/pad to the first axis of `t` to be of size `length`."""
if k not in feature_lengths:
return t
if isinstance(t, tf.RaggedTensor):
t = t.to_tensor()
length_k = feature_lengths[k]
if isinstance(length_k, int):
t = t[:length_k]
pad_amt = length_k - tf.shape(t)[0]
padded_t = tf.pad(t, [(0, pad_amt)] + [(0, 0)] * (len(t.shape) - 1))
padded_t.set_shape([length_k] + t.shape.as_list()[1:])
return padded_t
slices = tuple((slice(0, limit) for limit in length_k))
t = t[slices]
pad_amt = tf.pad((length_k - tf.shape(t))[..., None], ((0, 0), (1, 0)))
padded_t = tf.pad(t, pad_amt)
padded_t.set_shape(length_k)
return padded_t
return dataset.map(
lambda x: {k: _trim_and_pad(k, t) for k, t in x.items()},
num_parallel_calls=tf.data.experimental.AUTOTUNE,
)
[docs]
def trim_dataset(
dataset: tf.data.Dataset, sequence_length, output_features
) -> tf.data.Dataset:
"""Trim output features to sequence length."""
def _trim(k: str, v: tf.Tensor) -> tf.Tensor:
if (
k not in output_features
or not sequence_length
or k not in sequence_length
or sequence_length[k] is None
):
return v
# Unify lengths into an iterable so we can create a slice for each
# dimension, even if the length is a single int.
lengths = sequence_length[k]
if isinstance(lengths, int):
lengths = [lengths]
slices = tuple((slice(0, limit) for limit in lengths))
return v[slices]
return dataset.map(
lambda ex: {k: _trim(k, v) for k, v in ex.items()},
num_parallel_calls=tf.data.experimental.AUTOTUNE,
)
def _strip_packed_feature_key(key: str) -> str:
strip_suffix = lambda k, s: k[: -len(s)] if k.endswith(s) else k
return strip_suffix(strip_suffix(key, "_positions"), "_segment_ids")
[docs]
def trim_and_pack_dataset(
dataset: tf.data.Dataset,
feature_lengths: Mapping[str, int],
use_custom_ops: bool = False,
) -> tf.data.Dataset:
"""Creates a 'packed' version of a dataset on-the-fly.
Modified from the tensor2tensor library.
This is meant to replace the irritation of having to create a separate
"packed" version of a dataset to train efficiently on TPU.
Each example in the output dataset represents several examples in the
input dataset.
For each key in the input dataset that also exists in `feature_lengths`, two
additional keys are created:
<key>_segment_ids: an int32 tensor identifying the parts
representing the original example.
<key>_positions: an int32 tensor identifying the position within the
original example.
Features that are not in `feature_lengths` will be removed.
Example:
Two input examples get combined to form an output example.
The input examples are:
{"inputs": [8, 7, 1, 0], "targets":[4, 1, 0], "idx": 0}
{"inputs": [2, 3, 4, 1], "targets":[5, 6, 1], "idx": 1}
The output example is:
{
"inputs": [8, 7, 1, 2, 3, 4, 1, 0, 0, 0]
"inputs_segment_ids": [1, 1, 1, 2, 2, 2, 2, 0, 0, 0]
"inputs_positions": [0, 1, 2, 0, 1, 2, 3, 0, 0, 0]
"targets": [4, 1, 5, 6, 1, 0, 0, 0, 0, 0]
"targets_segment_ids": [1, 1, 2, 2, 2, 0, 0, 0, 0, 0]
"targets_positions": [0, 1, 0, 1, 2, 0, 0, 0, 0, 0]
}
0 represents padding in both the inputs and the outputs.
Sequences in the incoming examples are truncated to length in
`feature_lengths`, and the sequences in the output examples all have this
fixed (padded) length. Features not in `features_length` (i.e, "idx") are
removed.
Args:
dataset: a tf.data.Dataset
feature_lengths: map from feature key to final length. Other features will
be discarded.
use_custom_ops: a boolean - custom ops are faster but require a custom-built
binary, which is not currently possible on cloud-tpu.
Returns:
a tf.data.Dataset
"""
element_spec = dataset.element_spec
# Make sure that the dataset contains all keys in `feature_lengths`.
for k in feature_lengths:
if k not in element_spec:
raise ValueError(
f"Feature '{k}' not found in dataset. Available keys are "
f"{list(element_spec.keys())}"
)
if (
not element_spec[k].shape.is_compatible_with(tf.TensorShape([None]))
and not use_custom_ops
):
raise ValueError(
f"Features to be packed must be one-dimensional. '{k}' is not.' "
"Consider setting use_custom_ops if you have higher-rank features."
)
# Warn if there are any additional keys that will be removed.
additional_keys = set(element_spec) - set(feature_lengths)
if additional_keys:
logging.warning(
"Features not in `features_length` will be removed during packing: %s",
additional_keys,
)
ds = dataset.map(
lambda x: {k: x[k][:l, ...] for k, l in feature_lengths.items()},
num_parallel_calls=tf.data.experimental.AUTOTUNE,
)
# Setting batch_size=length ensures that the concatenated sequences (if they
# have length >=1) are sufficient to fill at least one packed example.
batch_size = max(feature_lengths.values())
padded_shapes = {k: [-1] for k in feature_lengths}
for k in feature_lengths:
padded_shapes[k].extend(dataset.element_spec[k].shape[1:])
ds = ds.padded_batch(batch_size, padded_shapes=padded_shapes)
if use_custom_ops:
ds = _pack_with_custom_ops(ds, feature_lengths)
else:
ds = _pack_with_tf_ops(ds, feature_lengths)
# Set the Tensor shapes correctly since they get lost in the process.
def _set_shape(x):
for k, v in x.items():
new_shape = [feature_lengths[_strip_packed_feature_key(k)]]
new_shape.extend(v.get_shape()[1:])
v.set_shape(new_shape)
return x
return ds.map(_set_shape, num_parallel_calls=tf.data.experimental.AUTOTUNE)
def _pack_with_tf_ops(
dataset: tf.data.Dataset, feature_lengths: Mapping[str, int]
) -> tf.data.Dataset:
"""Helper-function for packing a dataset which has already been batched.
See trim_and_pack_dataset()
Uses tf.while_loop. Slow.
Args:
dataset: a dataset containing padded batches of examples.
feature_lengths: mapping from feature key to packed length.
Returns:
a dataset.
"""
empty_example = {}
for k in feature_lengths:
for suff in ("", "_positions"):
empty_example[k + suff] = tf.zeros([0], dtype=tf.int32)
empty_example[k + suff].set_shape([None])
keys_etc = empty_example.keys()
def _write_packed_example(partial, outputs):
new_partial = empty_example.copy()
new_outputs = {}
for k in keys_etc:
new_outputs[k] = outputs[k].write(
outputs[k].size(),
tf.pad(
partial[k],
[[
0,
feature_lengths[_strip_packed_feature_key(k)]
- tf.size(partial[k]),
]],
),
)
return new_partial, new_outputs
def pack_batch(x: Mapping[str, tf.Tensor]) -> Mapping[str, tf.Tensor]:
"""Internal function to map over.
Consumes a batch of input examples and produces a variable number of output
examples.
Args:
x: a single example
Returns:
a tf.data.Dataset
"""
keys = list(feature_lengths)
partial = empty_example.copy()
first_key, *_ = keys
dynamic_batch_size = tf.shape(x[first_key])[0]
outputs = {}
for k in keys:
outputs[k] = tf.TensorArray(
tf.int32,
size=0,
dynamic_size=True,
element_shape=[feature_lengths[k]],
)
outputs[k + "_positions"] = tf.TensorArray(
tf.int32,
size=0,
dynamic_size=True,
element_shape=[feature_lengths[k]],
)
for i in tf.range(0, dynamic_batch_size):
tf.autograph.experimental.set_loop_options(
shape_invariants=[
(partial, {k: tf.TensorShape([None]) for k in keys_etc}),
(outputs, {k: tf.TensorShape(None) for k in keys_etc}),
]
)
can_append = True
one_example = {}
for k in keys:
val = tf.cast(x[k][i], tf.int32)
val = val[: tf.reduce_sum(tf.cast(tf.not_equal(val, 0), tf.int32))]
one_example[k] = val
for k in keys:
can_append = tf.logical_and(
can_append,
tf.less_equal(
tf.size(partial[k]) + tf.size(one_example[k]),
feature_lengths[k],
),
)
if not can_append:
partial, outputs = _write_packed_example(partial, outputs)
new_partial = {}
for k in keys:
new_seq = one_example[k][: feature_lengths[k]]
new_seq_len = tf.size(new_seq)
new_partial[k] = tf.concat([partial[k], new_seq], 0)
new_partial[k + "_positions"] = tf.concat(
[partial[k + "_positions"], tf.range(new_seq_len, dtype=tf.int32)],
0,
)
partial = new_partial
_, outputs = _write_packed_example(partial, outputs)
packed = {k: outputs[k].stack() for k in keys_etc}
for k in keys:
packed[k + "_segment_ids"] = tf.cumsum(
tf.cast(tf.equal(packed[k + "_positions"], 0), tf.int32), axis=1
) * tf.cast(tf.not_equal(packed[k], 0), tf.int32)
return packed
dataset = dataset.map(
pack_batch, num_parallel_calls=tf.data.experimental.AUTOTUNE
)
return dataset.unbatch()
def _pack_with_custom_ops(
dataset: tf.data.Dataset, feature_lengths: Mapping[str, int]
) -> tf.data.Dataset:
"""Helper-function for packing a dataset which has already been batched.
See trim_and_pack_dataset()
Relies on custom ops which require a custom compiled binary.
Faster than _pack_with_tf_ops(), and denser packing.
Args:
dataset: a dataset containing padded batches of examples.
feature_lengths: mapping from feature key to packed length.
Returns:
a dataset.
"""
# TODO(adarob): Move ops into this library and fix int64 issue.
from tensor2tensor.data_generators.ops import pack_sequences_ops # pylint: disable=g-import-not-at-top
keys = list(feature_lengths)
use_generic_custom_ops = False
if len(keys) == 1:
(k1,) = keys
k2 = k1
elif len(keys) == 2:
k1, k2 = keys
else:
use_generic_custom_ops = True
logging.info(
"`pack_sequences_2` cannot pack more than 2 features. "
"Using `pack_sequences_k` instead."
)
element_spec = dataset.element_spec
for k in feature_lengths:
if not element_spec[k].dtype.is_integer:
use_generic_custom_ops = True
logging.info(
(
"`pack_sequences_2` cannot pack non-integer feature '%s'. "
"Using `pack_sequences_k` instead."
),
k,
)
if not element_spec[k].shape.is_compatible_with(
tf.TensorShape([None, None])
):
use_generic_custom_ops = True
logging.info(
(
"`pack_sequences_2` cannot pack higher rank feature '%s'. "
"Using `pack_sequences_k` instead."
),
k,
)
def custom_pack_batch(x):
"""Map-function."""
if use_generic_custom_ops:
xs = []
max_lengths = []
for k in sorted(feature_lengths.keys()):
xs.append(x[k])
max_lengths.append(feature_lengths[k])
(packed, segment_ids, positions) = pack_sequences_ops.pack_sequences_k(
inputs=xs, max_lengths=max_lengths
)
y = {}
for i, k in enumerate(sorted(feature_lengths.keys())):
y[k] = packed[i]
y[f"{k}_segment_ids"] = segment_ids[i]
y[f"{k}_positions"] = positions[i]
return y
logging.info(
"Features are compatible with `pack_sequences_2`. "
"Not using `pack_sequences_k`."
)
(
k1_packed,
k1_segment_ids,
k1_positions,
k2_packed,
k2_segment_ids,
k2_positions,
) = pack_sequences_ops.pack_sequences2(
# cast to int64 for compatibility with custom ops
tf.cast(x[k1], tf.int64),
tf.cast(x[k2], tf.int64),
feature_lengths[k1],
feature_lengths[k2],
)
packed = {
k1: k1_packed,
k1 + "_segment_ids": k1_segment_ids,
k1 + "_positions": k1_positions,
}
if len(keys) == 2:
packed.update({
k2: k2_packed,
k2 + "_segment_ids": k2_segment_ids,
k2 + "_positions": k2_positions,
})
# cast back to int32
for k, v in packed.items():
packed[k] = tf.cast(v, tf.int32)
return packed
dataset = dataset.map(
custom_pack_batch, num_parallel_calls=tf.data.experimental.AUTOTUNE
)
dataset = dataset.unbatch()
return dataset
def _shift_right_by_one(tensor: tf.Tensor, bos_id: int = 0) -> tf.Tensor:
"""Shift the input tensor to the right by one position without wrapping."""
if not (tensor.dtype.is_integer or tensor.dtype.is_floating):
raise ValueError(f"Only numeric types are supported. Got: {tensor.dtype}")
# tf.roll wraps around the axis.
rolled = tf.roll(tensor, shift=1, axis=0)
# Zero out the first position by multiplying with [0, 1, 1, ..., 1].
depth = tf.shape(tensor)[0]
mask = tf.one_hot(0, depth=depth, on_value=0, off_value=1, dtype=tensor.dtype)
# Expand dims of mask to broadcast to rolled.
dim_expansion = [slice(None, None)] + [None] * (len(rolled.shape) - 1)
mask = mask[dim_expansion]
return rolled * mask + (1 - mask) * bos_id
# ========================= Mixing Rate Functions ==============================
[docs]
def mixing_rate_num_examples(
task,
maximum: Optional[int] = None,
scale: float = 1.0,
temperature: float = 1.0,
fallback_to_num_input_examples: bool = True,
split: str = "train",
) -> float:
"""Mixing rate based on the number of examples for the task's split.
It should be noted that SeqIO only injects the task, and all other parameters
must be provided by the user when initializing the Mixture.
Args:
task: the seqio.Task to compute a rate for.
maximum: an optional maximum value to clip at after constant scaling but
before temperature scaling.
scale: a multiplicative scaling factor applied before temperature.
temperature: a temperature (T) to scale rate (r) by as r^(1/T).
fallback_to_num_input_examples: whether to fallback to using the number of
input examples when the Task is not cached. Otherwise, an error will be
raised.
split: the split to look at for cached stats.
Returns:
The mixing rate for this task.
"""
if task.cache_dir or not fallback_to_num_input_examples:
ret = task.get_cached_stats(split)["examples"]
else:
logging.warning(
(
"Task '%s' not cached so using number of input examples instead of "
"preprocessed examples to compute rate."
),
task.name,
)
ret = task.num_input_examples(split)
ret *= scale
if maximum:
ret = min(ret, maximum)
if temperature != 1.0:
ret = ret ** (1.0 / temperature)
return ret
[docs]
def mixing_rate_num_characters(
task, temperature: float = 1.0, char_count_name: str = "text_chars"
) -> float:
"""Mixing rate based on the number of characters for the task's 'train' split.
Args:
task: the seqio.Task to compute a rate for.
temperature: a temperature (T) to scale rate (r) by as r^(1/T).
char_count_name: feature name of the character counts in the cached stats
file.
Returns:
The mixing rate for this task.
"""
if task.cache_dir is None:
raise ValueError(
"`mixing_rate_num_characters` requires that each task has is cached "
"with the character count stats."
)
ret = task.get_cached_stats("train")[char_count_name]
if temperature != 1.0:
ret = ret ** (1.0 / temperature)
return ret
# ======================== Decorators =========================================
_NEXT_MAP_SEED = None
[docs]
@contextlib.contextmanager
def map_seed_manager(initial_seed=None):
"""Contextmanager to control the initial seed used by `map_over_dataset`."""
global _NEXT_MAP_SEED
old_map_seed = _NEXT_MAP_SEED
_NEXT_MAP_SEED = initial_seed
yield
_NEXT_MAP_SEED = old_map_seed
[docs]
def set_preprocessor_seed(preprocessor_fn, seed=None):
"""Sets the internal map seed for the provided preprocessor."""
return add_kwargs_to_transform(preprocessor_fn, seqio_internal_map_seed=seed)
_SPECIAL_KWARGS = ("sequence_length", "output_features")
@dataclasses.dataclass
class _GrainRandomMapFn(_RandomMapTransform):
"""Grain Transform to represent existing SeqIO random map preprocessors."""
map_fn: Callable[..., Any]
num_seeds: int
num_parallel_calls: int = tf.data.AUTOTUNE
# These are set by SeqIO before applying the preprocessor or before passing
# it to Grain.
sequence_length: Optional[Mapping[str, int]] = None
output_features: Optional[Mapping[str, Any]] = None
def _map_fn_with_special_kwargs(self, *args, **kwargs):
"""Handle _SPECIAL_KWARGS before calling self.map_fn()."""
special_kwargs = {
k: getattr(self, k)
for k in _SPECIAL_KWARGS
if getattr(self, k) is not None
}
map_fn = add_kwargs_to_transform(self.map_fn, **special_kwargs)
return map_fn(*args, **kwargs)
# Path for Grain. Uses seed provided by Grain.
def random_map(self, element, rng: tf.Tensor):
if self.num_seeds == 1:
return self._map_fn_with_special_kwargs(element, seed=rng)
rngs = tf.random.experimental.stateless_split(rng, self.num_seeds)
rngs = tf.unstack(rngs)
return self._map_fn_with_special_kwargs(element, seeds=rngs)
# Path for SeqIO; preserves legacy logic to manage seeds and differs in
# seed-management behavior in Grain.
def __call__(self, dataset: tf.data.Dataset, *args, **kwargs):
global _NEXT_MAP_SEED
if _NEXT_MAP_SEED is None:
random_ds_seeds = ((None, None),) * self.num_seeds
else:
random_ds_seeds = np.arange(
_NEXT_MAP_SEED, _NEXT_MAP_SEED + 2 * self.num_seeds
).reshape(-1, 2)
random_ds_seeds = tuple(tuple(s) for s in random_ds_seeds)
_NEXT_MAP_SEED += 2 * self.num_seeds
seed_datasets = tf.nest.map_structure(
tf.data.Dataset.random, random_ds_seeds
)
def map_fn(element, seeds):
if self.num_seeds == 1:
return self._map_fn_with_special_kwargs(
element, seed=seeds[0], *args, **kwargs
)
return self._map_fn_with_special_kwargs(
element, seeds=seeds, *args, **kwargs
)
return tf.data.Dataset.zip((dataset, seed_datasets)).map(
map_fn, num_parallel_calls=self.num_parallel_calls
)
@dataclasses.dataclass
class _GrainMapFn(_MapTransform):
"""Grain Transform to represent existing SeqIO map preprocessors."""
map_fn: Callable[..., Any]
num_parallel_calls: int
# These are set by SeqIO before applying the preprocessor or before passing
# it to Grain.
sequence_length: Optional[Mapping[str, int]] = None
output_features: Optional[Mapping[str, Any]] = None
def _map_fn_with_special_kwargs(self, *args, **kwargs):
"""Handle _SPECIAL_KWARGS before calling self.map_fn()."""
special_kwargs = {
k: getattr(self, k)
for k in _SPECIAL_KWARGS
if getattr(self, k) is not None
}
map_fn = add_kwargs_to_transform(self.map_fn, **special_kwargs)
return map_fn(*args, **kwargs)
# Path for Grain
def map(self, element):
return self._map_fn_with_special_kwargs(element)
# Path for SeqIO
def __call__(
self, dataset: tf.data.Dataset, *args, **kwargs
) -> tf.data.Dataset:
return dataset.map(
lambda x: self._map_fn_with_special_kwargs(x, *args, **kwargs),
num_parallel_calls=self.num_parallel_calls,
)
[docs]
def map_over_dataset(
fn=None, *, num_seeds=None, num_parallel_calls=tf.data.experimental.AUTOTUNE
):
"""Decorator to map decorated function over dataset.
Many preprocessors map a function over a dataset. This decorator helps reduce
boilerplate for this common pattern.
If `num_seeds` is set to 1, a unique random seed (pair of int32) will be
passed to the mapping function with keyword 'seed'.
If `num_seeds` is greater than 1, unique random seeds (pairs of int32) will be
passed to the mapping function with keyword 'seeds'.
These seeds can be generated deterministically by using the `map_seed_manager`
to set the seed for the process that generates the individual seeds for each
mapping function. These seeds will be set sequentially from the initial seed
for each call to `map_over_dataset` where `num_seeds > 0`.
Args:
fn: map function
num_seeds: optional number of random seeds (pairs of int32) to pass to the
mapping function.
num_parallel_calls: num_parallel_calls value to pass to Dataset.map
Returns:
Callable transform which takes dataset as first argument.
"""
def map_without_seeds(fn):
@functools.wraps(fn)
def wrapped_fn(ds, *args, **kwargs):
return _GrainMapFn(fn, num_parallel_calls)(ds, *args, **kwargs)
return wrapped_fn
def map_with_seeds(fn):
@functools.wraps(fn)
def wrapped_fn(ds, *args, **kwargs):
return _GrainRandomMapFn(fn, num_seeds, num_parallel_calls)(
ds, *args, **kwargs
)
return wrapped_fn
wrapper = map_without_seeds if num_seeds is None else map_with_seeds
return wrapper if fn is None else wrapper(fn)
[docs]
def fully_qualified_class_name(instance: Any) -> str:
"""Returns the fully qualified class name of the given instance."""
return f"{type(instance).__module__}.{type(instance).__name__}"
[docs]
def function_name(function) -> str:
"""Returns the name of a (possibly partially applied) function."""
try:
if inspect.isclass(function):
# function can be a protocol.
if hasattr(function.__class__, "__name__"):
return function.__class__.__name__
else:
return str(function.__class__)
elif isinstance(function, functools.partial):
# functools.partial can be applied multiple times.
return function_name(function.func)
elif hasattr(function, "__name__"):
return function.__name__
except Exception: # pylint: disable=broad-exception-caught
pass
# Function name could not be determined.
return ""
# Lock to prevent multiple threads from calling tf.io.gfile.glob on the same
# file pattern at the same time.
_LIST_SHARD_LOCKS: dict[str, threading.Lock] = collections.defaultdict(
threading.Lock
)
[docs]
def list_files(file_patterns: Union[str, Iterable[str]]) -> list[str]:
"""Returns a sorted list of file for the given file pattern.
Note that shard patterns like `foo@2` are not expanded. Only glob patterns
are expanded, e.g., `foo*`.
Args:
file_patterns: A string or an iterable of strings, each of which is a file
pattern to expand.
"""
if isinstance(file_patterns, str):
file_patterns = [file_patterns]
result = []
for file_pattern in file_patterns:
if not glob.has_magic(file_pattern):
result.append(file_pattern)
continue
# Make sure only one thread is calling tf.io.gfile.glob on the same file
# pattern at the same time. The other threads will wait for the lock to be
# released, and then use the cached result.
with _LIST_SHARD_LOCKS[file_pattern]:
result.extend(_list_files_for_glob(file_pattern))
return result
@functools.lru_cache(maxsize=1024)
def _list_files_for_glob(file_pattern: str) -> list[str]:
"""Returns a sorted list of files for the given glob file pattern."""
file_names = set(tf.io.gfile.glob(file_pattern))
return sorted(file_names)