damus

nostr ios client
git clone git://jb55.com/damus
Log | Files | Refs | README | LICENSE

RelayPool.swift (15428B)


      1 //
      2 //  RelayPool.swift
      3 //  damus
      4 //
      5 //  Created by William Casarin on 2022-04-11.
      6 //
      7 
      8 import Foundation
      9 import Network
     10 
     11 struct RelayHandler {
     12     let sub_id: String
     13     let callback: (RelayURL, NostrConnectionEvent) -> ()
     14 }
     15 
     16 struct QueuedRequest {
     17     let req: NostrRequestType
     18     let relay: RelayURL
     19     let skip_ephemeral: Bool
     20 }
     21 
     22 /// Establishes and manages connections and subscriptions to a list of relays.
     23 class RelayPool {
     24     private(set) var relays: [Relay] = []
     25     var handlers: [RelayHandler] = []
     26     var request_queue: [QueuedRequest] = []
     27     var seen: [NoteId: Set<RelayURL>] = [:]
     28     var counts: [RelayURL: UInt64] = [:]
     29     var ndb: Ndb
     30     /// The keypair used to authenticate with relays
     31     var keypair: Keypair?
     32     var message_received_function: (((String, RelayDescriptor)) -> Void)?
     33     var message_sent_function: (((String, Relay)) -> Void)?
     34 
     35     private let network_monitor = NWPathMonitor()
     36     private let network_monitor_queue = DispatchQueue(label: "io.damus.network_monitor")
     37     private var last_network_status: NWPath.Status = .unsatisfied
     38 
     39     func close() {
     40         disconnect()
     41         relays = []
     42         handlers = []
     43         request_queue = []
     44         seen.removeAll()
     45         counts = [:]
     46         keypair = nil
     47     }
     48 
     49     init(ndb: Ndb, keypair: Keypair? = nil) {
     50         self.ndb = ndb
     51         self.keypair = keypair
     52 
     53         network_monitor.pathUpdateHandler = { [weak self] path in
     54             if (path.status == .satisfied || path.status == .requiresConnection) && self?.last_network_status != path.status {
     55                 DispatchQueue.main.async {
     56                     self?.connect_to_disconnected()
     57                 }
     58             }
     59             
     60             if let self, path.status != self.last_network_status {
     61                 for relay in self.relays {
     62                     relay.connection.log?.add("Network state: \(path.status)")
     63                 }
     64             }
     65             
     66             self?.last_network_status = path.status
     67         }
     68         network_monitor.start(queue: network_monitor_queue)
     69     }
     70     
     71     var our_descriptors: [RelayDescriptor] {
     72         return all_descriptors.filter { d in !d.ephemeral }
     73     }
     74     
     75     var all_descriptors: [RelayDescriptor] {
     76         relays.map { r in r.descriptor }
     77     }
     78     
     79     var num_connected: Int {
     80         return relays.reduce(0) { n, r in n + (r.connection.isConnected ? 1 : 0) }
     81     }
     82 
     83     func remove_handler(sub_id: String) {
     84         self.handlers = handlers.filter { $0.sub_id != sub_id }
     85         print("removing \(sub_id) handler, current: \(handlers.count)")
     86     }
     87     
     88     func ping() {
     89         Log.info("Pinging %d relays", for: .networking, relays.count)
     90         for relay in relays {
     91             relay.connection.ping()
     92         }
     93     }
     94 
     95     func register_handler(sub_id: String, handler: @escaping (RelayURL, NostrConnectionEvent) -> ()) {
     96         for handler in handlers {
     97             // don't add duplicate handlers
     98             if handler.sub_id == sub_id {
     99                 return
    100             }
    101         }
    102         self.handlers.append(RelayHandler(sub_id: sub_id, callback: handler))
    103         print("registering \(sub_id) handler, current: \(self.handlers.count)")
    104     }
    105 
    106     func remove_relay(_ relay_id: RelayURL) {
    107         var i: Int = 0
    108 
    109         self.disconnect(to: [relay_id])
    110         
    111         for relay in relays {
    112             if relay.id == relay_id {
    113                 relay.connection.disablePermanently()
    114                 relays.remove(at: i)
    115                 break
    116             }
    117             
    118             i += 1
    119         }
    120     }
    121 
    122     func add_relay(_ desc: RelayDescriptor) throws(RelayError) {
    123         let relay_id = desc.url
    124         if get_relay(relay_id) != nil {
    125             throw RelayError.RelayAlreadyExists
    126         }
    127         let conn = RelayConnection(url: desc.url, handleEvent: { event in
    128             self.handle_event(relay_id: relay_id, event: event)
    129         }, processUnverifiedWSEvent: { wsev in
    130             guard case .message(let msg) = wsev,
    131                   case .string(let str) = msg
    132             else { return }
    133 
    134             let _ = self.ndb.process_event(str)
    135             self.message_received_function?((str, desc))
    136         })
    137         let relay = Relay(descriptor: desc, connection: conn)
    138         self.relays.append(relay)
    139     }
    140 
    141     func setLog(_ log: RelayLog, for relay_id: RelayURL) {
    142         // add the current network state to the log
    143         log.add("Network state: \(network_monitor.currentPath.status)")
    144 
    145         get_relay(relay_id)?.connection.log = log
    146     }
    147     
    148     /// This is used to retry dead connections
    149     func connect_to_disconnected() {
    150         for relay in relays {
    151             let c = relay.connection
    152             
    153             let is_connecting = c.isConnecting
    154 
    155             if is_connecting && (Date.now.timeIntervalSince1970 - c.last_connection_attempt) > 5 {
    156                 print("stale connection detected (\(relay.descriptor.url.absoluteString)). retrying...")
    157                 relay.connection.reconnect()
    158             } else if relay.is_broken || is_connecting || c.isConnected {
    159                 continue
    160             } else {
    161                 relay.connection.reconnect()
    162             }
    163             
    164         }
    165     }
    166 
    167     func reconnect(to: [RelayURL]? = nil) {
    168         let relays = to.map{ get_relays($0) } ?? self.relays
    169         for relay in relays {
    170             // don't try to reconnect to broken relays
    171             relay.connection.reconnect()
    172         }
    173     }
    174 
    175     func connect(to: [RelayURL]? = nil) {
    176         let relays = to.map{ get_relays($0) } ?? self.relays
    177         for relay in relays {
    178             relay.connection.connect()
    179         }
    180     }
    181 
    182     func disconnect(to: [RelayURL]? = nil) {
    183         let relays = to.map{ get_relays($0) } ?? self.relays
    184         for relay in relays {
    185             relay.connection.disconnect()
    186         }
    187     }
    188 
    189     func unsubscribe(sub_id: String, to: [RelayURL]? = nil) {
    190         if to == nil {
    191             self.remove_handler(sub_id: sub_id)
    192         }
    193         self.send(.unsubscribe(sub_id), to: to)
    194     }
    195 
    196     func subscribe(sub_id: String, filters: [NostrFilter], handler: @escaping (RelayURL, NostrConnectionEvent) -> (), to: [RelayURL]? = nil) {
    197         register_handler(sub_id: sub_id, handler: handler)
    198         send(.subscribe(.init(filters: filters, sub_id: sub_id)), to: to)
    199     }
    200     
    201     /// Subscribes to data from the `RelayPool` based on a filter and a list of desired relays.
    202     /// 
    203     /// - Parameters:
    204     ///   - filters: The filters specifying the desired content.
    205     ///   - desiredRelays: The desired relays which to subsctibe to. If `nil`, it defaults to the `RelayPool`'s default list
    206     ///   - eoseTimeout: The maximum timeout which to give up waiting for the eoseSignal, in seconds
    207     /// - Returns: Returns an async stream that callers can easily consume via a for-loop
    208     func subscribe(filters: [NostrFilter], to desiredRelays: [RelayURL]? = nil, eoseTimeout: TimeInterval = 10) -> AsyncStream<StreamItem> {
    209         let desiredRelays = desiredRelays ?? self.relays.map({ $0.descriptor.url })
    210         return AsyncStream<StreamItem> { continuation in
    211             let sub_id = UUID().uuidString
    212             var seenEvents: Set<NoteId> = []
    213             var relaysWhoFinishedInitialResults: Set<RelayURL> = []
    214             var eoseSent = false
    215             self.subscribe(sub_id: sub_id, filters: filters, handler: { (relayUrl, connectionEvent) in
    216                 switch connectionEvent {
    217                 case .ws_connection_event(let ev):
    218                     // Websocket events such as connect/disconnect/error are already handled in `RelayConnection`. Do not perform any handling here.
    219                     // For the future, perhaps we should abstract away `.ws_connection_event` in `RelayPool`? Seems like something to be handled on the `RelayConnection` layer.
    220                     break
    221                 case .nostr_event(let nostrResponse):
    222                     guard nostrResponse.subid == sub_id else { return } // Do not stream items that do not belong in this subscription
    223                     switch nostrResponse {
    224                     case .event(_, let nostrEvent):
    225                         if seenEvents.contains(nostrEvent.id) { break } // Don't send two of the same events.
    226                         continuation.yield(with: .success(.event(nostrEvent)))
    227                         seenEvents.insert(nostrEvent.id)
    228                     case .notice(let note):
    229                         break   // We do not support handling these yet
    230                     case .eose(_):
    231                         relaysWhoFinishedInitialResults.insert(relayUrl)
    232                         if relaysWhoFinishedInitialResults == Set(desiredRelays) {
    233                             continuation.yield(with: .success(.eose))
    234                             eoseSent = true
    235                         }
    236                     case .ok(_): break    // No need to handle this, we are not sending an event to the relay
    237                     case .auth(_): break    // Handled in a separate function in RelayPool
    238                     }
    239                 }
    240             }, to: desiredRelays)
    241             Task {
    242                 try? await Task.sleep(nanoseconds: 1_000_000_000 * UInt64(eoseTimeout))
    243                 if !eoseSent { continuation.yield(with: .success(.eose)) }
    244             }
    245             continuation.onTermination = { @Sendable _ in
    246                 self.unsubscribe(sub_id: sub_id, to: desiredRelays)
    247                 self.remove_handler(sub_id: sub_id)
    248             }
    249         }
    250     }
    251     
    252     enum StreamItem {
    253         /// A Nostr event
    254         case event(NostrEvent)
    255         /// The "end of stored events" signal
    256         case eose
    257     }
    258 
    259     func subscribe_to(sub_id: String, filters: [NostrFilter], to: [RelayURL]?, handler: @escaping (RelayURL, NostrConnectionEvent) -> ()) {
    260         register_handler(sub_id: sub_id, handler: handler)
    261         send(.subscribe(.init(filters: filters, sub_id: sub_id)), to: to)
    262     }
    263 
    264     func count_queued(relay: RelayURL) -> Int {
    265         var c = 0
    266         for request in request_queue {
    267             if request.relay == relay {
    268                 c += 1
    269             }
    270         }
    271         
    272         return c
    273     }
    274 
    275     func queue_req(r: NostrRequestType, relay: RelayURL, skip_ephemeral: Bool) {
    276         let count = count_queued(relay: relay)
    277         guard count <= 10 else {
    278             print("can't queue, too many queued events for \(relay)")
    279             return
    280         }
    281         
    282         print("queueing request for \(relay)")
    283         request_queue.append(QueuedRequest(req: r, relay: relay, skip_ephemeral: skip_ephemeral))
    284     }
    285     
    286     func send_raw_to_local_ndb(_ req: NostrRequestType) {
    287         // send to local relay (nostrdb)
    288         switch req {
    289             case .typical(let r):
    290                 if case .event = r, let rstr = make_nostr_req(r) {
    291                     let _ = ndb.process_client_event(rstr)
    292                 }
    293             case .custom(let string):
    294                 let _ = ndb.process_client_event(string)
    295         }
    296     }
    297 
    298     func send_raw(_ req: NostrRequestType, to: [RelayURL]? = nil, skip_ephemeral: Bool = true) {
    299         let relays = to.map{ get_relays($0) } ?? self.relays
    300 
    301         self.send_raw_to_local_ndb(req)     // Always send Nostr events and data to NostrDB for a local copy
    302 
    303         for relay in relays {
    304             if req.is_read && !(relay.descriptor.info.canRead) {
    305                 continue    // Do not send read requests to relays that are not READ relays
    306             }
    307             
    308             if req.is_write && !(relay.descriptor.info.canWrite) {
    309                 continue    // Do not send write requests to relays that are not WRITE relays
    310             }
    311             
    312             if relay.descriptor.ephemeral && skip_ephemeral {
    313                 continue    // Do not send requests to ephemeral relays if we want to skip them
    314             }
    315             
    316             guard relay.connection.isConnected else {
    317                 queue_req(r: req, relay: relay.id, skip_ephemeral: skip_ephemeral)
    318                 continue
    319             }
    320             
    321             relay.connection.send(req, callback: { str in
    322                 self.message_sent_function?((str, relay))
    323             })
    324         }
    325     }
    326 
    327     func send(_ req: NostrRequest, to: [RelayURL]? = nil, skip_ephemeral: Bool = true) {
    328         send_raw(.typical(req), to: to, skip_ephemeral: skip_ephemeral)
    329     }
    330 
    331     func get_relays(_ ids: [RelayURL]) -> [Relay] {
    332         // don't include ephemeral relays in the default list to query
    333         relays.filter { ids.contains($0.id) }
    334     }
    335 
    336     func get_relay(_ id: RelayURL) -> Relay? {
    337         relays.first(where: { $0.id == id })
    338     }
    339 
    340     func run_queue(_ relay_id: RelayURL) {
    341         self.request_queue = request_queue.reduce(into: Array<QueuedRequest>()) { (q, req) in
    342             guard req.relay == relay_id else {
    343                 q.append(req)
    344                 return
    345             }
    346             
    347             print("running queueing request: \(req.req) for \(relay_id)")
    348             self.send_raw(req.req, to: [relay_id], skip_ephemeral: false)
    349         }
    350     }
    351 
    352     func record_seen(relay_id: RelayURL, event: NostrConnectionEvent) {
    353         if case .nostr_event(let ev) = event {
    354             if case .event(_, let nev) = ev {
    355                 if seen[nev.id]?.contains(relay_id) == true {
    356                     return
    357                 }
    358                 seen[nev.id, default: Set()].insert(relay_id)
    359                 counts[relay_id, default: 0] += 1
    360                 notify(.update_stats(note_id: nev.id))
    361             }
    362         }
    363     }
    364 
    365     func handle_event(relay_id: RelayURL, event: NostrConnectionEvent) {
    366         record_seen(relay_id: relay_id, event: event)
    367 
    368         // run req queue when we reconnect
    369         if case .ws_connection_event(let ws) = event {
    370             if case .connected = ws {
    371                 run_queue(relay_id)
    372             }
    373         }
    374 
    375         // Handle auth
    376         if case let .nostr_event(nostrResponse) = event,
    377            case let .auth(challenge_string) = nostrResponse {
    378             if let relay = get_relay(relay_id) {
    379                 print("received auth request from \(relay.descriptor.url.id)")
    380                 relay.authentication_state = .pending
    381                 if let keypair {
    382                     if let fullKeypair = keypair.to_full() {
    383                         if let authRequest = make_auth_request(keypair: fullKeypair, challenge_string: challenge_string, relay: relay) {
    384                             send(.auth(authRequest), to: [relay_id], skip_ephemeral: false)
    385                             relay.authentication_state = .verified
    386                         } else {
    387                             print("failed to make auth request")
    388                         }
    389                     } else {
    390                         print("keypair provided did not contain private key, can not sign auth request")
    391                         relay.authentication_state = .error(.no_private_key)
    392                     }
    393                 } else {
    394                     print("no keypair to reply to auth request")
    395                     relay.authentication_state = .error(.no_key)
    396                 }
    397             } else {
    398                 print("no relay found for \(relay_id)")
    399             }
    400         }
    401 
    402         for handler in handlers {
    403             handler.callback(relay_id, event)
    404         }
    405     }
    406 }
    407 
    408 func add_rw_relay(_ pool: RelayPool, _ url: RelayURL) {
    409     try? pool.add_relay(RelayPool.RelayDescriptor(url: url, info: .readWrite))
    410 }
    411 
    412