Improve support for loading local datasets (#33)

* Handle loading local datasets

* Reorder branches to avoid chain of negatives
This commit is contained in:
Spiky Moth
2025-11-23 06:45:34 +01:00
committed by GitHub
parent 83cbf0612a
commit b79b8b1475
+36 -2
View File
@@ -2,8 +2,10 @@
# Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com>
import gc
import os
from dataclasses import asdict
from importlib.metadata import version
from pathlib import Path
from typing import TypeVar
import torch
@@ -13,7 +15,10 @@ from accelerate.utils import (
is_sdaa_available,
is_xpu_available,
)
from datasets import load_dataset
from datasets import ReadInstruction, load_dataset, load_from_disk
from datasets.config import DATASET_STATE_JSON_FILENAME
from datasets.download.download_manager import DownloadMode
from datasets.utils.info_utils import VerificationMode
from optuna import Trial
from rich.console import Console
@@ -36,7 +41,36 @@ def format_duration(seconds: float) -> str:
def load_prompts(specification: DatasetSpecification) -> list[str]:
dataset = load_dataset(specification.dataset, split=specification.split)
path = specification.dataset
split_str = specification.split
if os.path.isdir(path):
if Path(path, DATASET_STATE_JSON_FILENAME).exists():
# Dataset saved with datasets.save_to_disk; needs special handling.
# Path should be the subdirectory for a particular split.
dataset = load_from_disk(path)
# Parse the split instructions.
ri = ReadInstruction.from_spec(split_str)
# Associate the split with its number of examples (lines).
split_name = str(dataset.split)
name2len = {split_name: len(dataset)}
# Convert the instructions to absolute indices and select the first one.
abs_i = ri.to_absolute(name2len)[0]
# Get the dataset by applying the indices.
dataset = dataset[abs_i.from_ : abs_i.to]
else:
# Path is a local directory.
dataset = load_dataset(
path,
split=split_str,
# Don't require the number of examples (lines) per split to be pre-defined.
verification_mode=VerificationMode.NO_CHECKS,
# But also don't use cached data, as the dataset may have changed on disk.
download_mode=DownloadMode.FORCE_REDOWNLOAD,
)
else:
# Probably a repository path; let load_dataset figure it out.
dataset = load_dataset(path, split=split_str)
return list(dataset[specification.column])