Improve support for loading local datasets (#33)
* Handle loading local datasets * Reorder branches to avoid chain of negatives
This commit is contained in:
+36
-2
@@ -2,8 +2,10 @@
|
|||||||
# Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com>
|
# Copyright (C) 2025 Philipp Emanuel Weidmann <pew@worldwidemann.com>
|
||||||
|
|
||||||
import gc
|
import gc
|
||||||
|
import os
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from importlib.metadata import version
|
from importlib.metadata import version
|
||||||
|
from pathlib import Path
|
||||||
from typing import TypeVar
|
from typing import TypeVar
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -13,7 +15,10 @@ from accelerate.utils import (
|
|||||||
is_sdaa_available,
|
is_sdaa_available,
|
||||||
is_xpu_available,
|
is_xpu_available,
|
||||||
)
|
)
|
||||||
from datasets import load_dataset
|
from datasets import ReadInstruction, load_dataset, load_from_disk
|
||||||
|
from datasets.config import DATASET_STATE_JSON_FILENAME
|
||||||
|
from datasets.download.download_manager import DownloadMode
|
||||||
|
from datasets.utils.info_utils import VerificationMode
|
||||||
from optuna import Trial
|
from optuna import Trial
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
|
|
||||||
@@ -36,7 +41,36 @@ def format_duration(seconds: float) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def load_prompts(specification: DatasetSpecification) -> list[str]:
|
def load_prompts(specification: DatasetSpecification) -> list[str]:
|
||||||
dataset = load_dataset(specification.dataset, split=specification.split)
|
path = specification.dataset
|
||||||
|
split_str = specification.split
|
||||||
|
if os.path.isdir(path):
|
||||||
|
if Path(path, DATASET_STATE_JSON_FILENAME).exists():
|
||||||
|
# Dataset saved with datasets.save_to_disk; needs special handling.
|
||||||
|
# Path should be the subdirectory for a particular split.
|
||||||
|
dataset = load_from_disk(path)
|
||||||
|
# Parse the split instructions.
|
||||||
|
ri = ReadInstruction.from_spec(split_str)
|
||||||
|
# Associate the split with its number of examples (lines).
|
||||||
|
split_name = str(dataset.split)
|
||||||
|
name2len = {split_name: len(dataset)}
|
||||||
|
# Convert the instructions to absolute indices and select the first one.
|
||||||
|
abs_i = ri.to_absolute(name2len)[0]
|
||||||
|
# Get the dataset by applying the indices.
|
||||||
|
dataset = dataset[abs_i.from_ : abs_i.to]
|
||||||
|
else:
|
||||||
|
# Path is a local directory.
|
||||||
|
dataset = load_dataset(
|
||||||
|
path,
|
||||||
|
split=split_str,
|
||||||
|
# Don't require the number of examples (lines) per split to be pre-defined.
|
||||||
|
verification_mode=VerificationMode.NO_CHECKS,
|
||||||
|
# But also don't use cached data, as the dataset may have changed on disk.
|
||||||
|
download_mode=DownloadMode.FORCE_REDOWNLOAD,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Probably a repository path; let load_dataset figure it out.
|
||||||
|
dataset = load_dataset(path, split=split_str)
|
||||||
|
|
||||||
return list(dataset[specification.column])
|
return list(dataset[specification.column])
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user