From b79b8b14756f9cdb4c922e834e661b6b4535e099 Mon Sep 17 00:00:00 2001 From: Spiky Moth Date: Sun, 23 Nov 2025 06:45:34 +0100 Subject: [PATCH] Improve support for loading local datasets (#33) * Handle loading local datasets * Reorder branches to avoid chain of negatives --- src/heretic/utils.py | 38 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/src/heretic/utils.py b/src/heretic/utils.py index c0cff59..7d0f6b3 100644 --- a/src/heretic/utils.py +++ b/src/heretic/utils.py @@ -2,8 +2,10 @@ # Copyright (C) 2025 Philipp Emanuel Weidmann import gc +import os from dataclasses import asdict from importlib.metadata import version +from pathlib import Path from typing import TypeVar import torch @@ -13,7 +15,10 @@ from accelerate.utils import ( is_sdaa_available, is_xpu_available, ) -from datasets import load_dataset +from datasets import ReadInstruction, load_dataset, load_from_disk +from datasets.config import DATASET_STATE_JSON_FILENAME +from datasets.download.download_manager import DownloadMode +from datasets.utils.info_utils import VerificationMode from optuna import Trial from rich.console import Console @@ -36,7 +41,36 @@ def format_duration(seconds: float) -> str: def load_prompts(specification: DatasetSpecification) -> list[str]: - dataset = load_dataset(specification.dataset, split=specification.split) + path = specification.dataset + split_str = specification.split + if os.path.isdir(path): + if Path(path, DATASET_STATE_JSON_FILENAME).exists(): + # Dataset saved with datasets.save_to_disk; needs special handling. + # Path should be the subdirectory for a particular split. + dataset = load_from_disk(path) + # Parse the split instructions. + ri = ReadInstruction.from_spec(split_str) + # Associate the split with its number of examples (lines). + split_name = str(dataset.split) + name2len = {split_name: len(dataset)} + # Convert the instructions to absolute indices and select the first one. + abs_i = ri.to_absolute(name2len)[0] + # Get the dataset by applying the indices. + dataset = dataset[abs_i.from_ : abs_i.to] + else: + # Path is a local directory. + dataset = load_dataset( + path, + split=split_str, + # Don't require the number of examples (lines) per split to be pre-defined. + verification_mode=VerificationMode.NO_CHECKS, + # But also don't use cached data, as the dataset may have changed on disk. + download_mode=DownloadMode.FORCE_REDOWNLOAD, + ) + else: + # Probably a repository path; let load_dataset figure it out. + dataset = load_dataset(path, split=split_str) + return list(dataset[specification.column])