damus

nostr ios client
git clone git://jb55.com/damus
Log | Files | Refs | README | LICENSE

RelayPool.swift (11799B)


      1 //
      2 //  RelayPool.swift
      3 //  damus
      4 //
      5 //  Created by William Casarin on 2022-04-11.
      6 //
      7 
      8 import Foundation
      9 import Network
     10 
     11 struct RelayHandler {
     12     let sub_id: String
     13     let callback: (RelayURL, NostrConnectionEvent) -> ()
     14 }
     15 
     16 struct QueuedRequest {
     17     let req: NostrRequestType
     18     let relay: RelayURL
     19     let skip_ephemeral: Bool
     20 }
     21 
     22 struct SeenEvent: Hashable {
     23     let relay_id: RelayURL
     24     let evid: NoteId
     25 }
     26 
     27 class RelayPool {
     28     var relays: [Relay] = []
     29     var handlers: [RelayHandler] = []
     30     var request_queue: [QueuedRequest] = []
     31     var seen: Set<SeenEvent> = Set()
     32     var counts: [RelayURL: UInt64] = [:]
     33     var ndb: Ndb
     34     var keypair: Keypair?
     35     var message_received_function: (((String, RelayDescriptor)) -> Void)?
     36     var message_sent_function: (((String, Relay)) -> Void)?
     37 
     38     private let network_monitor = NWPathMonitor()
     39     private let network_monitor_queue = DispatchQueue(label: "io.damus.network_monitor")
     40     private var last_network_status: NWPath.Status = .unsatisfied
     41 
     42     func close() {
     43         disconnect()
     44         relays = []
     45         handlers = []
     46         request_queue = []
     47         seen.removeAll()
     48         counts = [:]
     49         keypair = nil
     50     }
     51 
     52     init(ndb: Ndb, keypair: Keypair? = nil) {
     53         self.ndb = ndb
     54         self.keypair = keypair
     55 
     56         network_monitor.pathUpdateHandler = { [weak self] path in
     57             if (path.status == .satisfied || path.status == .requiresConnection) && self?.last_network_status != path.status {
     58                 DispatchQueue.main.async {
     59                     self?.connect_to_disconnected()
     60                 }
     61             }
     62             
     63             if let self, path.status != self.last_network_status {
     64                 for relay in self.relays {
     65                     relay.connection.log?.add("Network state: \(path.status)")
     66                 }
     67             }
     68             
     69             self?.last_network_status = path.status
     70         }
     71         network_monitor.start(queue: network_monitor_queue)
     72     }
     73     
     74     var our_descriptors: [RelayDescriptor] {
     75         return all_descriptors.filter { d in !d.ephemeral }
     76     }
     77     
     78     var all_descriptors: [RelayDescriptor] {
     79         relays.map { r in r.descriptor }
     80     }
     81     
     82     var num_connected: Int {
     83         return relays.reduce(0) { n, r in n + (r.connection.isConnected ? 1 : 0) }
     84     }
     85 
     86     func remove_handler(sub_id: String) {
     87         self.handlers = handlers.filter { $0.sub_id != sub_id }
     88         print("removing \(sub_id) handler, current: \(handlers.count)")
     89     }
     90     
     91     func ping() {
     92         Log.info("Pinging %d relays", for: .networking, relays.count)
     93         for relay in relays {
     94             relay.connection.ping()
     95         }
     96     }
     97 
     98     func register_handler(sub_id: String, handler: @escaping (RelayURL, NostrConnectionEvent) -> ()) {
     99         for handler in handlers {
    100             // don't add duplicate handlers
    101             if handler.sub_id == sub_id {
    102                 return
    103             }
    104         }
    105         self.handlers.append(RelayHandler(sub_id: sub_id, callback: handler))
    106         print("registering \(sub_id) handler, current: \(self.handlers.count)")
    107     }
    108 
    109     func remove_relay(_ relay_id: RelayURL) {
    110         var i: Int = 0
    111 
    112         self.disconnect(to: [relay_id])
    113         
    114         for relay in relays {
    115             if relay.id == relay_id {
    116                 relay.connection.disablePermanently()
    117                 relays.remove(at: i)
    118                 break
    119             }
    120             
    121             i += 1
    122         }
    123     }
    124 
    125     func add_relay(_ desc: RelayDescriptor) throws {
    126         let relay_id = desc.url
    127         if get_relay(relay_id) != nil {
    128             throw RelayError.RelayAlreadyExists
    129         }
    130         let conn = RelayConnection(url: desc.url, handleEvent: { event in
    131             self.handle_event(relay_id: relay_id, event: event)
    132         }, processEvent: { wsev in
    133             guard case .message(let msg) = wsev,
    134                   case .string(let str) = msg
    135             else { return }
    136 
    137             let _ = self.ndb.process_event(str)
    138             self.message_received_function?((str, desc))
    139         })
    140         let relay = Relay(descriptor: desc, connection: conn)
    141         self.relays.append(relay)
    142     }
    143 
    144     func setLog(_ log: RelayLog, for relay_id: RelayURL) {
    145         // add the current network state to the log
    146         log.add("Network state: \(network_monitor.currentPath.status)")
    147 
    148         get_relay(relay_id)?.connection.log = log
    149     }
    150     
    151     /// This is used to retry dead connections
    152     func connect_to_disconnected() {
    153         for relay in relays {
    154             let c = relay.connection
    155             
    156             let is_connecting = c.isConnecting
    157 
    158             if is_connecting && (Date.now.timeIntervalSince1970 - c.last_connection_attempt) > 5 {
    159                 print("stale connection detected (\(relay.descriptor.url.absoluteString)). retrying...")
    160                 relay.connection.reconnect()
    161             } else if relay.is_broken || is_connecting || c.isConnected {
    162                 continue
    163             } else {
    164                 relay.connection.reconnect()
    165             }
    166             
    167         }
    168     }
    169 
    170     func reconnect(to: [RelayURL]? = nil) {
    171         let relays = to.map{ get_relays($0) } ?? self.relays
    172         for relay in relays {
    173             // don't try to reconnect to broken relays
    174             relay.connection.reconnect()
    175         }
    176     }
    177 
    178     func connect(to: [RelayURL]? = nil) {
    179         let relays = to.map{ get_relays($0) } ?? self.relays
    180         for relay in relays {
    181             relay.connection.connect()
    182         }
    183     }
    184 
    185     func disconnect(to: [RelayURL]? = nil) {
    186         let relays = to.map{ get_relays($0) } ?? self.relays
    187         for relay in relays {
    188             relay.connection.disconnect()
    189         }
    190     }
    191 
    192     func unsubscribe(sub_id: String, to: [RelayURL]? = nil) {
    193         if to == nil {
    194             self.remove_handler(sub_id: sub_id)
    195         }
    196         self.send(.unsubscribe(sub_id), to: to)
    197     }
    198 
    199     func subscribe(sub_id: String, filters: [NostrFilter], handler: @escaping (RelayURL, NostrConnectionEvent) -> (), to: [RelayURL]? = nil) {
    200         register_handler(sub_id: sub_id, handler: handler)
    201         send(.subscribe(.init(filters: filters, sub_id: sub_id)), to: to)
    202     }
    203 
    204     func subscribe_to(sub_id: String, filters: [NostrFilter], to: [RelayURL]?, handler: @escaping (RelayURL, NostrConnectionEvent) -> ()) {
    205         register_handler(sub_id: sub_id, handler: handler)
    206         send(.subscribe(.init(filters: filters, sub_id: sub_id)), to: to)
    207     }
    208 
    209     func count_queued(relay: RelayURL) -> Int {
    210         var c = 0
    211         for request in request_queue {
    212             if request.relay == relay {
    213                 c += 1
    214             }
    215         }
    216         
    217         return c
    218     }
    219 
    220     func queue_req(r: NostrRequestType, relay: RelayURL, skip_ephemeral: Bool) {
    221         let count = count_queued(relay: relay)
    222         guard count <= 10 else {
    223             print("can't queue, too many queued events for \(relay)")
    224             return
    225         }
    226         
    227         print("queueing request for \(relay)")
    228         request_queue.append(QueuedRequest(req: r, relay: relay, skip_ephemeral: skip_ephemeral))
    229     }
    230     
    231     func send_raw_to_local_ndb(_ req: NostrRequestType) {
    232         // send to local relay (nostrdb)
    233         switch req {
    234             case .typical(let r):
    235                 if case .event = r, let rstr = make_nostr_req(r) {
    236                     let _ = ndb.process_client_event(rstr)
    237                 }
    238             case .custom(let string):
    239                 let _ = ndb.process_client_event(string)
    240         }
    241     }
    242 
    243     func send_raw(_ req: NostrRequestType, to: [RelayURL]? = nil, skip_ephemeral: Bool = true) {
    244         let relays = to.map{ get_relays($0) } ?? self.relays
    245 
    246         self.send_raw_to_local_ndb(req)
    247 
    248         for relay in relays {
    249             if req.is_read && !(relay.descriptor.info.read ?? true) {
    250                 continue
    251             }
    252             
    253             if req.is_write && !(relay.descriptor.info.write ?? true) {
    254                 continue
    255             }
    256             
    257             if relay.descriptor.ephemeral && skip_ephemeral {
    258                 continue
    259             }
    260             
    261             guard relay.connection.isConnected else {
    262                 queue_req(r: req, relay: relay.id, skip_ephemeral: skip_ephemeral)
    263                 continue
    264             }
    265             
    266             relay.connection.send(req, callback: { str in
    267                 self.message_sent_function?((str, relay))
    268             })
    269         }
    270     }
    271 
    272     func send(_ req: NostrRequest, to: [RelayURL]? = nil, skip_ephemeral: Bool = true) {
    273         send_raw(.typical(req), to: to, skip_ephemeral: skip_ephemeral)
    274     }
    275 
    276     func get_relays(_ ids: [RelayURL]) -> [Relay] {
    277         // don't include ephemeral relays in the default list to query
    278         relays.filter { ids.contains($0.id) }
    279     }
    280 
    281     func get_relay(_ id: RelayURL) -> Relay? {
    282         relays.first(where: { $0.id == id })
    283     }
    284 
    285     func run_queue(_ relay_id: RelayURL) {
    286         self.request_queue = request_queue.reduce(into: Array<QueuedRequest>()) { (q, req) in
    287             guard req.relay == relay_id else {
    288                 q.append(req)
    289                 return
    290             }
    291             
    292             print("running queueing request: \(req.req) for \(relay_id)")
    293             self.send_raw(req.req, to: [relay_id], skip_ephemeral: false)
    294         }
    295     }
    296 
    297     func record_seen(relay_id: RelayURL, event: NostrConnectionEvent) {
    298         if case .nostr_event(let ev) = event {
    299             if case .event(_, let nev) = ev {
    300                 let k = SeenEvent(relay_id: relay_id, evid: nev.id)
    301                 if !seen.contains(k) {
    302                     seen.insert(k)
    303                     if counts[relay_id] == nil {
    304                         counts[relay_id] = 1
    305                     } else {
    306                         counts[relay_id] = (counts[relay_id] ?? 0) + 1
    307                     }
    308                 }
    309             }
    310         }
    311     }
    312 
    313     func handle_event(relay_id: RelayURL, event: NostrConnectionEvent) {
    314         record_seen(relay_id: relay_id, event: event)
    315 
    316         // run req queue when we reconnect
    317         if case .ws_event(let ws) = event {
    318             if case .connected = ws {
    319                 run_queue(relay_id)
    320             }
    321         }
    322 
    323         // Handle auth
    324         if case let .nostr_event(nostrResponse) = event,
    325            case let .auth(challenge_string) = nostrResponse {
    326             if let relay = get_relay(relay_id) {
    327                 print("received auth request from \(relay.descriptor.url.id)")
    328                 relay.authentication_state = .pending
    329                 if let keypair {
    330                     if let fullKeypair = keypair.to_full() {
    331                         if let authRequest = make_auth_request(keypair: fullKeypair, challenge_string: challenge_string, relay: relay) {
    332                             send(.auth(authRequest), to: [relay_id], skip_ephemeral: false)
    333                             relay.authentication_state = .verified
    334                         } else {
    335                             print("failed to make auth request")
    336                         }
    337                     } else {
    338                         print("keypair provided did not contain private key, can not sign auth request")
    339                         relay.authentication_state = .error(.no_private_key)
    340                     }
    341                 } else {
    342                     print("no keypair to reply to auth request")
    343                     relay.authentication_state = .error(.no_key)
    344                 }
    345             } else {
    346                 print("no relay found for \(relay_id)")
    347             }
    348         }
    349 
    350         for handler in handlers {
    351             handler.callback(relay_id, event)
    352         }
    353     }
    354 }
    355 
    356 func add_rw_relay(_ pool: RelayPool, _ url: RelayURL) {
    357     try? pool.add_relay(RelayDescriptor(url: url, info: .rw))
    358 }
    359 
    360