Add chat functionality
This commit is contained in:
+54
-10
@@ -222,14 +222,58 @@ def main():
|
|||||||
)
|
)
|
||||||
print(f" * Score: [bold]{-study.best_value:.4f}[/]")
|
print(f" * Score: [bold]{-study.best_value:.4f}[/]")
|
||||||
|
|
||||||
return
|
|
||||||
|
|
||||||
print()
|
print()
|
||||||
action = questionary.select(
|
print("Restoring best model...")
|
||||||
"What do you want to do with the optimized model?",
|
print("* Reloading model...")
|
||||||
choices=[
|
model.reload_model()
|
||||||
"Save to a local folder",
|
print("* Abliterating...")
|
||||||
"Upload to Hugging Face",
|
model.abliterate(
|
||||||
"Nothing (discard the model)",
|
refusal_directions,
|
||||||
],
|
study.best_params["max_weight"],
|
||||||
).ask()
|
study.best_params["max_weight_position"],
|
||||||
|
study.best_params["min_weight"],
|
||||||
|
study.best_params["min_weight_distance"],
|
||||||
|
)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
print()
|
||||||
|
action = questionary.select(
|
||||||
|
"What do you want to do with the optimized model?",
|
||||||
|
choices=[
|
||||||
|
"Save the model to a local folder",
|
||||||
|
"Upload the model to Hugging Face",
|
||||||
|
"Chat with the model",
|
||||||
|
"Nothing (Quit)",
|
||||||
|
],
|
||||||
|
).ask()
|
||||||
|
|
||||||
|
match action:
|
||||||
|
case "Save the model to a local folder":
|
||||||
|
# TODO
|
||||||
|
pass
|
||||||
|
case "Upload the model to Hugging Face":
|
||||||
|
# TODO
|
||||||
|
pass
|
||||||
|
case "Chat with the model":
|
||||||
|
print()
|
||||||
|
print("[cyan]Press Ctrl+C at any time to return to the menu.[/]")
|
||||||
|
|
||||||
|
chat = [
|
||||||
|
{"role": "system", "content": settings.system_prompt},
|
||||||
|
]
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
message = questionary.text("User:", qmark=">").unsafe_ask()
|
||||||
|
if not message:
|
||||||
|
break
|
||||||
|
chat.append({"role": "user", "content": message})
|
||||||
|
|
||||||
|
print("[bold]Assistant:[/] ", end="")
|
||||||
|
response = model.stream_chat_response(chat)
|
||||||
|
chat.append({"role": "assistant", "content": response})
|
||||||
|
except (KeyboardInterrupt, EOFError):
|
||||||
|
# Ctrl+C/Ctrl+D
|
||||||
|
break
|
||||||
|
case "Nothing (Quit)":
|
||||||
|
break
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from transformers import (
|
|||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
BatchEncoding,
|
BatchEncoding,
|
||||||
PreTrainedTokenizerBase,
|
PreTrainedTokenizerBase,
|
||||||
|
TextStreamer,
|
||||||
)
|
)
|
||||||
from transformers.generation.utils import GenerateOutput
|
from transformers.generation.utils import GenerateOutput
|
||||||
|
|
||||||
@@ -251,3 +252,31 @@ class Model:
|
|||||||
logprobs.append(self.get_logprobs(batch))
|
logprobs.append(self.get_logprobs(batch))
|
||||||
|
|
||||||
return torch.cat(logprobs, dim=0)
|
return torch.cat(logprobs, dim=0)
|
||||||
|
|
||||||
|
def stream_chat_response(self, chat: list[dict[str, str]]) -> str:
|
||||||
|
chat_prompt: str = self.tokenizer.apply_chat_template(
|
||||||
|
chat,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
tokenize=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
inputs = self.tokenizer(
|
||||||
|
chat_prompt,
|
||||||
|
return_tensors="pt",
|
||||||
|
).to(self.model.device)
|
||||||
|
|
||||||
|
streamer = TextStreamer(
|
||||||
|
self.tokenizer,
|
||||||
|
skip_prompt=True,
|
||||||
|
skip_special_tokens=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = self.model.generate(
|
||||||
|
**inputs,
|
||||||
|
streamer=streamer,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.tokenizer.decode(
|
||||||
|
outputs[0, inputs["input_ids"].shape[1] :],
|
||||||
|
skip_special_tokens=True,
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user