diff --git a/src/heretic/utils.py b/src/heretic/utils.py index 8dc5e29..8da2c57 100644 --- a/src/heretic/utils.py +++ b/src/heretic/utils.py @@ -17,7 +17,7 @@ from accelerate.utils import ( is_sdaa_available, is_xpu_available, ) -from datasets import ReadInstruction, load_dataset, load_from_disk +from datasets import DatasetDict, ReadInstruction, load_dataset, load_from_disk from datasets.config import DATASET_STATE_JSON_FILENAME from datasets.download.download_manager import DownloadMode from datasets.utils.info_utils import VerificationMode @@ -145,6 +145,9 @@ def load_prompts(specification: DatasetSpecification) -> list[str]: # Dataset saved with datasets.save_to_disk; needs special handling. # Path should be the subdirectory for a particular split. dataset = load_from_disk(path) + assert not isinstance(dataset, DatasetDict), ( + "Loading dataset dicts is not supported" + ) # Parse the split instructions. instruction = ReadInstruction.from_spec(split_str) # Associate the split with its number of examples (lines).