notedeck

One damus client to rule them all
git clone git://jb55.com/notedeck
Log | Files | Refs | README | LICENSE

commit 633cba8331683a34516c164ce96f0dcae8825ef3
parent 366827d335ddc95fb8d0e05dc24adcc3fce8c2d7
Author: William Casarin <jb55@jb55.com>
Date:   Sat, 29 Mar 2025 10:10:14 -0700

dave: introduce model config

so you can switch between openai and ollama models

Signed-off-by: William Casarin <jb55@jb55.com>

Diffstat:
Mcrates/notedeck_dave/src/lib.rs | 54++++++++++++++++++++++++++++++++++++++++++++++--------
1 file changed, 46 insertions(+), 8 deletions(-)

diff --git a/crates/notedeck_dave/src/lib.rs b/crates/notedeck_dave/src/lib.rs @@ -304,6 +304,46 @@ pub struct Dave { tools: Arc<HashMap<String, Tool>>, client: async_openai::Client<OpenAIConfig>, incoming_tokens: Option<Receiver<DaveResponse>>, + model_config: ModelConfig, +} + +pub struct ModelConfig { + endpoint: Option<String>, + model: String, + api_key: Option<String>, +} + +impl Default for ModelConfig { + fn default() -> Self { + ModelConfig { + endpoint: None, + model: "gpt-4o".to_string(), + api_key: std::env::var("OPENAI_API_KEY").ok(), + } + } +} + +impl ModelConfig { + pub fn ollama() -> Self { + ModelConfig { + endpoint: std::env::var("OLLAMA_HOST").ok(), + model: "hhao/qwen2.5-coder-tools:latest".to_string(), + api_key: None, + } + } + + pub fn to_api(&self) -> OpenAIConfig { + let mut cfg = OpenAIConfig::new(); + if let Some(endpoint) = &self.endpoint { + cfg = cfg.with_api_base(endpoint.to_owned()); + } + + if let Some(api_key) = &self.api_key { + cfg = cfg.with_api_key(api_key.to_owned()); + } + + cfg + } } impl Dave { @@ -312,12 +352,9 @@ impl Dave { } pub fn new(render_state: Option<&RenderState>) -> Self { - let mut config = OpenAIConfig::new(); //.with_api_base("http://ollama.jb55.com/v1"); - if let Ok(api_key) = std::env::var("OPENAI_API_KEY") { - config = config.with_api_key(api_key); - } - - let client = Client::with_config(config); + //let mut config = OpenAIConfig::new(); //.with_api_base("http://ollama.jb55.com/v1"); + let model_config = ModelConfig::default(); + let client = Client::with_config(model_config.to_api()); let input = "".to_string(); let pubkey = "32e1827635450ebb3c5a7d12c1f8e7b2b514439ac10a67eef3d9fd9c5c68e245".to_string(); @@ -336,6 +373,7 @@ impl Dave { incoming_tokens: None, tools: Arc::new(tools), input, + model_config, chat: vec![system_prompt], } } @@ -513,6 +551,7 @@ impl Dave { let ctx = ctx.clone(); let client = self.client.clone(); let tools = self.tools.clone(); + let model_name = self.model_config.model.clone(); let (tx, rx) = mpsc::channel(); self.incoming_tokens = Some(rx); @@ -521,8 +560,7 @@ impl Dave { let mut token_stream = match client .chat() .create_stream(CreateChatCompletionRequest { - model: "gpt-4o".to_string(), - //model: "gpt-4o".to_string(), + model: model_name, stream: Some(true), messages, tools: Some(dave_tools().iter().map(|t| t.to_api()).collect()),