feat: allow injecting prefixes and suffixes into prompts
This commit is contained in:
@@ -27,6 +27,16 @@ class DatasetSpecification(BaseModel):
|
|||||||
|
|
||||||
column: str = Field(description="Column in the dataset that contains the prompts.")
|
column: str = Field(description="Column in the dataset that contains the prompts.")
|
||||||
|
|
||||||
|
prefix: str = Field(
|
||||||
|
default="",
|
||||||
|
description="Text to prepend to each prompt.",
|
||||||
|
)
|
||||||
|
|
||||||
|
suffix: str = Field(
|
||||||
|
default="",
|
||||||
|
description="Text to append to each prompt.",
|
||||||
|
)
|
||||||
|
|
||||||
residual_plot_label: str | None = Field(
|
residual_plot_label: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Label to use for the dataset in plots of residual vectors.",
|
description="Label to use for the dataset in plots of residual vectors.",
|
||||||
|
|||||||
@@ -171,7 +171,15 @@ def load_prompts(specification: DatasetSpecification) -> list[str]:
|
|||||||
# Probably a repository path; let load_dataset figure it out.
|
# Probably a repository path; let load_dataset figure it out.
|
||||||
dataset = load_dataset(path, split=split_str)
|
dataset = load_dataset(path, split=split_str)
|
||||||
|
|
||||||
return list(dataset[specification.column])
|
prompts = list(dataset[specification.column])
|
||||||
|
|
||||||
|
if specification.prefix:
|
||||||
|
prompts = [f"{specification.prefix} {prompt}" for prompt in prompts]
|
||||||
|
|
||||||
|
if specification.suffix:
|
||||||
|
prompts = [f"{prompt} {specification.suffix}" for prompt in prompts]
|
||||||
|
|
||||||
|
return prompts
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|||||||
Reference in New Issue
Block a user