seqio.dataset_providers package#
Interfaces#
- class seqio.dataset_providers.DataSourceInterface(*args, **kwargs)[source]#
Interface for DataSource.
- class seqio.dataset_providers.DatasetProviderBase[source]#
Abstract base for classes that provide a tf.data.Dataset.
- class seqio.dataset_providers.DatasetProviderRegistry[source]#
Base for registry of data providers.
Subclasses must wrap get method to override the return type for pytype. TODO(adarob): Remove the need to override get.
- classmethod add(name, provider_cls, provider_kwargs)[source]#
Instantiates and adds provider to the registry.
Data Sources#
- class seqio.dataset_providers.DataSource(splits, num_input_examples=None, caching_permitted=True, performs_internal_shuffling=False)[source]#
A DatasetProvider that provides raw data from an input source.
Inherits all abstract methods and properties of DatasetProviderBase except those overridden below.
- property caching_permitted#
Indicates whether this data source may be cached.
Caching may be prohibited for the sake of data versioning rigor or as a matter of policy for certain datasets.
- abstract get_dataset(split=Split('train'), shuffle=True, seed=None, shard_info=None, *, sequence_length=None, use_cached=False, num_epochs=1)[source]#
Overrides base class to add shard identifier and remove use_cached.
- Parameters:
split – string, the split to return.
shuffle – bool, whether to shuffle the input source.
seed – tf.int64 scalar tf.Tensor (or None) for shuffling input source.
shard_info – optional specification for loading a shard of the split.
sequence_length – Unused
use_cached – Unused
num_epochs – Unused
- property output_features#
Override unused property of DatasetProviderBase.
- property performs_internal_shuffling#
Indicates whether this data source performs internal shuffling.
Some datasets may provide internal shuffling mechanisms that could allow the dataset to be shuffled without calling ds.shuffle().
- abstract property supports_arbitrary_sharding#
Whether supports sharding beyond those available in list_shards.
- class seqio.dataset_providers.TfdsDataSource(tfds_name=None, tfds_data_dir=None, splits=None, caching_permitted=True, decoders=None, tfds_builder_kwargs=None, read_only=False)[source]#
A DataSource that uses TensorFlow Datasets to provide the input data.
- get_dataset(split=None, shuffle=True, seed=None, shard_info=None, *, sequence_length=None, use_cached=False, num_epochs=1)[source]#
Overrides base class to add shard identifier and remove use_cached.
- Parameters:
split – string, the split to return.
shuffle – bool, whether to shuffle the input source.
seed – tf.int64 scalar tf.Tensor (or None) for shuffling input source.
shard_info – optional specification for loading a shard of the split.
sequence_length – Unused
use_cached – Unused
num_epochs – Unused
- property splits#
Overrides since we can’t call info.splits until after init.
- property supports_arbitrary_sharding#
Whether supports sharding beyond those available in list_shards.
- class seqio.dataset_providers.FileDataSource(read_file_fn, split_to_filepattern, num_input_examples=None, caching_permitted=True, file_shuffle_buffer_size=None, cycle_length=16, block_length=16, performs_internal_shuffling=False)[source]#
A DataSource that reads a file to provide the input dataset.
- get_dataset(split=Split('train'), shuffle=True, seed=None, shard_info=None, *, sequence_length=None, use_cached=False, num_epochs=1)[source]#
Overrides base class to add shard identifier and remove use_cached.
- Parameters:
split – string, the split to return.
shuffle – bool, whether to shuffle the input source.
seed – tf.int64 scalar tf.Tensor (or None) for shuffling input source.
shard_info – optional specification for loading a shard of the split.
sequence_length – Unused
use_cached – Unused
num_epochs – Unused
- property supports_arbitrary_sharding#
Whether supports sharding beyond those available in list_shards.
- class seqio.dataset_providers.TFExampleDataSource(split_to_filepattern, feature_description, reader_cls=<class 'tensorflow.python.data.ops.readers.TFRecordDatasetV2'>, num_input_examples=None, caching_permitted=True, file_shuffle_buffer_size=None, cycle_length=16, block_length=16, performs_internal_shuffling=False)[source]#
A FileDataSource that reads files of tf.train.Example protos as input.
- class seqio.dataset_providers.TextLineDataSource(split_to_filepattern, skip_header_lines=0, num_input_examples=None, caching_permitted=True, file_shuffle_buffer_size=None, cycle_length=16, block_length=16)[source]#
A FileDataSource that reads lines of text from a file as input.
- class seqio.dataset_providers.ProtoDataSource(split_to_filepattern, decode_proto_fn, reader_cls=<class 'tensorflow.python.data.ops.readers.TFRecordDatasetV2'>, num_input_examples=None, caching_permitted=True, file_shuffle_buffer_size=None, cycle_length=16, block_length=16)[source]#
A FileDataSource that reads files of arbitrary protos as input.
- class seqio.dataset_providers.FunctionDataSource(dataset_fn, splits, num_input_examples=None, caching_permitted=True)[source]#
A DataSource that uses a function to provide the input data.
This source is not recommended when shuffling is required unless it is cached/materialized in advance. Using this source without caching for training will result in insufficient shuffling and lead to repeated data on restarts.
- get_dataset(split=Split('train'), shuffle=True, seed=None, shard_info=None, *, sequence_length=None, use_cached=False, num_epochs=1)[source]#
Overrides base class to add shard identifier and remove use_cached.
- Parameters:
split – string, the split to return.
shuffle – bool, whether to shuffle the input source.
seed – tf.int64 scalar tf.Tensor (or None) for shuffling input source.
shard_info – optional specification for loading a shard of the split.
sequence_length – Unused
use_cached – Unused
num_epochs – Unused
- property supports_arbitrary_sharding#
Whether supports sharding beyond those available in list_shards.
Task#
- class seqio.dataset_providers.Task(name, source, output_features, preprocessors=None, postprocess_fn=None, metric_fns=None, metric_objs=None, shuffle_buffer_size=1000, source_info=None)[source]#
A class to manage a dataset and its related metrics.
- property cache_dir#
Returns the cache directory (or None), initializing if needed.
- get_dataset(sequence_length=None, split=Split('train'), use_cached=False, shuffle=True, shuffle_buffer_size=None, seed=None, shard_info=None, num_epochs=1, trim_output_features=True, try_in_mem_cache=True)[source]#
Returns a tf.data.Dataset from cache or generated on the fly.
- Parameters:
sequence_length – dict mapping feature key to maximum int length for that feature. If longer after preprocessing, the feature will be truncated. May be set to None to avoid truncation.
split – string, the split to return.
use_cached – bool, whether to use the cached dataset instead of processing it on the fly. Defaults to False.
shuffle – bool, whether to shuffle the dataset. Only used when generating on the fly (use_cached=False).
shuffle_buffer_size – an integer or None to use task-specific buffer size.
seed – tf.int64 scalar tf.Tensor (or None) for shuffling tf.data.
shard_info – optional specification for loading a shard of the split. If the Task’s DataSource contains at least the number of shards in the specification, it will be passed the shard info to avoid loading the full source dataset. Otherwise, the full source dataset will be loaded and sharded at the individual examples.
num_epochs – the number of times to iterate through the dataset, or None to repeat indefinitely. Note that the repeat occurs in the pipeline after offline caching, but before applying potentially stochastic post-cache preprocessors and is therefore typically preferred to calling repeat() on the returned dataset. Defaults to 1.
trim_output_features – If True, it trims output features to be less than the length given by sequence_length.
try_in_mem_cache – If True, caches sufficiently small datasets in memory for efficiency.
- Returns:
A tf.data.Dataset.
- property metric_fns#
List of all metric functions.
- property metric_objs#
List of all metric objects.
- postprocess_fn(decoded_model_output, **postprocess_kwargs)[source]#
Returns the model output after applying the postprocess function.
- property predict_metric_fns#
List of metric functions that use model predictions.
- property predict_with_aux_metric_fns#
List of metric functions that use model predictions with aux values.
- preprocess_postcache(dataset, sequence_length, seed=None)[source]#
Runs preprocessing steps after the optional CacheDatasetPlaceholder.
- Parameters:
dataset – a tf.data.Dataset
sequence_length – dict mapping feature key to int length for that feature. If None, the features will not be truncated.
seed – an optional random seed for deterministic preprocessing.
- Returns:
a tf.data.Dataset
- preprocess_precache(dataset, seed=None)[source]#
Runs preprocessing steps before the optional CacheDatasetPlaceholder.
- property requires_caching#
Whether or not this task requires offline caching.
- property score_metric_fns#
List of metric functions that use log likelihood scores.
- property supports_caching#
Whether or not this task supports offline caching.
- class seqio.dataset_providers.ShardInfo(index, num_shards)[source]#
A container for specifying sharding info.
- class seqio.dataset_providers.SourceInfo(file_path=None, line_number=None)[source]#
Information about the source location of a class or function.
- file_path#
where on disk the source code is located.
- Type:
str | None
- line_number#
the line number in the file where the class/function/etc is defined.
- Type:
int | None
Mixture#
- class seqio.dataset_providers.Mixture(name, tasks, default_rate=None, sample_fn=functools.partial(<function DatasetV2.sample_from_datasets>, stop_on_empty_dataset=True), source_info=None)[source]#
Class for mixing multiple tasks.
- get_dataset(sequence_length=None, split=Split('train'), use_cached=False, shuffle=True, seed=None, shard_info=None, num_epochs=None, copy_pretokenized=False, compute_stats_empirically=False, log_mixing_proportions=True, passthrough_features=None, trim_output_features=True, try_in_mem_cache=True)[source]#
Returns the dataset of mixed tasks using the object-specified rates.
- Parameters:
sequence_length – dict mapping feature key to maximum int length for that feature. If longer after preprocessing, the feature will be truncated. May be set to None to avoid truncation.
split – string, the split to return for all tasks.
use_cached – bool, whether to use the cached dataset instead of processing it on the fly. This will be passed to the underlying Tasks in the Mixture. Defaults to False.
shuffle – bool, whether to shuffle the dataset. Only used when generating on the fly (use_cached=False).
seed – tf.int64 scalar tf.Tensor (or None) for shuffling tf.data.
shard_info – optional specification for loading a shard of the split.
num_epochs – the number of times to iterate through the dataset, or None to repeat indefinitely. Note that the repeat occurs in the pipeline after offline caching, but before applying potentially stochastic post-cache preprocessors and is therefore typically preferred to calling repeat() on the returned dataset. Defaults to None.
copy_pretokenized – bool, whether to pass through copies of pretokenized features a “_pretokenized” suffix added to the key.
compute_stats_empirically – a boolean - does not work on TPU
log_mixing_proportions – whether to log the mixing proportions of the tasks
passthrough_features – a list of additional features that will be kept after the feature filtering. If set to be None, then only the output_features defined for the mixture will be kept.
trim_output_features – If True, it trims output features to be less than the length given by sequence_length.
try_in_mem_cache – If True, caches sufficiently small datasets in memory for efficiency.
- get_task_dataset(task, output_feature_keys, sequence_length=None, split=Split('train'), use_cached=False, shuffle=True, seed=None, shard_info=None, num_epochs=None, trim_output_features=True, try_in_mem_cache=True)[source]#
.
- property rate_per_task_name#
Returns the rate for each task.
Note that sub-mixtures are included as tasks and that the tasks part of these sub-mixtures are not in the mapping.
Registry#
- class seqio.dataset_providers.TaskRegistry[source]#
Registry of Tasks.
APIs#
- seqio.dataset_providers.get_dataset(mixture_or_task_name, task_feature_lengths, feature_converter, dataset_split='train', use_cached=False, shuffle=False, num_epochs=1, shard_info=None, verbose=True, seed=None, batch_size=None, trim_output_features=True)[source]#
Get processed dataset with the model features.
In order to use options specific to a feature converter, e.g., packing, feature_converter instance should be instantiated with those options before being pased to this function.
Getting sharded datasets is supported. To use this feature, pass in shard_info, with shard_index and num_shards information. Sharding is done before the feature converter stage. Therefore, if packing is used it will be done on the sharded dataset.
- Parameters:
mixture_or_task_name – mixture or task name for the Task API.
task_feature_lengths – dict mapping task feature key to its sequence length. This specifies the sequence length of the dataset from the Task API.
feature_converter – a feature converter object to use to convert the task features to model features. Must be a subclass of FeatureConverter.
dataset_split – the split to use.
use_cached – whether to use the cached dataset instead of processing it on the fly.
shuffle – whether to shuffle the dataset.
num_epochs – the number of times to iterate through the dataset, or None to repeat indefinitely. Note that the repeat occurs in the pipeline after offline caching, but before applying potentially stochastic post-cache preprocessors and is therefore typically preferred to calling repeat() on the returned dataset. Defaults to 1.
shard_info – number of shards and shard index information.
verbose – if true, log the feature shapes.
seed – a random seed to for shuffling tf.data.
batch_size – Optional batch size.
trim_output_features – If True, it trims output features to be less than the length given by sequence_length.
- Returns:
the processed dataset.
- Return type:
ds
- seqio.dataset_providers.get_mixture_or_task(task_or_mixture_name)[source]#
Return the Task or Mixture from the appropriate registry.