diff --git a/src/heretic/config.py b/src/heretic/config.py index fd2cf62..33c4976 100644 --- a/src/heretic/config.py +++ b/src/heretic/config.py @@ -27,6 +27,16 @@ class DatasetSpecification(BaseModel): column: str = Field(description="Column in the dataset that contains the prompts.") + prefix: str = Field( + default="", + description="Text to prepend to each prompt.", + ) + + suffix: str = Field( + default="", + description="Text to append to each prompt.", + ) + residual_plot_label: str | None = Field( default=None, description="Label to use for the dataset in plots of residual vectors.", diff --git a/src/heretic/utils.py b/src/heretic/utils.py index 8da2c57..e350293 100644 --- a/src/heretic/utils.py +++ b/src/heretic/utils.py @@ -171,7 +171,15 @@ def load_prompts(specification: DatasetSpecification) -> list[str]: # Probably a repository path; let load_dataset figure it out. dataset = load_dataset(path, split=split_str) - return list(dataset[specification.column]) + prompts = list(dataset[specification.column]) + + if specification.prefix: + prompts = [f"{specification.prefix} {prompt}" for prompt in prompts] + + if specification.suffix: + prompts = [f"{prompt} {specification.suffix}" for prompt in prompts] + + return prompts T = TypeVar("T")