Add save and upload functionality
This commit is contained in:
@@ -10,6 +10,7 @@ requires-python = ">=3.10"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"accelerate>=1.10.0",
|
"accelerate>=1.10.0",
|
||||||
"datasets>=4.0.0",
|
"datasets>=4.0.0",
|
||||||
|
"huggingface-hub>=0.34.4",
|
||||||
"optuna>=4.5.0",
|
"optuna>=4.5.0",
|
||||||
"pydantic-settings>=2.10.1",
|
"pydantic-settings>=2.10.1",
|
||||||
"questionary>=2.1.1",
|
"questionary>=2.1.1",
|
||||||
|
|||||||
+72
-26
@@ -4,7 +4,9 @@
|
|||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from importlib.metadata import version
|
from importlib.metadata import version
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import huggingface_hub
|
||||||
import optuna
|
import optuna
|
||||||
import questionary
|
import questionary
|
||||||
import torch
|
import torch
|
||||||
@@ -258,33 +260,77 @@ def main():
|
|||||||
],
|
],
|
||||||
).ask()
|
).ask()
|
||||||
|
|
||||||
match action:
|
# All actions are wrapped in a try/except block so that if an error occurs,
|
||||||
case "Save the model to a local folder":
|
# another action can be tried, instead of the program crashing and losing
|
||||||
# TODO
|
# the optimized model.
|
||||||
pass
|
try:
|
||||||
case "Upload the model to Hugging Face":
|
match action:
|
||||||
# TODO
|
case "Save the model to a local folder":
|
||||||
pass
|
save_directory = questionary.path("Path to the folder:").ask()
|
||||||
case "Chat with the model":
|
if not save_directory:
|
||||||
print()
|
continue
|
||||||
print("[cyan]Press Ctrl+C at any time to return to the menu.[/]")
|
|
||||||
|
|
||||||
chat = [
|
print("Saving model...")
|
||||||
{"role": "system", "content": settings.system_prompt},
|
model.model.save_pretrained(save_directory)
|
||||||
]
|
model.tokenizer.save_pretrained(save_directory)
|
||||||
|
print(f"Model saved to [bold]{save_directory}[/].")
|
||||||
|
|
||||||
while True:
|
case "Upload the model to Hugging Face":
|
||||||
try:
|
# We don't use huggingface_hub.login() because that stores the token on disk,
|
||||||
message = questionary.text("User:", qmark=">").unsafe_ask()
|
# and since this program will often be run on rented or shared GPU servers,
|
||||||
if not message:
|
# it's better to not persist credentials.
|
||||||
|
token = huggingface_hub.get_token()
|
||||||
|
if not token:
|
||||||
|
token = questionary.password("Hugging Face access token:").ask()
|
||||||
|
if not token:
|
||||||
|
continue
|
||||||
|
|
||||||
|
user = huggingface_hub.whoami(token)
|
||||||
|
print(f"Logged in as [bold]{user['fullname']} ({user['email']})[/]")
|
||||||
|
|
||||||
|
repo_id = questionary.text(
|
||||||
|
"Name of repository:",
|
||||||
|
default=f"{user['name']}/{Path(settings.model).name}-heretic",
|
||||||
|
).ask()
|
||||||
|
|
||||||
|
visibility = questionary.select(
|
||||||
|
"Should the repository be public or private?",
|
||||||
|
choices=[
|
||||||
|
"Public",
|
||||||
|
"Private",
|
||||||
|
],
|
||||||
|
).ask()
|
||||||
|
private = visibility == "Private"
|
||||||
|
|
||||||
|
print("Uploading model...")
|
||||||
|
model.model.push_to_hub(repo_id, private=private, token=token)
|
||||||
|
model.tokenizer.push_to_hub(repo_id, private=private, token=token)
|
||||||
|
print(f"Model uploaded to [bold]{repo_id}[/].")
|
||||||
|
|
||||||
|
case "Chat with the model":
|
||||||
|
print()
|
||||||
|
print("[cyan]Press Ctrl+C at any time to return to the menu.[/]")
|
||||||
|
|
||||||
|
chat = [
|
||||||
|
{"role": "system", "content": settings.system_prompt},
|
||||||
|
]
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
message = questionary.text("User:", qmark=">").unsafe_ask()
|
||||||
|
if not message:
|
||||||
|
break
|
||||||
|
chat.append({"role": "user", "content": message})
|
||||||
|
|
||||||
|
print("[bold]Assistant:[/] ", end="")
|
||||||
|
response = model.stream_chat_response(chat)
|
||||||
|
chat.append({"role": "assistant", "content": response})
|
||||||
|
except (KeyboardInterrupt, EOFError):
|
||||||
|
# Ctrl+C/Ctrl+D
|
||||||
break
|
break
|
||||||
chat.append({"role": "user", "content": message})
|
|
||||||
|
|
||||||
print("[bold]Assistant:[/] ", end="")
|
case "Nothing (Quit)":
|
||||||
response = model.stream_chat_response(chat)
|
break
|
||||||
chat.append({"role": "assistant", "content": response})
|
|
||||||
except (KeyboardInterrupt, EOFError):
|
except Exception as error:
|
||||||
# Ctrl+C/Ctrl+D
|
print(f"[red]Error: {error}[/]")
|
||||||
break
|
|
||||||
case "Nothing (Quit)":
|
|
||||||
break
|
|
||||||
|
|||||||
@@ -479,6 +479,7 @@ source = { editable = "." }
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "accelerate" },
|
{ name = "accelerate" },
|
||||||
{ name = "datasets" },
|
{ name = "datasets" },
|
||||||
|
{ name = "huggingface-hub" },
|
||||||
{ name = "optuna" },
|
{ name = "optuna" },
|
||||||
{ name = "pydantic-settings" },
|
{ name = "pydantic-settings" },
|
||||||
{ name = "questionary" },
|
{ name = "questionary" },
|
||||||
@@ -490,6 +491,7 @@ dependencies = [
|
|||||||
requires-dist = [
|
requires-dist = [
|
||||||
{ name = "accelerate", specifier = ">=1.10.0" },
|
{ name = "accelerate", specifier = ">=1.10.0" },
|
||||||
{ name = "datasets", specifier = ">=4.0.0" },
|
{ name = "datasets", specifier = ">=4.0.0" },
|
||||||
|
{ name = "huggingface-hub", specifier = ">=0.34.4" },
|
||||||
{ name = "optuna", specifier = ">=4.5.0" },
|
{ name = "optuna", specifier = ">=4.5.0" },
|
||||||
{ name = "pydantic-settings", specifier = ">=2.10.1" },
|
{ name = "pydantic-settings", specifier = ">=2.10.1" },
|
||||||
{ name = "questionary", specifier = ">=2.1.1" },
|
{ name = "questionary", specifier = ">=2.1.1" },
|
||||||
|
|||||||
Reference in New Issue
Block a user