Source code for seqio.beam_utils

# Copyright 2026 The SeqIO Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""SeqIO Beam utilities."""

import functools
import hashlib
import importlib
import json
import operator
import os
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Tuple

from absl import logging
import apache_beam as beam
from apache_beam import metrics
import numpy as np
import seqio
import tensorflow.compat.v2 as tf

from array_record.python import array_record_module

PROVENANCE_PREFIX = "provenance/"
TASK_PROVENANCE_KEY = PROVENANCE_PREFIX + "task"
SOURCE_SHARD_PROVENANCE_KEY = PROVENANCE_PREFIX + "source_shard"
SOURCE_SHARD_ID_PROVENANCE_KEY = PROVENANCE_PREFIX + "source_shard_index"
ID_WITHIN_SHARD_PROVENANCE_KEY = PROVENANCE_PREFIX + "index_within_shard"
PREPROCESSORS_SEED_PROVENANCE_KEY = PROVENANCE_PREFIX + "preprocessors_seed"
PROVENANCE_KEYS = [
    TASK_PROVENANCE_KEY,
    SOURCE_SHARD_PROVENANCE_KEY,
    SOURCE_SHARD_ID_PROVENANCE_KEY,
    ID_WITHIN_SHARD_PROVENANCE_KEY,
    PREPROCESSORS_SEED_PROVENANCE_KEY,
]


def _import_modules(modules):
  for module in modules:
    if module:
      importlib.import_module(module)


