Improve support for loading local datasets (#33)

* Handle loading local datasets

* Reorder branches to avoid chain of negatives
This commit is contained in:
Spiky Moth
2025-11-23 06:45:34 +01:00
committed by GitHub
parent 83cbf0612a
commit b79b8b1475
+36 -2
View File
@@ -2,8 +2,10 @@
# Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com> # Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com>
import gc import gc
import os
from dataclasses import asdict from dataclasses import asdict
from importlib.metadata import version from importlib.metadata import version
from pathlib import Path
from typing import TypeVar from typing import TypeVar
import torch import torch
@@ -13,7 +15,10 @@ from accelerate.utils import (
is_sdaa_available, is_sdaa_available,
is_xpu_available, is_xpu_available,
) )
from datasets import load_dataset from datasets import ReadInstruction, load_dataset, load_from_disk
from datasets.config import DATASET_STATE_JSON_FILENAME
from datasets.download.download_manager import DownloadMode
from datasets.utils.info_utils import VerificationMode
from optuna import Trial from optuna import Trial
from rich.console import Console from rich.console import Console
@@ -36,7 +41,36 @@ def format_duration(seconds: float) -> str:
def load_prompts(specification: DatasetSpecification) -> list[str]: def load_prompts(specification: DatasetSpecification) -> list[str]:
dataset = load_dataset(specification.dataset, split=specification.split) path = specification.dataset
split_str = specification.split
if os.path.isdir(path):
if Path(path, DATASET_STATE_JSON_FILENAME).exists():
# Dataset saved with datasets.save_to_disk; needs special handling.
# Path should be the subdirectory for a particular split.
dataset = load_from_disk(path)
# Parse the split instructions.
ri = ReadInstruction.from_spec(split_str)
# Associate the split with its number of examples (lines).
split_name = str(dataset.split)
name2len = {split_name: len(dataset)}
# Convert the instructions to absolute indices and select the first one.
abs_i = ri.to_absolute(name2len)[0]
# Get the dataset by applying the indices.
dataset = dataset[abs_i.from_ : abs_i.to]
else:
# Path is a local directory.
dataset = load_dataset(
path,
split=split_str,
# Don't require the number of examples (lines) per split to be pre-defined.
verification_mode=VerificationMode.NO_CHECKS,
# But also don't use cached data, as the dataset may have changed on disk.
download_mode=DownloadMode.FORCE_REDOWNLOAD,
)
else:
# Probably a repository path; let load_dataset figure it out.
dataset = load_dataset(path, split=split_str)
return list(dataset[specification.column]) return list(dataset[specification.column])