notedeck

One damus client to rule them all
git clone git://jb55.com/notedeck
Log | Files | Refs | README | LICENSE

commit d6c065694a66b6ae59a10b67fdbc183a8d0f5fc1
parent 2a9c5c7848b27545b42dc3e9289f20913dc1dae4
Author: William Casarin <jb55@jb55.com>
Date:   Mon, 14 Apr 2025 12:37:25 -0700

dave: organize

move more things into their own modules

Signed-off-by: William Casarin <jb55@jb55.com>

Diffstat:
Mcrates/notedeck_dave/src/lib.rs | 528++++---------------------------------------------------------------------------
Acrates/notedeck_dave/src/messages.rs | 64++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
Acrates/notedeck_dave/src/query.rs | 0
Acrates/notedeck_dave/src/tools.rs | 484+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
4 files changed, 571 insertions(+), 505 deletions(-)

diff --git a/crates/notedeck_dave/src/lib.rs b/crates/notedeck_dave/src/lib.rs @@ -1,330 +1,35 @@ use async_openai::{ config::OpenAIConfig, - types::{ - ChatCompletionMessageToolCall, ChatCompletionRequestAssistantMessage, - ChatCompletionRequestAssistantMessageContent, ChatCompletionRequestMessage, - ChatCompletionRequestSystemMessage, ChatCompletionRequestSystemMessageContent, - ChatCompletionRequestToolMessage, ChatCompletionRequestToolMessageContent, - ChatCompletionRequestUserMessage, ChatCompletionRequestUserMessageContent, - ChatCompletionTool, ChatCompletionToolType, CreateChatCompletionRequest, FunctionCall, - FunctionObject, - }, + types::{ChatCompletionRequestMessage, CreateChatCompletionRequest}, Client, }; -use chrono::{DateTime, Duration, Local}; +use chrono::{Duration, Local}; use egui_wgpu::RenderState; use futures::StreamExt; -use nostrdb::{Ndb, NoteKey, Transaction}; +use nostrdb::Transaction; use notedeck::AppContext; use notedeck_ui::icons::search_icon; -use serde::{Deserialize, Serialize}; -use serde_json::{json, Value}; use std::collections::HashMap; use std::sync::mpsc::{self, Receiver}; use std::sync::Arc; pub use avatar::DaveAvatar; pub use config::ModelConfig; +pub use messages::{DaveResponse, Message}; pub use quaternion::Quaternion; +pub use tools::{ + PartialToolCall, QueryCall, QueryContext, QueryResponse, Tool, ToolCall, ToolCalls, + ToolResponse, ToolResponses, +}; pub use vec3::Vec3; mod avatar; mod config; +mod messages; mod quaternion; +mod tools; mod vec3; -#[derive(Debug, Clone)] -pub enum Message { - User(String), - Assistant(String), - System(String), - ToolCalls(Vec<ToolCall>), - ToolResponse(ToolResponse), -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct QueryResponse { - context: QueryContext, - notes: Vec<u64>, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub enum ToolResponses { - Query(QueryResponse), -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ToolResponse { - id: String, - typ: ToolResponses, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ToolCall { - id: String, - typ: ToolCalls, -} - -#[derive(Default, Debug, Clone)] -pub struct PartialToolCall { - id: Option<String>, - name: Option<String>, - arguments: Option<String>, -} - -#[derive(Debug, Clone)] -pub struct UnknownToolCall { - id: String, - name: String, - arguments: String, -} - -impl UnknownToolCall { - pub fn parse(&self, tools: &HashMap<String, Tool>) -> Result<ToolCall, ToolCallError> { - let Some(tool) = tools.get(&self.name) else { - return Err(ToolCallError::NotFound(self.name.to_owned())); - }; - - let parsed_args = (tool.parse_call)(&self.arguments)?; - Ok(ToolCall { - id: self.id.clone(), - typ: parsed_args, - }) - } -} - -impl PartialToolCall { - pub fn complete(&self) -> Option<UnknownToolCall> { - Some(UnknownToolCall { - id: self.id.clone()?, - name: self.name.clone()?, - arguments: self.arguments.clone()?, - }) - } -} - -impl ToolCall { - pub fn to_api(&self) -> ChatCompletionMessageToolCall { - ChatCompletionMessageToolCall { - id: self.id.clone(), - r#type: ChatCompletionToolType::Function, - function: self.typ.to_api(), - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub enum ToolCalls { - Query(QueryCall), -} - -impl ToolCalls { - pub fn to_api(&self) -> FunctionCall { - FunctionCall { - name: self.name().to_owned(), - arguments: self.arguments(), - } - } - - fn name(&self) -> &'static str { - match self { - Self::Query(_) => "search", - } - } - - fn arguments(&self) -> String { - match self { - Self::Query(search) => serde_json::to_string(search).unwrap(), - } - } -} - -pub enum DaveResponse { - ToolCalls(Vec<ToolCall>), - Token(String), -} - -impl Message { - pub fn to_api_msg(&self, txn: &Transaction, ndb: &Ndb) -> ChatCompletionRequestMessage { - match self { - Message::User(msg) => { - ChatCompletionRequestMessage::User(ChatCompletionRequestUserMessage { - name: None, - content: ChatCompletionRequestUserMessageContent::Text(msg.clone()), - }) - } - - Message::Assistant(msg) => { - ChatCompletionRequestMessage::Assistant(ChatCompletionRequestAssistantMessage { - content: Some(ChatCompletionRequestAssistantMessageContent::Text( - msg.clone(), - )), - ..Default::default() - }) - } - - Message::System(msg) => { - ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage { - content: ChatCompletionRequestSystemMessageContent::Text(msg.clone()), - ..Default::default() - }) - } - - Message::ToolCalls(calls) => { - ChatCompletionRequestMessage::Assistant(ChatCompletionRequestAssistantMessage { - tool_calls: Some(calls.iter().map(|c| c.to_api()).collect()), - ..Default::default() - }) - } - - Message::ToolResponse(resp) => { - let tool_response = format_tool_response_for_ai(txn, ndb, &resp.typ); - - ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessage { - tool_call_id: resp.id.clone(), - content: ChatCompletionRequestToolMessageContent::Text(tool_response), - }) - } - } - } -} - -#[derive(Debug, Serialize)] -struct SimpleNote { - pubkey: String, - name: String, - content: String, - created_at: String, - note_kind: String, // todo: add replying to -} - -fn note_kind_desc(kind: u64) -> String { - match kind { - 1 => "microblog".to_string(), - 0 => "profile".to_string(), - _ => kind.to_string(), - } -} - -/// Take the result of a tool response and present it to the ai so that -/// it can interepret it and take further action -fn format_tool_response_for_ai(txn: &Transaction, ndb: &Ndb, resp: &ToolResponses) -> String { - match resp { - ToolResponses::Query(search_r) => { - let simple_notes: Vec<SimpleNote> = search_r - .notes - .iter() - .filter_map(|nkey| { - let Ok(note) = ndb.get_note_by_key(txn, NoteKey::new(*nkey)) else { - return None; - }; - - let name = ndb - .get_profile_by_pubkey(txn, note.pubkey()) - .ok() - .and_then(|p| p.record().profile()) - .and_then(|p| p.name().or_else(|| p.display_name())) - .unwrap_or("Anonymous") - .to_string(); - - let content = note.content().to_owned(); - let pubkey = hex::encode(note.pubkey()); - let note_kind = note_kind_desc(note.kind() as u64); - - let created_at = { - let datetime = - DateTime::from_timestamp(note.created_at() as i64, 0).unwrap(); - datetime.format("%Y-%m-%d %H:%M:%S").to_string() - }; - - Some(SimpleNote { - pubkey, - name, - content, - created_at, - note_kind, - }) - }) - .collect(); - - serde_json::to_string(&json!({"search_results": simple_notes})).unwrap() - } - } -} - -#[derive(Debug, Deserialize, Serialize, Clone)] -#[serde(rename_all = "lowercase")] -pub enum QueryContext { - Home, - Profile, - Any, -} - -#[derive(Debug, Deserialize, Serialize, Clone)] -pub struct QueryCall { - context: Option<QueryContext>, - limit: Option<u64>, - since: Option<u64>, - kind: Option<u64>, - until: Option<u64>, - author: Option<String>, - search: Option<String>, -} - -impl QueryCall { - pub fn to_filter(&self) -> nostrdb::Filter { - let mut filter = nostrdb::Filter::new() - .limit(self.limit()) - .kinds([self.kind.unwrap_or(1)]); - - if let Some(search) = &self.search { - filter = filter.search(search); - } - - if let Some(until) = self.until { - filter = filter.until(until); - } - - if let Some(since) = self.since { - filter = filter.since(since); - } - - filter.build() - } - - fn limit(&self) -> u64 { - self.limit.unwrap_or(10) - } - - fn context(&self) -> QueryContext { - self.context.clone().unwrap_or(QueryContext::Any) - } - - pub fn execute(&self, txn: &Transaction, ndb: &Ndb) -> QueryResponse { - let notes = { - if let Ok(results) = ndb.query(txn, &[self.to_filter()], self.limit() as i32) { - results.into_iter().map(|r| r.note_key.as_u64()).collect() - } else { - vec![] - } - }; - QueryResponse { - context: self.context.clone().unwrap_or(QueryContext::Any), - notes, - } - } - - pub fn parse(args: &str) -> Result<ToolCalls, ToolCallError> { - match serde_json::from_str::<QueryCall>(args) { - Ok(call) => Ok(ToolCalls::Query(call)), - Err(e) => Err(ToolCallError::ArgParseFailure(format!( - "Failed to parse args: '{}', error: {}", - args, e - ))), - } - } -} - pub struct Dave { chat: Vec<Message>, /// A 3d representation of dave. @@ -351,8 +56,8 @@ impl Dave { let pubkey = "32e1827635450ebb3c5a7d12c1f8e7b2b514439ac10a67eef3d9fd9c5c68e245".to_string(); let avatar = render_state.map(DaveAvatar::new); let mut tools: HashMap<String, Tool> = HashMap::new(); - for tool in dave_tools() { - tools.insert(tool.name.to_string(), tool); + for tool in tools::dave_tools() { + tools.insert(tool.name().to_string(), tool); } let now = Local::now(); @@ -414,13 +119,13 @@ You are an AI agent for the nostr protocol called Dave, created by Damus. nostr let txn = Transaction::new(app_ctx.ndb).unwrap(); for call in &toolcalls { // execute toolcall - match &call.typ { + match call.calls() { ToolCalls::Query(search_call) => { let resp = search_call.execute(&txn, app_ctx.ndb); - self.chat.push(Message::ToolResponse(ToolResponse { - id: call.id.clone(), - typ: ToolResponses::Query(resp), - })) + self.chat.push(Message::ToolResponse(ToolResponse::new( + call.id().to_owned(), + ToolResponses::Query(resp), + ))) } } } @@ -499,7 +204,7 @@ You are an AI agent for the nostr protocol called Dave, created by Damus. nostr }; //TODO: fix this to support any query - if let Some(search) = &query_call.search { + if let Some(search) = query_call.search() { ui.label(format!("Querying {context}for '{search}'")); } else { ui.label(format!("Querying {:?}", &query_call)); @@ -509,7 +214,7 @@ You are an AI agent for the nostr protocol called Dave, created by Damus. nostr fn tool_call_ui(toolcalls: &[ToolCall], ui: &mut egui::Ui) { ui.vertical(|ui| { for call in toolcalls { - match &call.typ { + match call.calls() { ToolCalls::Query(search_call) => { ui.horizontal(|ui| { egui::Frame::new() @@ -581,7 +286,7 @@ You are an AI agent for the nostr protocol called Dave, created by Damus. nostr model: model_name, stream: Some(true), messages, - tools: Some(dave_tools().iter().map(|t| t.to_api()).collect()), + tools: Some(tools::dave_tools().iter().map(|t| t.to_api()).collect()), user: Some(pubkey), ..Default::default() }) @@ -615,19 +320,19 @@ You are an AI agent for the nostr protocol called Dave, created by Damus. nostr let entry = all_tool_calls.entry(tool.index).or_default(); if let Some(id) = &tool.id { - entry.id.get_or_insert(id.to_string()); + entry.id().get_or_insert(id); } if let Some(name) = tool.function.as_ref().and_then(|f| f.name.as_ref()) { - entry.name.get_or_insert(name.to_string()); + entry.name().get_or_insert(name); } if let Some(argchunk) = tool.function.as_ref().and_then(|f| f.arguments.as_ref()) { entry - .arguments + .arguments_mut() .get_or_insert_with(String::new) .push_str(argchunk); } @@ -687,190 +392,3 @@ impl notedeck::App for Dave { self.render(ctx, ui); } } - -#[derive(Debug, Clone)] -enum ArgType { - String, - Number, - Enum(Vec<&'static str>), -} - -impl ArgType { - pub fn type_string(&self) -> &'static str { - match self { - Self::String => "string", - Self::Number => "number", - Self::Enum(_) => "string", - } - } -} - -#[derive(Debug, Clone)] -struct ToolArg { - typ: ArgType, - name: &'static str, - required: bool, - description: &'static str, - default: Option<Value>, -} - -#[derive(Debug, Clone)] -pub struct Tool { - parse_call: fn(&str) -> Result<ToolCalls, ToolCallError>, - name: &'static str, - description: &'static str, - arguments: Vec<ToolArg>, -} - -impl Tool { - pub fn to_function_object(&self) -> FunctionObject { - let required_args = self - .arguments - .iter() - .filter_map(|arg| { - if arg.required { - Some(Value::String(arg.name.to_owned())) - } else { - None - } - }) - .collect(); - - let mut parameters: serde_json::Map<String, Value> = serde_json::Map::new(); - parameters.insert("type".to_string(), Value::String("object".to_string())); - parameters.insert("required".to_string(), Value::Array(required_args)); - parameters.insert("additionalProperties".to_string(), Value::Bool(false)); - - let mut properties: serde_json::Map<String, Value> = serde_json::Map::new(); - - for arg in &self.arguments { - let mut props: serde_json::Map<String, Value> = serde_json::Map::new(); - props.insert( - "type".to_string(), - Value::String(arg.typ.type_string().to_string()), - ); - - let description = if let Some(default) = &arg.default { - format!("{} (Default: {default}))", arg.description) - } else { - arg.description.to_owned() - }; - - props.insert("description".to_string(), Value::String(description)); - if let ArgType::Enum(enums) = &arg.typ { - props.insert( - "enum".to_string(), - Value::Array( - enums - .iter() - .map(|s| Value::String((*s).to_owned())) - .collect(), - ), - ); - } - - properties.insert(arg.name.to_owned(), Value::Object(props)); - } - - parameters.insert("properties".to_string(), Value::Object(properties)); - - FunctionObject { - name: self.name.to_owned(), - description: Some(self.description.to_owned()), - strict: Some(false), - parameters: Some(Value::Object(parameters)), - } - } - - pub fn to_api(&self) -> ChatCompletionTool { - ChatCompletionTool { - r#type: ChatCompletionToolType::Function, - function: self.to_function_object(), - } - } -} - -fn query_tool() -> Tool { - Tool { - name: "query", - parse_call: QueryCall::parse, - description: "Note query functionality. Used for finding notes using full-text search terms, scoped by different contexts. You can use a combination of limit, since, and until to pull notes from any time range.", - arguments: vec![ - ToolArg { - name: "search", - typ: ArgType::String, - required: false, - default: None, - description: "A fulltext search query. Queries with multiple words will only return results with notes that have all of those words. Don't include filler words/symbols like 'and', punctuation, etc", - }, - - ToolArg { - name: "limit", - typ: ArgType::Number, - required: true, - default: Some(Value::Number(serde_json::Number::from_i128(10).unwrap())), - description: "The number of results to return.", - }, - - ToolArg { - name: "since", - typ: ArgType::Number, - required: false, - default: None, - description: "Only pull notes after this unix timestamp", - }, - - ToolArg { - name: "until", - typ: ArgType::Number, - required: false, - default: None, - description: "Only pull notes up until this unix timestamp. Always include this when searching notes within some date range (yesterday, last week, etc).", - }, - - ToolArg { - name: "kind", - typ: ArgType::Number, - required: false, - default: Some(Value::Number(serde_json::Number::from_i128(1).unwrap())), - description: r#"The kind of note. Kind list: - - 0: profiles - - 1: microblogs/\"tweets\"/posts - - 6: reposts of kind 1 notes - - 7: emoji reactions/likes - - 9735: zaps (bitcoin micropayment receipts) - - 30023: longform articles, blog posts, etc - - "#, - }, - - ToolArg { - name: "author", - typ: ArgType::String, - required: false, - default: None, - description: "An author *pubkey* to constrain the query on. Can be used to search for notes from individual users. If unsure what pubkey to use, you can query for kind 0 profiles with the search argument.", - }, - - ToolArg { - name: "context", - typ: ArgType::Enum(vec!["home", "profile", "any"]), - required: false, - default: Some(Value::String("any".to_string())), - description: "The context in which the search is occuring. valid options are 'home', 'profile', 'any'", - } - ] - } -} - -#[derive(Debug)] -pub enum ToolCallError { - EmptyName, - EmptyArgs, - NotFound(String), - ArgParseFailure(String), -} - -fn dave_tools() -> Vec<Tool> { - vec![query_tool()] -} diff --git a/crates/notedeck_dave/src/messages.rs b/crates/notedeck_dave/src/messages.rs @@ -0,0 +1,64 @@ +use crate::tools::{ToolCall, ToolResponse}; +use async_openai::types::*; +use nostrdb::{Ndb, Transaction}; + +#[derive(Debug, Clone)] +pub enum Message { + User(String), + Assistant(String), + System(String), + ToolCalls(Vec<ToolCall>), + ToolResponse(ToolResponse), +} + +/// The ai backends response. Since we are using streaming APIs these are +/// represented as individual tokens or tool calls +pub enum DaveResponse { + ToolCalls(Vec<ToolCall>), + Token(String), +} + +impl Message { + pub fn to_api_msg(&self, txn: &Transaction, ndb: &Ndb) -> ChatCompletionRequestMessage { + match self { + Message::User(msg) => { + ChatCompletionRequestMessage::User(ChatCompletionRequestUserMessage { + name: None, + content: ChatCompletionRequestUserMessageContent::Text(msg.clone()), + }) + } + + Message::Assistant(msg) => { + ChatCompletionRequestMessage::Assistant(ChatCompletionRequestAssistantMessage { + content: Some(ChatCompletionRequestAssistantMessageContent::Text( + msg.clone(), + )), + ..Default::default() + }) + } + + Message::System(msg) => { + ChatCompletionRequestMessage::System(ChatCompletionRequestSystemMessage { + content: ChatCompletionRequestSystemMessageContent::Text(msg.clone()), + ..Default::default() + }) + } + + Message::ToolCalls(calls) => { + ChatCompletionRequestMessage::Assistant(ChatCompletionRequestAssistantMessage { + tool_calls: Some(calls.iter().map(|c| c.to_api()).collect()), + ..Default::default() + }) + } + + Message::ToolResponse(resp) => { + let tool_response = resp.responses().format_for_dave(txn, ndb); + + ChatCompletionRequestMessage::Tool(ChatCompletionRequestToolMessage { + tool_call_id: resp.id().to_owned(), + content: ChatCompletionRequestToolMessageContent::Text(tool_response), + }) + } + } + } +} diff --git a/crates/notedeck_dave/src/query.rs b/crates/notedeck_dave/src/query.rs diff --git a/crates/notedeck_dave/src/tools.rs b/crates/notedeck_dave/src/tools.rs @@ -0,0 +1,484 @@ +use async_openai::types::*; +use chrono::DateTime; +use nostrdb::{Ndb, NoteKey, Transaction}; +use serde::{Deserialize, Serialize}; +use serde_json::{json, Value}; +use std::collections::HashMap; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolCall { + id: String, + typ: ToolCalls, +} + +impl ToolCall { + pub fn id(&self) -> &str { + &self.id + } + + pub fn calls(&self) -> &ToolCalls { + &self.typ + } + + pub fn to_api(&self) -> ChatCompletionMessageToolCall { + ChatCompletionMessageToolCall { + id: self.id.clone(), + r#type: ChatCompletionToolType::Function, + function: self.typ.to_api(), + } + } +} + +#[derive(Default, Debug, Clone)] +pub struct PartialToolCall { + id: Option<String>, + name: Option<String>, + arguments: Option<String>, +} + +impl PartialToolCall { + pub fn id(&self) -> Option<&str> { + self.id.as_deref() + } + + pub fn name(&self) -> Option<&str> { + self.name.as_deref() + } + + pub fn arguments(&self) -> Option<&str> { + self.arguments.as_deref() + } + + pub fn arguments_mut(&mut self) -> &mut Option<String> { + &mut self.arguments + } +} + +/// The query response from nostrdb for a given context +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct QueryResponse { + context: QueryContext, + notes: Vec<u64>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum ToolResponses { + Query(QueryResponse), +} + +#[derive(Debug, Clone)] +pub struct UnknownToolCall { + id: String, + name: String, + arguments: String, +} + +impl UnknownToolCall { + pub fn parse(&self, tools: &HashMap<String, Tool>) -> Result<ToolCall, ToolCallError> { + let Some(tool) = tools.get(&self.name) else { + return Err(ToolCallError::NotFound(self.name.to_owned())); + }; + + let parsed_args = (tool.parse_call)(&self.arguments)?; + Ok(ToolCall { + id: self.id.clone(), + typ: parsed_args, + }) + } +} + +impl PartialToolCall { + pub fn complete(&self) -> Option<UnknownToolCall> { + Some(UnknownToolCall { + id: self.id.clone()?, + name: self.name.clone()?, + arguments: self.arguments.clone()?, + }) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum ToolCalls { + Query(QueryCall), +} + +impl ToolCalls { + pub fn to_api(&self) -> FunctionCall { + FunctionCall { + name: self.name().to_owned(), + arguments: self.arguments(), + } + } + + fn name(&self) -> &'static str { + match self { + Self::Query(_) => "search", + } + } + + fn arguments(&self) -> String { + match self { + Self::Query(search) => serde_json::to_string(search).unwrap(), + } + } +} + +#[derive(Debug)] +pub enum ToolCallError { + EmptyName, + EmptyArgs, + NotFound(String), + ArgParseFailure(String), +} + +#[derive(Debug, Clone)] +enum ArgType { + String, + Number, + Enum(Vec<&'static str>), +} + +impl ArgType { + pub fn type_string(&self) -> &'static str { + match self { + Self::String => "string", + Self::Number => "number", + Self::Enum(_) => "string", + } + } +} + +#[derive(Debug, Clone)] +struct ToolArg { + typ: ArgType, + name: &'static str, + required: bool, + description: &'static str, + default: Option<Value>, +} + +#[derive(Debug, Clone)] +pub struct Tool { + parse_call: fn(&str) -> Result<ToolCalls, ToolCallError>, + name: &'static str, + description: &'static str, + arguments: Vec<ToolArg>, +} + +impl Tool { + pub fn name(&self) -> &'static str { + self.name + } + + pub fn to_function_object(&self) -> FunctionObject { + let required_args = self + .arguments + .iter() + .filter_map(|arg| { + if arg.required { + Some(Value::String(arg.name.to_owned())) + } else { + None + } + }) + .collect(); + + let mut parameters: serde_json::Map<String, Value> = serde_json::Map::new(); + parameters.insert("type".to_string(), Value::String("object".to_string())); + parameters.insert("required".to_string(), Value::Array(required_args)); + parameters.insert("additionalProperties".to_string(), Value::Bool(false)); + + let mut properties: serde_json::Map<String, Value> = serde_json::Map::new(); + + for arg in &self.arguments { + let mut props: serde_json::Map<String, Value> = serde_json::Map::new(); + props.insert( + "type".to_string(), + Value::String(arg.typ.type_string().to_string()), + ); + + let description = if let Some(default) = &arg.default { + format!("{} (Default: {default}))", arg.description) + } else { + arg.description.to_owned() + }; + + props.insert("description".to_string(), Value::String(description)); + if let ArgType::Enum(enums) = &arg.typ { + props.insert( + "enum".to_string(), + Value::Array( + enums + .iter() + .map(|s| Value::String((*s).to_owned())) + .collect(), + ), + ); + } + + properties.insert(arg.name.to_owned(), Value::Object(props)); + } + + parameters.insert("properties".to_string(), Value::Object(properties)); + + FunctionObject { + name: self.name.to_owned(), + description: Some(self.description.to_owned()), + strict: Some(false), + parameters: Some(Value::Object(parameters)), + } + } + + pub fn to_api(&self) -> ChatCompletionTool { + ChatCompletionTool { + r#type: ChatCompletionToolType::Function, + function: self.to_function_object(), + } + } +} + +impl ToolResponses { + pub fn format_for_dave(&self, txn: &Transaction, ndb: &Ndb) -> String { + format_tool_response_for_ai(txn, ndb, self) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolResponse { + id: String, + typ: ToolResponses, +} + +impl ToolResponse { + pub fn new(id: String, responses: ToolResponses) -> Self { + Self { id, typ: responses } + } + + pub fn responses(&self) -> &ToolResponses { + &self.typ + } + + pub fn id(&self) -> &str { + &self.id + } +} + +#[derive(Debug, Deserialize, Serialize, Clone)] +#[serde(rename_all = "lowercase")] +pub enum QueryContext { + Home, + Profile, + Any, +} + +/// The parsed query that dave wants to use to satisfy a request +#[derive(Debug, Deserialize, Serialize, Clone)] +pub struct QueryCall { + context: Option<QueryContext>, + limit: Option<u64>, + since: Option<u64>, + kind: Option<u64>, + until: Option<u64>, + author: Option<String>, + search: Option<String>, +} + +impl QueryCall { + pub fn to_filter(&self) -> nostrdb::Filter { + let mut filter = nostrdb::Filter::new() + .limit(self.limit()) + .kinds([self.kind.unwrap_or(1)]); + + if let Some(search) = &self.search { + filter = filter.search(search); + } + + if let Some(until) = self.until { + filter = filter.until(until); + } + + if let Some(since) = self.since { + filter = filter.since(since); + } + + filter.build() + } + + fn limit(&self) -> u64 { + self.limit.unwrap_or(10) + } + + pub fn search(&self) -> Option<&str> { + self.search.as_deref() + } + + pub fn context(&self) -> QueryContext { + self.context.clone().unwrap_or(QueryContext::Any) + } + + pub fn execute(&self, txn: &Transaction, ndb: &Ndb) -> QueryResponse { + let notes = { + if let Ok(results) = ndb.query(txn, &[self.to_filter()], self.limit() as i32) { + results.into_iter().map(|r| r.note_key.as_u64()).collect() + } else { + vec![] + } + }; + QueryResponse { + context: self.context.clone().unwrap_or(QueryContext::Any), + notes, + } + } + + pub fn parse(args: &str) -> Result<ToolCalls, ToolCallError> { + match serde_json::from_str::<QueryCall>(args) { + Ok(call) => Ok(ToolCalls::Query(call)), + Err(e) => Err(ToolCallError::ArgParseFailure(format!( + "Failed to parse args: '{}', error: {}", + args, e + ))), + } + } +} + +/// A simple note format for use when formatting +/// tool responses +#[derive(Debug, Serialize)] +struct SimpleNote { + pubkey: String, + name: String, + content: String, + created_at: String, + note_kind: String, // todo: add replying to +} + +/// Take the result of a tool response and present it to the ai so that +/// it can interepret it and take further action +fn format_tool_response_for_ai(txn: &Transaction, ndb: &Ndb, resp: &ToolResponses) -> String { + match resp { + ToolResponses::Query(search_r) => { + let simple_notes: Vec<SimpleNote> = search_r + .notes + .iter() + .filter_map(|nkey| { + let Ok(note) = ndb.get_note_by_key(txn, NoteKey::new(*nkey)) else { + return None; + }; + + let name = ndb + .get_profile_by_pubkey(txn, note.pubkey()) + .ok() + .and_then(|p| p.record().profile()) + .and_then(|p| p.name().or_else(|| p.display_name())) + .unwrap_or("Anonymous") + .to_string(); + + let content = note.content().to_owned(); + let pubkey = hex::encode(note.pubkey()); + let note_kind = note_kind_desc(note.kind() as u64); + + let created_at = { + let datetime = + DateTime::from_timestamp(note.created_at() as i64, 0).unwrap(); + datetime.format("%Y-%m-%d %H:%M:%S").to_string() + }; + + Some(SimpleNote { + pubkey, + name, + content, + created_at, + note_kind, + }) + }) + .collect(); + + serde_json::to_string(&json!({"search_results": simple_notes})).unwrap() + } + } +} + +fn note_kind_desc(kind: u64) -> String { + match kind { + 1 => "microblog".to_string(), + 0 => "profile".to_string(), + _ => kind.to_string(), + } +} + +fn query_tool() -> Tool { + Tool { + name: "query", + parse_call: QueryCall::parse, + description: "Note query functionality. Used for finding notes using full-text search terms, scoped by different contexts. You can use a combination of limit, since, and until to pull notes from any time range.", + arguments: vec![ + ToolArg { + name: "search", + typ: ArgType::String, + required: false, + default: None, + description: "A fulltext search query. Queries with multiple words will only return results with notes that have all of those words. Don't include filler words/symbols like 'and', punctuation, etc", + }, + + ToolArg { + name: "limit", + typ: ArgType::Number, + required: true, + default: Some(Value::Number(serde_json::Number::from_i128(10).unwrap())), + description: "The number of results to return.", + }, + + ToolArg { + name: "since", + typ: ArgType::Number, + required: false, + default: None, + description: "Only pull notes after this unix timestamp", + }, + + ToolArg { + name: "until", + typ: ArgType::Number, + required: false, + default: None, + description: "Only pull notes up until this unix timestamp. Always include this when searching notes within some date range (yesterday, last week, etc).", + }, + + ToolArg { + name: "kind", + typ: ArgType::Number, + required: false, + default: Some(Value::Number(serde_json::Number::from_i128(1).unwrap())), + description: r#"The kind of note. Kind list: + - 0: profiles + - 1: microblogs/\"tweets\"/posts + - 6: reposts of kind 1 notes + - 7: emoji reactions/likes + - 9735: zaps (bitcoin micropayment receipts) + - 30023: longform articles, blog posts, etc + + "#, + }, + + ToolArg { + name: "author", + typ: ArgType::String, + required: false, + default: None, + description: "An author *pubkey* to constrain the query on. Can be used to search for notes from individual users. If unsure what pubkey to use, you can query for kind 0 profiles with the search argument.", + }, + + ToolArg { + name: "context", + typ: ArgType::Enum(vec!["home", "profile", "any"]), + required: false, + default: Some(Value::String("any".to_string())), + description: "The context in which the search is occuring. valid options are 'home', 'profile', 'any'", + } + ] + } +} + +pub fn dave_tools() -> Vec<Tool> { + vec![query_tool()] +}