Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions datatools/io_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import os
import io
import random

from typing import Any, Dict, Union, List, Optional
from collections.abc import Sequence
Expand Down Expand Up @@ -62,6 +63,11 @@ def shard(cls, dataset: Array, shard_id: int, num_shards: int):
shard_indices = np.linspace(0, N, num_shards + 1)

return cls(dataset, range(int(shard_indices[shard_id]), int(shard_indices[shard_id + 1])))

@classmethod
def sample(cls, dataset: Array, num_samples: int, seed: int = 0):
rng = random.Random(seed)
return cls(dataset, rng.sample(range(len(dataset)), num_samples))

def __len__(self) -> int:
return len(self.indices)
Expand Down Expand Up @@ -190,6 +196,28 @@ def get_item(self, idx: int) -> Dict[str, Any]:
return json.loads(self.lines[idx])


class CSVDataset(Array):
def __init__(self, paths: List[Union[UPath, str]], is_tsv: bool = False):
self.paths = paths
self.is_tsv = is_tsv

dfs = []
for path in paths:
dfs.append(pd.read_csv(path, sep='\t' if is_tsv else ','))

self.df = pd.concat(dfs)

def __len__(self) -> int:
return len(self.df)

@property
def size(self) -> int:
return len(self.df)

def get_item(self, idx: int) -> Dict[str, Any]:
return self.df.iloc[idx].to_dict()


class PyArrowDataset(Array):
"""PyArrow-based dataset that supports parquet and arrow files with local and S3 paths."""

Expand Down
10 changes: 9 additions & 1 deletion datatools/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from upath import UPath
import glob

from datatools.io_utils import LocalDatasets, JsonlDataset, is_remote_path, has_compressed_mds_files, RemoteDatasets, PyArrowDataset
from datatools.io_utils import LocalDatasets, JsonlDataset, is_remote_path, has_compressed_mds_files, RemoteDatasets, PyArrowDataset, CSVDataset

def _expand_glob_patterns(input_paths: List[Union[UPath, str]]) -> List[UPath]:
"""Expand glob patterns in input paths for both local and remote paths."""
Expand Down Expand Up @@ -103,6 +103,10 @@ def load(*input_paths: List[Union[UPath, str]], options: Optional[LoadOptions] =
if suffix in [".arrow", ".parquet", ".npy", ".jsonl"]:
input_type = suffix[1:]
break
# Attempt to load json as jsonl
if suffix == ".json":
input_type = "jsonl"
break

if input_type == "mosaic":
if any(is_remote_path(path) or has_compressed_mds_files(path) for path in input_paths):
Expand All @@ -111,6 +115,10 @@ def load(*input_paths: List[Union[UPath, str]], options: Optional[LoadOptions] =
return LocalDatasets(input_paths)
elif input_type == "jsonl":
return JsonlDataset(input_paths)
elif input_type == "csv":
return CSVDataset(input_paths)
elif input_type == "tsv":
return CSVDataset(input_paths, is_tsv=True)
elif input_type == "npy":
return np.concatenate([np.load(str(path)) for path in input_paths])
elif input_type in {"parquet", "arrow"}:
Expand Down
2 changes: 1 addition & 1 deletion datatools/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def load_indices(options):

if options.index_range is not None:
logger.info(f"Using indices from {options.index_range[0]} to {options.index_range[1]}")
indices = range(*options.index_range)
indices = np.arange(*options.index_range)

return indices

Expand Down
22 changes: 12 additions & 10 deletions datatools/scripts/pack.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def pack_fn(data: Array,
field: add_special_tokens(np.array(item[field], dtype=np.uint32), options, bos=True, eos=True)
for field in options.other_fields
}

if options.split_by_lengths:
while len(input_ids) >= sorted_lengths[-1]:
# From longest to shortest
Expand All @@ -219,42 +219,42 @@ def pack_fn(data: Array,
if len(input_ids) >= target_len)

target_subset = subset / f"{target_len}-{options.pack_length}"

other_iterators = {
field: iter(list(other_buffers[field][target_len].process(seq[:target_len])))
for field, seq in other_seqs.items()
}
for item in buffers[target_len].process(input_ids[:target_len]):
if options.domain_field:
item.update({options.domain_field: str(target_subset)})

for field, iterator in other_iterators.items():
item[field] = next(iterator)[field]
assert len(item[field]) == len(item[options.token_field])

yield target_subset, item

if options.intact:
break

input_ids = add_special_tokens(input_ids[target_len - options.overlap:], options, mos=True)

for field, iterator in other_iterators.items():
other_seqs[field] = add_special_tokens(other_seqs[field][target_len - options.overlap:], options, mos=True)
else:
other_iterators = {
field: iter(list(other_buffers[field][subset].process(seq)))
for field, seq in other_seqs.items()
}

for item in buffers[subset].process(input_ids):
if options.domain_field:
item.update({options.domain_field: str(subset)})

for field, iterator in other_iterators.items():
item[field] = next(iterator)[field]
assert len(item[field]) == len(item[options.token_field])

yield subset, item


Expand All @@ -267,7 +267,7 @@ def main():
parser.add_arguments(PackOptions, dest="pack_options")
parser.add_arguments(LoadOptions, dest="load_options")
parser.add_arguments(ProcessOptions, dest="process_options")

parser.add_argument("-x", "--shuffle", action="store_true", help="Shuffle the dataset")
parser.add_argument("--seed", type=int, default=42, help="Shuffle seed")

Expand All @@ -277,13 +277,15 @@ def main():
dataset = load(*args.inputs, options=args.load_options)
N = len(dataset)
print(f"Loaded dataset with {N} samples")

if args.shuffle:
indices = load_indices(args.process_options)
if indices is None:
indices = np.arange(N)
np.random.seed(args.seed)
args.process_options.indices = indices[np.random.permutation(len(indices))]
args.process_options.index_path = None
args.process_options.index_range = None

process(dataset,
partial(pack_fn, options=args.pack_options),
Expand Down
13 changes: 9 additions & 4 deletions datatools/scripts/tokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def load_tokenizer_encoder(options: TokenizeOptions):
from datatools.scripts.tokenizers.llama3_tokenizer import Tokenizer
tokenizer = Tokenizer(str(Path(__file__).parent / "tokenizers" / "llama3_tokenizer.model"))
from datatools.scripts.tokenizers.llama3_tokenizer import ChatFormat

if options.chat_template:
chat_format = ChatFormat(tokenizer)
def encode_fn(item):
Expand Down Expand Up @@ -84,7 +84,7 @@ def encode_fn(item):
return tokens
return encode_fn




def tokenize_fn(data: Array,
Expand All @@ -96,7 +96,7 @@ def tokenize_fn(data: Array,
for i in tqdm(range(len(data)), desc=f"Process {process_id}"):
item = data[i]
domain = item[options.domain_by] if options.domain_by is not None else options.domain

if options.chat_template and options.chat_assistant_masking:
tokens, masks = encode_fn(item)

Expand All @@ -109,7 +109,12 @@ def tokenize_fn(data: Array,
output_item = {
options.token_field: np.array(tokens, dtype=np.uint32),
}


if len(tokens) == 0:
# writing an array of length 0 will throw an error by MDS, which is undesirable as the rest of the data will be abandoned
# instead, we will write with a dummy token (0), and let the later user filter this out
output_item[options.token_field] = np.array([0], dtype=np.uint32)

if options.length_field:
output_item[options.length_field] = len(tokens)
if options.domain_field:
Expand Down