Add save and upload functionality
This commit is contained in:
@@ -10,6 +10,7 @@ requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"accelerate>=1.10.0",
|
||||
"datasets>=4.0.0",
|
||||
"huggingface-hub>=0.34.4",
|
||||
"optuna>=4.5.0",
|
||||
"pydantic-settings>=2.10.1",
|
||||
"questionary>=2.1.1",
|
||||
|
||||
+72
-26
@@ -4,7 +4,9 @@
|
||||
import sys
|
||||
import time
|
||||
from importlib.metadata import version
|
||||
from pathlib import Path
|
||||
|
||||
import huggingface_hub
|
||||
import optuna
|
||||
import questionary
|
||||
import torch
|
||||
@@ -258,33 +260,77 @@ def main():
|
||||
],
|
||||
).ask()
|
||||
|
||||
match action:
|
||||
case "Save the model to a local folder":
|
||||
# TODO
|
||||
pass
|
||||
case "Upload the model to Hugging Face":
|
||||
# TODO
|
||||
pass
|
||||
case "Chat with the model":
|
||||
print()
|
||||
print("[cyan]Press Ctrl+C at any time to return to the menu.[/]")
|
||||
# All actions are wrapped in a try/except block so that if an error occurs,
|
||||
# another action can be tried, instead of the program crashing and losing
|
||||
# the optimized model.
|
||||
try:
|
||||
match action:
|
||||
case "Save the model to a local folder":
|
||||
save_directory = questionary.path("Path to the folder:").ask()
|
||||
if not save_directory:
|
||||
continue
|
||||
|
||||
chat = [
|
||||
{"role": "system", "content": settings.system_prompt},
|
||||
]
|
||||
print("Saving model...")
|
||||
model.model.save_pretrained(save_directory)
|
||||
model.tokenizer.save_pretrained(save_directory)
|
||||
print(f"Model saved to [bold]{save_directory}[/].")
|
||||
|
||||
while True:
|
||||
try:
|
||||
message = questionary.text("User:", qmark=">").unsafe_ask()
|
||||
if not message:
|
||||
case "Upload the model to Hugging Face":
|
||||
# We don't use huggingface_hub.login() because that stores the token on disk,
|
||||
# and since this program will often be run on rented or shared GPU servers,
|
||||
# it's better to not persist credentials.
|
||||
token = huggingface_hub.get_token()
|
||||
if not token:
|
||||
token = questionary.password("Hugging Face access token:").ask()
|
||||
if not token:
|
||||
continue
|
||||
|
||||
user = huggingface_hub.whoami(token)
|
||||
print(f"Logged in as [bold]{user['fullname']} ({user['email']})[/]")
|
||||
|
||||
repo_id = questionary.text(
|
||||
"Name of repository:",
|
||||
default=f"{user['name']}/{Path(settings.model).name}-heretic",
|
||||
).ask()
|
||||
|
||||
visibility = questionary.select(
|
||||
"Should the repository be public or private?",
|
||||
choices=[
|
||||
"Public",
|
||||
"Private",
|
||||
],
|
||||
).ask()
|
||||
private = visibility == "Private"
|
||||
|
||||
print("Uploading model...")
|
||||
model.model.push_to_hub(repo_id, private=private, token=token)
|
||||
model.tokenizer.push_to_hub(repo_id, private=private, token=token)
|
||||
print(f"Model uploaded to [bold]{repo_id}[/].")
|
||||
|
||||
case "Chat with the model":
|
||||
print()
|
||||
print("[cyan]Press Ctrl+C at any time to return to the menu.[/]")
|
||||
|
||||
chat = [
|
||||
{"role": "system", "content": settings.system_prompt},
|
||||
]
|
||||
|
||||
while True:
|
||||
try:
|
||||
message = questionary.text("User:", qmark=">").unsafe_ask()
|
||||
if not message:
|
||||
break
|
||||
chat.append({"role": "user", "content": message})
|
||||
|
||||
print("[bold]Assistant:[/] ", end="")
|
||||
response = model.stream_chat_response(chat)
|
||||
chat.append({"role": "assistant", "content": response})
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
# Ctrl+C/Ctrl+D
|
||||
break
|
||||
chat.append({"role": "user", "content": message})
|
||||
|
||||
print("[bold]Assistant:[/] ", end="")
|
||||
response = model.stream_chat_response(chat)
|
||||
chat.append({"role": "assistant", "content": response})
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
# Ctrl+C/Ctrl+D
|
||||
break
|
||||
case "Nothing (Quit)":
|
||||
break
|
||||
case "Nothing (Quit)":
|
||||
break
|
||||
|
||||
except Exception as error:
|
||||
print(f"[red]Error: {error}[/]")
|
||||
|
||||
@@ -479,6 +479,7 @@ source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "accelerate" },
|
||||
{ name = "datasets" },
|
||||
{ name = "huggingface-hub" },
|
||||
{ name = "optuna" },
|
||||
{ name = "pydantic-settings" },
|
||||
{ name = "questionary" },
|
||||
@@ -490,6 +491,7 @@ dependencies = [
|
||||
requires-dist = [
|
||||
{ name = "accelerate", specifier = ">=1.10.0" },
|
||||
{ name = "datasets", specifier = ">=4.0.0" },
|
||||
{ name = "huggingface-hub", specifier = ">=0.34.4" },
|
||||
{ name = "optuna", specifier = ">=4.5.0" },
|
||||
{ name = "pydantic-settings", specifier = ">=2.10.1" },
|
||||
{ name = "questionary", specifier = ">=2.1.1" },
|
||||
|
||||
Reference in New Issue
Block a user