[docs] class PreprocessTask(beam.PTransform): """Preprocesses a Task. Returns a PCollection of example dicts containing Tensors. """ def __init__( self, task: seqio.Task, split: str, *, preprocessors_seed: Optional[int] = None, setup_fn: Callable[[], None] = lambda: None, modules_to_import: Sequence[str] = (), add_provenance: bool = False, tfds_data_dir: Optional[str] = None, ): """PreprocessTask constructor. Args: task: Task, the task to process. split: string, the split to process. preprocessors_seed: (Optional) int, a seed for stateless random ops in task preprocessing. setup_fn: (Optional) callable, a function called before loading the task. modules_to_import: (Optional) list, modules to import. add_provenance: If True, provenance is added to each example. tfds_data_dir: (Optional) str, directory where the TFDS datasets are stored. Raises: FileNotFoundError: raised when no shards are found for the task. """ self._task_name = task.name self._split = split self._preprocessors_seed = preprocessors_seed self._setup_fn = setup_fn self._modules_to_import = modules_to_import self._add_provenance = add_provenance self._tfds_data_dir = tfds_data_dir self._int64_max = 2**63 - 1 self.shards = list(enumerate(task.source.list_shards(split))) if not self.shards: raise FileNotFoundError(f"No shards found for {task.name} {split}") logging.info( "%s %s %s shards: %d %s", task.name, split, tfds_data_dir, len(self.shards), ", ".join(["%s" % f[1] for f in self.shards]), ) def _increment_counter(self, name): metrics.Metrics.counter( str("%s_%s" % (self._task_name, self._split)), name ).inc() def _emit_examples(self, shard: Tuple[int, str]): """Emits examples keyed by shard number and index for a single shard.""" self._setup_fn() _import_modules(self._modules_to_import) task = seqio.TaskRegistry.get(self._task_name) if self._tfds_data_dir: seqio.set_tfds_data_dir_override(self._tfds_data_dir) shard_index, shard_name = shard logging.info("Processing shard: %s", shard_name) self._increment_counter("input-shards") # Create a unique, deterministic preprocessors seed for each task and shard. md5_digest = hashlib.md5( (self._task_name + f"shard{shard_index}").encode() ).digest() shard_preprocessors_seed = int.from_bytes(md5_digest, "little") + ( self._preprocessors_seed or 0 ) if shard_preprocessors_seed > self._int64_max: # The user provided seed is very likely to be much smaller than 2**62, # therefore it's safe to just truncated the rest of the bytes and add up. md5_digest = md5_digest[:7] shard_preprocessors_seed = int.from_bytes(md5_digest, "little") + ( self._preprocessors_seed or 0 ) # Truncate if still a large number. shard_preprocessors_seed %= self._int64_max ds = task.source.get_dataset( split=self._split, shard_info=seqio.ShardInfo( index=shard_index, num_shards=len(self.shards) ), shuffle=False, seed=shard_preprocessors_seed, ) ds = task.preprocess_precache(ds, seed=shard_preprocessors_seed) ds = ds.prefetch(tf.data.AUTOTUNE) def _add_provenance( index_within_shard: int, ex: Dict[str, Any] ) -> Dict[str, Any]: ex.update({ TASK_PROVENANCE_KEY: self._task_name, SOURCE_SHARD_PROVENANCE_KEY: shard_name, SOURCE_SHARD_ID_PROVENANCE_KEY: shard_index, ID_WITHIN_SHARD_PROVENANCE_KEY: index_within_shard, }) if self._preprocessors_seed: ex.update({PREPROCESSORS_SEED_PROVENANCE_KEY: self._preprocessors_seed}) return ex for i, ex in enumerate(ds): if self._add_provenance: ex = _add_provenance(i, ex) self._increment_counter("examples") # Log every power of two. if i & (i - 1) == 0: logging.info("Example [%d] = %s", i, ex) yield ex def expand(self, pipeline): return ( pipeline | "create_shards" >> beam.Create(self.shards) | "emit_examples" >> beam.FlatMap(self._emit_examples) )
[docs] class WriteExampleTfRecord(beam.PTransform): """Writes examples (dicts) to a TFRecord of tf.Example protos.""" def __init__(self, output_path: str, num_shards: Optional[int] = None): """WriteExampleTfRecord constructor. Args: output_path: string, path to the output TFRecord file (w/o shard suffix). num_shards: (optional) int, number of shards to output or None to use liquid sharding. """ self._output_path = output_path self._num_shards = num_shards def expand(self, pcoll): sink = beam.io.tfrecordio.WriteToTFRecord( self._output_path, num_shards=self._num_shards, coder=beam.coders.ProtoCoder(tf.train.Example), ) return pcoll | beam.Map(seqio.dict_to_tfexample) | beam.Reshuffle() | sink
class _ArrayRecordSink(beam.io.filebasedsink.FileBasedSink): """Sink Class for use in Arrayrecord PTransform.""" def __init__( self, file_path_prefix, file_name_suffix=None, num_shards=0, shard_name_template=None, coder=beam.coders.coders.ToBytesCoder(), compression_type=beam.io.filesystem.CompressionTypes.AUTO, preserve_random_access: bool = False, ): super().__init__( file_path_prefix, file_name_suffix=file_name_suffix, num_shards=num_shards, shard_name_template=shard_name_template, coder=coder, mime_type="application/octet-stream", compression_type=compression_type, ) self._preserve_random_access = preserve_random_access def open(self, temp_path): group_size = 1 if self._preserve_random_access else self.num_shards array_writer = array_record_module.ArrayRecordWriter( temp_path, f"group_size:{group_size}" ) return array_writer def close(self, file_handle): file_handle.close() def write_encoded_record(self, file_handle, value): file_handle.write(value)
[docs] class WriteToArrayRecord(beam.PTransform): """PTransform for a disk-based write to ArrayRecord.""" def __init__( self, file_path_prefix, file_name_suffix="", num_shards=0, shard_name_template=None, coder=beam.coders.coders.ToBytesCoder(), compression_type=beam.io.filesystem.CompressionTypes.AUTO, preserve_random_access: bool = False, ): self._sink = _ArrayRecordSink( file_path_prefix, file_name_suffix, num_shards, shard_name_template, coder, compression_type, preserve_random_access, ) def expand(self, pcoll): return pcoll | beam.io.iobase.Write(self._sink)
[docs] class WriteExampleArrayRecord(beam.PTransform): """Writes examples (dicts) to an ArrayRecord of tf.Example protos.""" def __init__( self, output_path: str, num_shards: Optional[int] = None, preserve_random_access: bool = False, ): """WriteExampleArrayRecord constructor. Args: output_path: string, path to the output ArrayRecord file (w/o shard suffix). num_shards: (optional) int, number of shards to output or None to use liquid sharding. preserve_random_access: Whether to preserve the random access of the written ArrayRecord. If true, set group_size=1, else, set to number of shards. """ self._output_path = output_path self._num_shards = num_shards self._preserve_random_access = preserve_random_access def expand(self, pcoll): sink = WriteToArrayRecord( self._output_path, num_shards=self._num_shards, coder=beam.coders.ProtoCoder(tf.train.Example), preserve_random_access=self._preserve_random_access, ) return pcoll | beam.Map(seqio.dict_to_tfexample) | beam.Reshuffle() | sink
[docs] class WriteJson(beam.PTransform): """Writes datastructures to file as JSON(L).""" def __init__(self, output_path: str, prettify: Optional[bool] = True): """WriteJson constructor. Args: output_path: string, path to the output JSON(L) file. prettify: bool, whether to write the outputs with sorted keys and indentation. Note this not be used if there are multiple records being written to the file (JSONL). """ self._output_path = output_path self._prettify = prettify def _jsonify(self, el): if self._prettify: return json.dumps(el, sort_keys=True, indent=2) else: return json.dumps(el) def expand(self, pcoll): sink = beam.io.WriteToText( self._output_path, num_shards=1, shard_name_template="" ) return pcoll | beam.Map(self._jsonify) | "write_info" >> sink
[docs] class GetInfo(beam.PTransform): """Computes info for dataset examples. Expects a single PCollections of examples. Returns a dictionary with information needed to read the data (number of shards, feature shapes and types) """ def __init__(self, num_shards: int, exclude_provenance: bool = True): self._num_shards = num_shards self._exclude_provenance = exclude_provenance def _info_dict(self, ex: List[Dict[str, Any]]): if not ex: return {} assert len(ex) == 1 ex = ex[0] info = { "num_shards": self._num_shards, "features": {}, "seqio_version": seqio.__version__, } feature_dict = info["features"] for k, v in ex.items(): if self._exclude_provenance and k.startswith(PROVENANCE_PREFIX): continue if isinstance(v, tf.RaggedTensor): t = v else: t = tf.constant(v) dtype = t.dtype.name shape = t.shape.as_list() # Keep all the dimensions but the first if t is not a scalar. if shape: shape = [None] + shape[1:] feature_dict[k] = {"shape": shape, "dtype": dtype} return info def expand(self, pcoll): return ( pcoll | beam.combiners.Sample.FixedSizeGlobally(1) | beam.Map(self._info_dict) )
class _CountTokens(beam.DoFn): """Returns token counts for each feature.""" def __init__(self, output_features: Mapping[str, seqio.Feature]): self._output_features = output_features def setup(self): # Certain vocabularies are lazy loaded. Since we are running under beam we # try to do the loading only once in the setup phase. for feat in self._output_features.values(): v = feat.vocabulary.eos_id v = feat.vocabulary.unk_id v = feat.vocabulary.pad_id del v def process(self, ex: Mapping[str, Any]) -> Iterable[Tuple[str, int]]: for name, feat in self._output_features.items(): if ( name in ex and ( isinstance(ex[name], np.ndarray) or isinstance(ex[name], tf.Tensor) ) and ex[name].dtype in (np.int32, np.int64) ): if isinstance(ex[name], tf.Tensor): values = ex[name].numpy() else: values = ex[name] conditions = [] if feat.vocabulary.eos_id is not None: conditions.append((values != feat.vocabulary.eos_id)) if feat.vocabulary.pad_id is not None: conditions.append((values != feat.vocabulary.pad_id)) if conditions: valid_tokens = functools.reduce(operator.and_, conditions) else: # Assumes all values are valid tokens. valid_tokens = np.ones_like(values, dtype=bool) num_tokens = int(np.sum(valid_tokens)) yield (f"{name}_tokens", num_tokens) class _CountCharacters(beam.DoFn): """Returns character counts for each feature. This works with both tokenized (integer array) dataset and string dataset. For the former, each feature is detokenized and the string length is computed. For the latter, the bytes feature is decoded and the string length of that is returned. Example 1 (tokenized dataset): ```python Assume that these examples are generated with a vocab that decodes "ea" as [4, 5], etc. input_examples = [{ # Decoded as "ea", i.e., length 2 string "inputs": np.array([4, 5]), # Decoded as "ea test", i.e., length 7 string "targets": np.array([4, 5, 10]), }, { # Decoded as "e", i.e., length 1 string "inputs": np.array([4]), # Decoded as "asoil", i.e., length 5 string. "1" is an EOS id. "targets": np.array([5, 6, 7, 8, 9, 1]) }] This `DoFn` returns (yields each of the 4 elements in sequence): [("inputs_chars", 2), ("targets_chars", 7), ("inputs_chars", 1), ("targets_chars", 5)] ``` Example 2 (string dataset): ```python input_examples = [{ "text": b"this is a string of length 29" }, { "text": b"this is another string of length 35" }] This `DoFn` returns (yields each of the 2 elements in sequence): [("text_chars", 29), ("text_chars", 35)] ``` """ def __init__(self, output_features: Mapping[str, seqio.Feature]): self._output_features = output_features def setup(self): # Certain vocabularies are lazy loaded. Since we are running under beam we # try to do the loading only once in the setup phase. for feat in self._output_features.values(): v = feat.vocabulary.eos_id v = feat.vocabulary.unk_id v = feat.vocabulary.pad_id del v def process(self, ex: Mapping[str, Any]) -> Iterable[Tuple[str, int]]: for name, feat in self._output_features.items(): # We only compute the character length for the rank-1 integer array for # the feature using `seqio.SentencePieceVocabulary`. if ( name in ex and isinstance(ex[name], np.ndarray) and ex[name].dtype in (np.int32, np.int64) and feat.rank == 1 and isinstance(feat.vocabulary, seqio.SentencePieceVocabulary) ): value = ex[name] value = np.abs(value.astype(np.int32)) decoded = feat.vocabulary.decode_tf(value).numpy().decode("utf-8") # If each example in the dataset has the type tf.string, its type # becomes `bytes` inside the `ds.as_numpy_iterator()`. This `DoFn` is # assumed to be applied to the examples from such numpy iterator. elif name in ex and isinstance(ex[name], bytes): decoded = ex[name].decode("utf-8") else: continue yield (f"{name}_chars", len(decoded))
[docs] class GetStats(beam.PTransform): """Computes statistics for dataset examples. The `expand` method expects a PCollection of examples where each example is a dictionary of string identifiers (e.g. "inputs" and "targets") mapped to numpy array. Returns a dictionary with statistics (number of examples, number of tokens) prefixed by the identifiers. """ def __init__( self, output_features: Mapping[str, seqio.Feature], task_ids: Optional[Mapping[str, Any]] = None, enable_char_counts: bool = False, ): self._output_features = output_features self._task_ids = task_ids or {} self._enable_char_counts = enable_char_counts logging.info("Getting stats for output features: %s", str(output_features)) def expand(self, pcoll): example_counts = ( pcoll | "count_examples" >> beam.combiners.Count.Globally() | "key_example_counts" >> beam.Map(lambda x: ("examples", x)) | "example_count_dict" >> beam.combiners.ToDict() ) token_counts = pcoll | "count_tokens" >> beam.ParDo( _CountTokens(self._output_features) ) total_tokens = ( token_counts | "sum_tokens" >> beam.CombinePerKey(sum) | "token_count_dict" >> beam.combiners.ToDict() ) max_tokens = ( token_counts | "max_tokens" >> beam.CombinePerKey(max) | "rename_max_stat" >> beam.Map(lambda x: (x[0].replace("tokens", "max_tokens"), x[1])) | "token_max_dict" >> beam.combiners.ToDict() ) stats = [example_counts, total_tokens, max_tokens] if self._enable_char_counts: char_length = ( pcoll | beam.ParDo(_CountCharacters(self._output_features)) | "sum_characters" >> beam.CombinePerKey(sum) | "character_length_dict" >> beam.combiners.ToDict() ) stats.append(char_length) def _merge_dicts(dicts): merged_dict = {} for d in dicts: assert not set(merged_dict).intersection(d) merged_dict.update(d) return merged_dict if self._task_ids: # ids could be Tensors, cast to int. self._task_ids = {k: int(v) for k, v in self._task_ids.items()} task_ids_dict = {"task_ids": self._task_ids} task_ids = ( pcoll | "sample_for_task_ids" >> beam.combiners.Sample.FixedSizeGlobally(1) | "create_task_ids" >> beam.Map(lambda _: task_ids_dict) ) stats.append(task_ids) return ( stats | "flatten_counts" >> beam.Flatten() | "merge_stats" >> beam.CombineGlobally(_merge_dicts) )