diff --git a/src/heretic/main.py b/src/heretic/main.py index 19bc3da..2f7f2aa 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -222,14 +222,58 @@ def main(): ) print(f" * Score: [bold]{-study.best_value:.4f}[/]") - return - print() - action = questionary.select( - "What do you want to do with the optimized model?", - choices=[ - "Save to a local folder", - "Upload to Hugging Face", - "Nothing (discard the model)", - ], - ).ask() + print("Restoring best model...") + print("* Reloading model...") + model.reload_model() + print("* Abliterating...") + model.abliterate( + refusal_directions, + study.best_params["max_weight"], + study.best_params["max_weight_position"], + study.best_params["min_weight"], + study.best_params["min_weight_distance"], + ) + + while True: + print() + action = questionary.select( + "What do you want to do with the optimized model?", + choices=[ + "Save the model to a local folder", + "Upload the model to Hugging Face", + "Chat with the model", + "Nothing (Quit)", + ], + ).ask() + + match action: + case "Save the model to a local folder": + # TODO + pass + case "Upload the model to Hugging Face": + # TODO + pass + case "Chat with the model": + print() + print("[cyan]Press Ctrl+C at any time to return to the menu.[/]") + + chat = [ + {"role": "system", "content": settings.system_prompt}, + ] + + while True: + try: + message = questionary.text("User:", qmark=">").unsafe_ask() + if not message: + break + chat.append({"role": "user", "content": message}) + + print("[bold]Assistant:[/] ", end="") + response = model.stream_chat_response(chat) + chat.append({"role": "assistant", "content": response}) + except (KeyboardInterrupt, EOFError): + # Ctrl+C/Ctrl+D + break + case "Nothing (Quit)": + break diff --git a/src/heretic/model.py b/src/heretic/model.py index 04cf60f..fa8ae59 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -12,6 +12,7 @@ from transformers import ( AutoTokenizer, BatchEncoding, PreTrainedTokenizerBase, + TextStreamer, ) from transformers.generation.utils import GenerateOutput @@ -251,3 +252,31 @@ class Model: logprobs.append(self.get_logprobs(batch)) return torch.cat(logprobs, dim=0) + + def stream_chat_response(self, chat: list[dict[str, str]]) -> str: + chat_prompt: str = self.tokenizer.apply_chat_template( + chat, + add_generation_prompt=True, + tokenize=False, + ) + + inputs = self.tokenizer( + chat_prompt, + return_tensors="pt", + ).to(self.model.device) + + streamer = TextStreamer( + self.tokenizer, + skip_prompt=True, + skip_special_tokens=True, + ) + + outputs = self.model.generate( + **inputs, + streamer=streamer, + ) + + return self.tokenizer.decode( + outputs[0, inputs["input_ids"].shape[1] :], + skip_special_tokens=True, + )