# Copyright 2026 The SeqIO Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Experimental utilities for SeqIO."""
import functools
import inspect
from typing import Callable, Iterable, Mapping, Optional, Sequence
from absl import logging
from seqio import dataset_providers
from seqio import preprocessors as seqio_preprocessors
from seqio import utils
import tensorflow as tf
import tensorflow_datasets as tfds
CacheDatasetPlaceholder = dataset_providers.CacheDatasetPlaceholder
Mixture = dataset_providers.Mixture
MixtureRegistry = dataset_providers.MixtureRegistry
ShardInfo = dataset_providers.ShardInfo
Task = dataset_providers.Task
TaskRegistry = dataset_providers.TaskRegistry
def _enfore_empty_registries():
"""Enforces that the seqio TaskRegistry and MixtureRegistry are empty."""
non_empty_registries = []
if TaskRegistry.names():
logging.error('TaskRegistry has %s tasks', len(TaskRegistry.names()))
for task_name in TaskRegistry.names():
logging.error('Registered Task: %s', task_name)
non_empty_registries.append('TaskRegistry')
if MixtureRegistry.names():
logging.error(
'MixtureRegistry has %s mixtures', len(MixtureRegistry.names())
)
for mixture_name in MixtureRegistry.names():
logging.error('Registered Mixture: %s', mixture_name)
non_empty_registries.append('MixtureRegistry')
if non_empty_registries:
non_empty_registries = ', '.join(non_empty_registries)
raise ValueError(
f'The {non_empty_registries} is non-empty. Please invoke'
' `disable_registry()` before any Task/Mixtures are registered.'
)
def _get_argument(name, position, args, kwargs, default_val=''):
if name in kwargs:
return kwargs.get(name)
if len(args) > position:
return args[position]
return default_val
def _no_op_task_registry_add(*args, **kwargs):
logging.info(
'No-op Task register call: %s', _get_argument('name', 0, args, kwargs)
)
def _no_op_mixture_registry_add(*args, **kwargs):
logging.info(
'No-op Mixture register call: %s', _get_argument('name', 0, args, kwargs)
)
def _no_op_task_registry_get(*args, **kwargs):
name = _get_argument('name', 0, args, kwargs)
raise ValueError(f'Disabled TaskRegistry.get call: {name}')
def _no_op_mixture_registry_get(*args, **kwargs):
name = _get_argument('name', 0, args, kwargs)
raise ValueError(f'Disabled MixtureRegistry.get call: {name}')
[docs]
def disable_registry():
"""Disables the seqio TaskRegistry and MixtureRegistry."""
_enfore_empty_registries()
dataset_providers.TaskRegistry.add = _no_op_task_registry_add
dataset_providers.TaskRegistry.add_provider = _no_op_task_registry_add
dataset_providers.TaskRegistry.get = _no_op_task_registry_get
dataset_providers.MixtureRegistry.add = _no_op_mixture_registry_add
dataset_providers.MixtureRegistry.add_provider = _no_op_mixture_registry_add
dataset_providers.MixtureRegistry.get = _no_op_mixture_registry_get
def _get_fully_cached_name(
original_name: str, sequence_length: Mapping[str, int]
) -> str:
"""Generates name for fully-cached task or mixture."""
new_name = f'{original_name}_'
# Find shortest unique prefix.
prefix_len = 0
while len(set(feat[:prefix_len] for feat in sequence_length)) != len(
sequence_length
):
prefix_len += 1
new_name += '_'.join(
f'{feat[:prefix_len]}{sequence_length[feat]}' for feat in sequence_length
)
return new_name
[docs]
def add_fully_cached_task(
task_name: str,
sequence_length: Mapping[str, int],
disallow_shuffling: bool = False,
) -> Task:
"""Adds fully-cached version of the task for given sequence lengths."""
task = TaskRegistry.get(task_name)
new_name = _get_fully_cached_name(task_name, sequence_length)
try:
return TaskRegistry.get(new_name)
except ValueError:
pass
# Rename the sequence lengths to differentiate from the preprocessor kwarg.
fixed_sequence_length = sequence_length
new_preprocessors = []
for prep in task.preprocessors:
if isinstance(prep, CacheDatasetPlaceholder):
continue
def wrapped_prep(ds, output_features, prep=prep):
prep_args = inspect.signature(prep).parameters.keys()
extra_kwargs = {}
if 'sequence_length' in prep_args:
extra_kwargs['sequence_length'] = fixed_sequence_length
if 'output_features' in prep_args:
extra_kwargs['output_features'] = output_features
return prep(ds, **extra_kwargs)
new_preprocessors.append(wrapped_prep)
# Cache at the end of the pipeline.
new_preprocessors.append(CacheDatasetPlaceholder(required=True))
# Add post-cache preprocessor to ensure the runtime sequence length is valid.
def validate_sequence_length(ds, sequence_length):
if sequence_length is not None and dict(sequence_length) != dict(
fixed_sequence_length
):
raise ValueError(
f"Fully-cached task '{new_name}' can only be loaded with "
f'`sequence_length={fixed_sequence_length}` or `None`. '
f'Given sequence_length={sequence_length}.'
)
return ds
new_preprocessors.append(validate_sequence_length)
logging.info(
"Registering fully cached Task '%s' with sequence lengths %s.",
new_name,
sequence_length,
)
return TaskRegistry.add(
new_name,
source=task.source,
preprocessors=new_preprocessors,
output_features=task.output_features,
metric_fns=task.metric_fns,
postprocess_fn=task.postprocessor,
shuffle_buffer_size=None
if disallow_shuffling
else dataset_providers.SHUFFLE_BUFFER_SIZE,
)
[docs]
def add_fully_cached_mixture(
mixture_name: str,
sequence_length: Mapping[str, int],
disallow_shuffling: bool = False,
) -> Mixture:
"""Adds fully-cached version of the mixture for given sequence lengths."""
mixture = MixtureRegistry.get(mixture_name)
new_name = _get_fully_cached_name(mixture_name, sequence_length)
# Register fully-cached tasks for the mixture.
new_tasks = [
add_fully_cached_task(task.name, sequence_length, disallow_shuffling)
for task in mixture.tasks
]
logging.info(
"Registering fully cached Mixture '%s' with sequence lengths %s.",
new_name,
sequence_length,
)
return MixtureRegistry.add(
new_name,
[
(new_t.name, mixture._task_to_rate[old_t.name]) # pylint:disable=protected-access
for old_t, new_t in zip(mixture.tasks, new_tasks)
],
)
[docs]
class FewshotDataSource(dataset_providers.DataSource):
"""Combines two splits of another `DataSource` to provide fewshot examples.
Output examples are a dictionary containing a single eval example and a batch
of train examples. For example, with `num_shots=2`:
{
'train': {
'inputs': [
'How many Beatles are there?', 'How many Beatles are alive in 2020?'
],
'targets': ['4', '2']
},
'eval': {
'inputs': 'What city were the Beatles from?'
'targets': 'Liverpool'
}
}
Note that if `num_shots` is 0, the 'train' entry will not be included in the
resulting examples.
"""
def __init__(
self,
original_source: dataset_providers.DataSource,
num_shots: int,
train_preprocessors: Iterable[
Callable[[tf.data.Dataset], tf.data.Dataset]
] = (),
eval_preprocessors: Iterable[
Callable[[tf.data.Dataset], tf.data.Dataset]
] = (),
train_split: str = 'train',
train_feature_keys: Iterable[str] = ('inputs', 'targets'),
shuffle_buffer_size: int = dataset_providers.SHUFFLE_BUFFER_SIZE,
eval_on_fixed_exemplars: bool = False,
):
"""Initializes FewshotDataSource.
Args:
original_source: a DataSource to produce fewshot examples from.
num_shots: A non-negative integer specifying how many training examples to
include in the inputs.
train_preprocessors: an iterable of preprocessors to run on the train
split before zipping with the eval split.
eval_preprocessors: an iterable of preprocessors to run on the eval split
before zipping with the train split.
train_split: the split to use as training examples.
train_feature_keys: the features to retain in the train split after
preprocessing but before batching zipping with the eval split. This is
necessary to remove variable-length sequences, which cannot be batched.
shuffle_buffer_size: size of the shuffle buffer used when calling
`get_dataset` with shuffle=True. Note that separate shuffles are applied
to the `train` and `eval` splits before they are combined.
eval_on_fixed_exemplars: If True, uses a fixed set of exemplars at
evaluation time. Only effective during evaluation when `split` not
equals `self._train_split`.
"""
self._original_source = original_source
self._num_shots = num_shots
self._train_preprocessors = train_preprocessors
self._eval_preprocessors = eval_preprocessors
self._train_split = train_split
self._train_feature_keys = train_feature_keys
self._shuffle_buffer_size = shuffle_buffer_size
self._eval_on_fixed_exemplars = eval_on_fixed_exemplars
# Override split in property since it may need to be loaded lazily (e.g.,
# for TfdsSource)
super().__init__(splits=())
@property
def splits(self) -> Sequence[str]:
return self._original_source.splits
@property
def supports_arbitrary_sharding(self) -> bool:
return False
[docs]
@functools.lru_cache()
def list_shards(self, split: str) -> Sequence[str]:
return self._original_source.list_shards(split)
[docs]
def get_dataset(
self,
split: str = tfds.Split.TRAIN,
shuffle: bool = True,
seed: Optional[int] = None,
shard_info: Optional[ShardInfo] = None,
*, # remaining args are out of order from parent
sequence_length: Optional[Mapping[str, int]] = None, # Unused
use_cached: bool = False, # Unused
num_epochs: Optional[int] = 1, # Unused
) -> tf.data.Dataset:
shard_info: ShardInfo = shard_info or ShardInfo(0, 1)
if self._train_split not in self._original_source.splits:
raise ValueError(
f"Train split '{self._train_split}' is not one of the original "
f'source splits: {self._original_source.splits}'
)
if not self._num_shots:
logging.warning(
'Train examples will not be included in the provided dataset since '
'`num_shots` is 0.'
)
def _apply_preprocessors(ds, preprocessors):
for prep_fn in preprocessors:
ds = prep_fn(ds)
return ds
def _get_maybe_sharded_dataset(
split_: str, shuffle_: bool, seed_: int
) -> tf.data.Dataset:
"""Shard at source if possible, but fall back to examples if not."""
num_shards = len(self._original_source.list_shards(split_))
if num_shards >= shard_info.num_shards:
# Shard at the source.
ds = self._original_source.get_dataset(
split=split_, shuffle=shuffle_, seed=seed_, shard_info=shard_info
)
else:
# Shard the examples.
ds = self._original_source.get_dataset(
split=split_, shuffle=shuffle_, seed=seed_
).shard(shard_info.num_shards, shard_info.index)
if shuffle_:
# Do our own shuffling here, because original_source.get_dataset does
# not necessarily return an adequately shuffled dataset even when we
# request shuffle=True. For example, TfdsDataSource only shuffles at the
# file shard level, not the individual example level (this amounts to no
# shuffling if there is only one file shard).
ds = ds.shuffle(
buffer_size=self._shuffle_buffer_size,
seed=seed_,
reshuffle_each_iteration=True,
)
return ds
if seed is None:
train_seed = None
eval_seed = None
else:
# If fixing a seed, train and eval seeds need to be different, otherwise
# in the num_shots=1 case, identical examples would be zipped together.
train_seed = seed
eval_seed = seed + 1
datasets = {}
if self._num_shots:
# Note that we ALWAYS shuffle the train split, even if the user passes
# shuffle=False. This is to prevent the degenerate situation where train
# and eval examples are identical. In the case of shuffle=False, we still
# guarantee determinism by using a fixed seed of 0.
train_ds = _get_maybe_sharded_dataset(
split_=self._train_split,
shuffle_=True,
seed_=train_seed if shuffle else 0,
)
train_ds = _apply_preprocessors(train_ds, self._train_preprocessors)
train_ds = train_ds.map(
lambda x: {k: x[k] for k in self._train_feature_keys},
num_parallel_calls=tf.data.experimental.AUTOTUNE,
)
train_ds = train_ds.repeat().batch(self._num_shots)
if self._eval_on_fixed_exemplars and split != self._train_split:
train_ds = train_ds.take(1).cache().repeat()
datasets['train'] = train_ds
eval_ds = _get_maybe_sharded_dataset(
split_=split, shuffle_=shuffle, seed_=eval_seed
)
eval_ds = _apply_preprocessors(eval_ds, self._eval_preprocessors)
datasets['eval'] = eval_ds
return tf.data.Dataset.zip(datasets)
[docs]
def fewshot_preprocessor(
ds,
inputs_prefix='',
targets_prefix='',
example_separator='\n\n',
prompt='',
reverse=False,
):
"""Create 'inputs' and 'targets' strings for (zero/few)-shot evaluation.
Inputs and targets will be formatted using the given prefixes along with a
separator between each pair. The few-shot examples from the train set will
include both inputs and targets, whereas the eval example (at the end) will
contain only the input followed by the targets prefix.
NOTE: The final target prefix will be right-stripped so that the input does
not end with whitepsace.
For example, a 2-shot output might look like:
output: {
'inputs':
'0 How many states in the US? X 1 50 X 0 How many cents in a dollar? X '
'1 100 X 0 Who was in the Beatles? X 1',
'targets': 'John',
'answers': ['John', 'Paul', 'George', 'Ringo']
}
Args:
ds: A dictionary of zipped eval and train tf.data.Datasets, each
preprocessed with at least the fields 'inputs' and 'targets'. Note that
the train dataset will not exist in the 0-shot case.
inputs_prefix: Prefix string for inputs.
targets_prefix: Prefix string for targets.
example_separator: The string separator to delimit different examples.
prompt: Optional prefix for the entire few-shot input. Typically consists of
a natural language description of the task or task instructions.
reverse: If True, the list of few shot examples is reversed. If used with
eval_on_fixed_exemplars = True and a fixed train_seed, the last N shots
will be the same when num_shots is N or N+M. In other words, additional
shots are prepended instead of appended.
Returns:
A tf.data.Dataset containing 'inputs', 'targets', and any other features
from the evaluation dataset.
"""
@utils.map_over_dataset
def fewshot_map(ex):
if 'train' in ex:
train_examples = tf.stack(
[
inputs_prefix + ex['train']['inputs'],
targets_prefix + ex['train']['targets'] + example_separator,
],
axis=1,
)
if reverse:
train_examples = tf.reverse(train_examples, [0])
shots = tf.strings.reduce_join(tf.reshape(train_examples, [-1]))
else:
shots = ''
if prompt:
shots = tf.strings.join([prompt, shots], separator=example_separator)
new_ex = {
'inputs': (
shots
+ inputs_prefix
+ ex['eval']['inputs']
+ targets_prefix.rstrip()
),
'targets': ex['eval']['targets'],
}
# Pass through other eval features unchanged.
new_ex.update(
{k: v for k, v in ex['eval'].items() if k not in ('inputs', 'targets')}
)
return new_ex
ds = fewshot_map(ds)
if ds.element_spec['inputs'].shape.rank:
# Unbatch if not a scalar. This is useful for fewshot eval.
ds = ds.unbatch()
return ds
[docs]
def add_task_with_sentinels(task_name: str, num_sentinels: Optional[int] = 1):
"""Adds sentinels to the inputs/outputs of a task.
Adds num_sentinels sentinels to the end of 'inputs' and at the beginning
of 'targets'. This is known to help fine-tuning span corruption models,
especially on smaller datasets.
This will also rename the task by adding a "_{num_sentinels}_sentinel" suffix
to the task name, but making sure it comes before the following suffixes:
'_train', '_dev', '_test', '.'.
Example before:
'inputs': What is the captial of illinois?
'targets': Springfield.
Example after:
'inputs': What is the captial of illinois? <extra_id_0>
'targets': <extra_id_0> Springfield.
Args:
task_name: a str, which is the name of the task you want to have sentinels
added to. Note this will not override the current task, but will create a
new one.
num_sentinels: integer, number of sentinels to end of inputs and the
beginning of targets.
"""
def _append_eos_after_trim_and_preserve(
dataset: tf.data.Dataset,
output_features: Mapping[str, dataset_providers.Feature],
sequence_length: Optional[Mapping[str, int]] = None,
preserve_final_n_tokens_when_trimming: Optional[int] = None,
) -> tf.data.Dataset:
"""Version of append_eos_after_trim with option to preserve last n tokens."""
def _maybe_add_eos_and_trim(key: str, value: tf.Tensor) -> tf.Tensor:
if key not in output_features or not output_features[key].add_eos:
return value
eos_id = output_features[key].vocabulary.eos_id
if (
sequence_length is not None
and sequence_length.get(key, None) is not None
):
max_length = sequence_length[key]
if (
preserve_final_n_tokens_when_trimming is not None
and preserve_final_n_tokens_when_trimming > 0
):
# Compute the new length of the sequence excluding the EOS token.
trimmed_length = tf.minimum(max_length, tf.shape(value)[0] + 1)
# Can't preserve more tokens than the sequence length.
n_tokens_to_preserve = tf.minimum(
preserve_final_n_tokens_when_trimming, trimmed_length - 1
)
# pylint: disable=invalid-unary-operand-type
return tf.concat(
[
value[: trimmed_length - (n_tokens_to_preserve + 1)],
value[-n_tokens_to_preserve:],
[eos_id],
],
axis=0,
)
# pylint: enable=invalid-unary-operand-type
else:
return tf.concat([value[: max_length - 1], [eos_id]], axis=0)
else:
return tf.concat([value, [eos_id]], axis=0)
return dataset.map(
lambda ex: {k: _maybe_add_eos_and_trim(k, v) for k, v in ex.items()},
num_parallel_calls=tf.data.experimental.AUTOTUNE,
)
def _create_new_task_name(task_name):
"""Creates the new task name with sentinels added."""
sentinel_name = '_{}_sentinel'.format(num_sentinels)
# Avoid messing up evaluation suffixes, so insert the sentinel name right
# before these keywords.
for suffix in ['_train', '_dev', '_test', '_eval', '.']:
idx = task_name.find(suffix)
if idx >= 0:
return task_name[:idx] + sentinel_name + task_name[idx:]
return task_name + sentinel_name
def _sentinel_id(vocabulary, sentinel_num=0):
"""Token ID to use as a sentinel.
Args:
vocabulary: a t5.data.vocabularies.Vocabulary
sentinel_num: an optional interger, what sentinel should be returned. By
default it returns the first sentinel.
Returns:
an integer
"""
return vocabulary.vocab_size - 1 - sentinel_num
def _add_sentinels(dataset, sequence_length, output_features):
"""Adds sentinels to end of inputs and beginning of targets."""
del sequence_length
input_vocab = output_features['inputs'].vocabulary
target_vocab = output_features['targets'].vocabulary
@utils.map_over_dataset
def _my_fn(x):
sentinels_input = [
_sentinel_id(input_vocab, idx) for idx in range(num_sentinels)
]
sentinels_output = [
_sentinel_id(target_vocab, idx) for idx in range(num_sentinels)
]
x['inputs'] = tf.concat([x['inputs'], sentinels_input], 0)
x['targets'] = tf.concat([sentinels_output, x['targets']], 0)
return x
return _my_fn(dataset)
def _postprocess_fn_remove_sentinel(string_label, *args, **kwargs):
"""If sentinels are appended to the task, then remove them before eval."""
del args
del kwargs
vocab = task.output_features['targets'].vocabulary
sentinel_str = vocab.decode(
[_sentinel_id(vocab, idx) for idx in range(num_sentinels)]
)
if string_label.startswith(sentinel_str):
string_label = string_label[len(sentinel_str) :].strip()
return string_label
def _wrap_postprocess_fn_remove_sentinel(postprocess_fn):
"""Wrap around another postprocess_fn to remove sentinels first."""
def new_fn(string_label, *args, **kwargs):
string_label = _postprocess_fn_remove_sentinel(
string_label, *args, **kwargs
)
return postprocess_fn(string_label, *args, **kwargs)
return new_fn
# Create the new task name.
task = TaskRegistry.get(task_name)
sentinel_task_name = _create_new_task_name(task_name)
# Make the new preprocessors that will insert sentinels and make sure
# sentinels are preserved if the sequences are trimmed.
new_preprocessors = list(task.preprocessors)
if new_preprocessors[-1] is seqio_preprocessors.append_eos_after_trim:
new_eos_funtion = functools.partial(
_append_eos_after_trim_and_preserve,
preserve_final_n_tokens_when_trimming=num_sentinels,
)
new_preprocessors[-1] = new_eos_funtion
new_preprocessors.insert(-1, _add_sentinels)
else:
new_preprocessors.append(_add_sentinels)
# Remove the inserted sentinels in the postprocessor.
postprocess_fn = task.postprocessor
if postprocess_fn is not None:
new_postprocess_fn = _wrap_postprocess_fn_remove_sentinel(postprocess_fn)
else:
new_postprocess_fn = _postprocess_fn_remove_sentinel
TaskRegistry.add(
sentinel_task_name,
source=task.source,
preprocessors=new_preprocessors,
output_features=task.output_features,
postprocess_fn=new_postprocess_fn,
metric_fns=task.metric_fns,
)