Pipeline with custom dataset tokenizer: when to save/load manually

I am trying my hand at the datasets library and I am not sure that I understand the flow.

Let’s assume that I have a single file that is a pickled dict. In that dict, I have two keys that each contain a list of datapoints. One of them is text and the other one is a sentence embedding (yeah, working on a strange project…).

I know that I can create a dataset from this file as follows:

dataset = Dataset.from_dict(torch.load("data.pt"))
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
keys_to_retain = {"input_ids", "sembedding"}
dataset = dataset.map(lambda example: tokenizer(example["text"], padding='max_length'), batched=True)
dataset.remove_columns_(set(dataset.column_names) - keys_to_retain)
dataset.set_format(type="torch", columns=["input_ids", "sembedding"])

My question is, what’s next? Especially considering how caching works. The first thing that should happen is splitting the dataset into a train, dev, test set. As a result I would eventually have a dictionary with train, dev, test keys in them. I can then use them in dataloaders and I am ready to go.

The question is, what about subsequent runs (.e. new Python sessions). Will all that code need to be run again? Should I do dataset.save_to_disk, and in a next session not run the whole dataset creation again? In other words, do I have to manually check for the saved files? Something like this (untested).

def create_datasets(dataset_path):
    if Path(dataset_path).exists():
        datasets = {partition: load_from_disk(Path(dataset_path) / partition) for partition in ["train", "dev", "test"]}
    else:
        # the snippet that I posted above
        # assuming we have train, dev, test in datasets
        for key, dataset in datasets.items():
            dataset.save_to_disk(Path(dataset_path).joinpath(key))
        
    return dataset

Or is the dataset cached somewhere and every time the first snippet is encountered, none of those steps is repeated and the cached dataset is loaded?

In short, it is not clear to me when I can rely on cache that is hidden (probably somewhere in the user directory), and when I should manually use save_to_disk and load a dataset manually.

Thanks!

1 Like