damus

nostr ios client
git clone git://jb55.com/damus
Log | Files | Refs | README | LICENSE

RelayPool.swift (11596B)


      1 //
      2 //  RelayPool.swift
      3 //  damus
      4 //
      5 //  Created by William Casarin on 2022-04-11.
      6 //
      7 
      8 import Foundation
      9 import Network
     10 
     11 struct RelayHandler {
     12     let sub_id: String
     13     let callback: (RelayURL, NostrConnectionEvent) -> ()
     14 }
     15 
     16 struct QueuedRequest {
     17     let req: NostrRequestType
     18     let relay: RelayURL
     19     let skip_ephemeral: Bool
     20 }
     21 
     22 struct SeenEvent: Hashable {
     23     let relay_id: RelayURL
     24     let evid: NoteId
     25 }
     26 
     27 class RelayPool {
     28     var relays: [Relay] = []
     29     var handlers: [RelayHandler] = []
     30     var request_queue: [QueuedRequest] = []
     31     var seen: Set<SeenEvent> = Set()
     32     var counts: [RelayURL: UInt64] = [:]
     33     var ndb: Ndb
     34     var keypair: Keypair?
     35     var message_received_function: (((String, RelayDescriptor)) -> Void)?
     36     var message_sent_function: (((String, Relay)) -> Void)?
     37 
     38     private let network_monitor = NWPathMonitor()
     39     private let network_monitor_queue = DispatchQueue(label: "io.damus.network_monitor")
     40     private var last_network_status: NWPath.Status = .unsatisfied
     41 
     42     func close() {
     43         disconnect()
     44         relays = []
     45         handlers = []
     46         request_queue = []
     47         seen.removeAll()
     48         counts = [:]
     49         keypair = nil
     50     }
     51 
     52     init(ndb: Ndb, keypair: Keypair? = nil) {
     53         self.ndb = ndb
     54         self.keypair = keypair
     55 
     56         network_monitor.pathUpdateHandler = { [weak self] path in
     57             if (path.status == .satisfied || path.status == .requiresConnection) && self?.last_network_status != path.status {
     58                 DispatchQueue.main.async {
     59                     self?.connect_to_disconnected()
     60                 }
     61             }
     62             
     63             if let self, path.status != self.last_network_status {
     64                 for relay in self.relays {
     65                     relay.connection.log?.add("Network state: \(path.status)")
     66                 }
     67             }
     68             
     69             self?.last_network_status = path.status
     70         }
     71         network_monitor.start(queue: network_monitor_queue)
     72     }
     73     
     74     var our_descriptors: [RelayDescriptor] {
     75         return all_descriptors.filter { d in !d.ephemeral }
     76     }
     77     
     78     var all_descriptors: [RelayDescriptor] {
     79         relays.map { r in r.descriptor }
     80     }
     81     
     82     var num_connected: Int {
     83         return relays.reduce(0) { n, r in n + (r.connection.isConnected ? 1 : 0) }
     84     }
     85 
     86     func remove_handler(sub_id: String) {
     87         self.handlers = handlers.filter { $0.sub_id != sub_id }
     88         print("removing \(sub_id) handler, current: \(handlers.count)")
     89     }
     90     
     91     func ping() {
     92         for relay in relays {
     93             relay.connection.ping()
     94         }
     95     }
     96 
     97     func register_handler(sub_id: String, handler: @escaping (RelayURL, NostrConnectionEvent) -> ()) {
     98         for handler in handlers {
     99             // don't add duplicate handlers
    100             if handler.sub_id == sub_id {
    101                 return
    102             }
    103         }
    104         self.handlers.append(RelayHandler(sub_id: sub_id, callback: handler))
    105         print("registering \(sub_id) handler, current: \(self.handlers.count)")
    106     }
    107 
    108     func remove_relay(_ relay_id: RelayURL) {
    109         var i: Int = 0
    110 
    111         self.disconnect(to: [relay_id])
    112         
    113         for relay in relays {
    114             if relay.id == relay_id {
    115                 relay.connection.disablePermanently()
    116                 relays.remove(at: i)
    117                 break
    118             }
    119             
    120             i += 1
    121         }
    122     }
    123 
    124     func add_relay(_ desc: RelayDescriptor) throws {
    125         let relay_id = desc.url
    126         if get_relay(relay_id) != nil {
    127             throw RelayError.RelayAlreadyExists
    128         }
    129         let conn = RelayConnection(url: desc.url, handleEvent: { event in
    130             self.handle_event(relay_id: relay_id, event: event)
    131         }, processEvent: { wsev in
    132             guard case .message(let msg) = wsev,
    133                   case .string(let str) = msg
    134             else { return }
    135 
    136             let _ = self.ndb.process_event(str)
    137             self.message_received_function?((str, desc))
    138         })
    139         let relay = Relay(descriptor: desc, connection: conn)
    140         self.relays.append(relay)
    141     }
    142 
    143     func setLog(_ log: RelayLog, for relay_id: RelayURL) {
    144         // add the current network state to the log
    145         log.add("Network state: \(network_monitor.currentPath.status)")
    146 
    147         get_relay(relay_id)?.connection.log = log
    148     }
    149     
    150     /// This is used to retry dead connections
    151     func connect_to_disconnected() {
    152         for relay in relays {
    153             let c = relay.connection
    154             
    155             let is_connecting = c.isConnecting
    156 
    157             if is_connecting && (Date.now.timeIntervalSince1970 - c.last_connection_attempt) > 5 {
    158                 print("stale connection detected (\(relay.descriptor.url.absoluteString)). retrying...")
    159                 relay.connection.reconnect()
    160             } else if relay.is_broken || is_connecting || c.isConnected {
    161                 continue
    162             } else {
    163                 relay.connection.reconnect()
    164             }
    165             
    166         }
    167     }
    168 
    169     func reconnect(to: [RelayURL]? = nil) {
    170         let relays = to.map{ get_relays($0) } ?? self.relays
    171         for relay in relays {
    172             // don't try to reconnect to broken relays
    173             relay.connection.reconnect()
    174         }
    175     }
    176 
    177     func connect(to: [RelayURL]? = nil) {
    178         let relays = to.map{ get_relays($0) } ?? self.relays
    179         for relay in relays {
    180             relay.connection.connect()
    181         }
    182     }
    183 
    184     func disconnect(to: [RelayURL]? = nil) {
    185         let relays = to.map{ get_relays($0) } ?? self.relays
    186         for relay in relays {
    187             relay.connection.disconnect()
    188         }
    189     }
    190 
    191     func unsubscribe(sub_id: String, to: [RelayURL]? = nil) {
    192         if to == nil {
    193             self.remove_handler(sub_id: sub_id)
    194         }
    195         self.send(.unsubscribe(sub_id), to: to)
    196     }
    197 
    198     func subscribe(sub_id: String, filters: [NostrFilter], handler: @escaping (RelayURL, NostrConnectionEvent) -> (), to: [RelayURL]? = nil) {
    199         register_handler(sub_id: sub_id, handler: handler)
    200         send(.subscribe(.init(filters: filters, sub_id: sub_id)), to: to)
    201     }
    202 
    203     func subscribe_to(sub_id: String, filters: [NostrFilter], to: [RelayURL]?, handler: @escaping (RelayURL, NostrConnectionEvent) -> ()) {
    204         register_handler(sub_id: sub_id, handler: handler)
    205         send(.subscribe(.init(filters: filters, sub_id: sub_id)), to: to)
    206     }
    207 
    208     func count_queued(relay: RelayURL) -> Int {
    209         var c = 0
    210         for request in request_queue {
    211             if request.relay == relay {
    212                 c += 1
    213             }
    214         }
    215         
    216         return c
    217     }
    218 
    219     func queue_req(r: NostrRequestType, relay: RelayURL, skip_ephemeral: Bool) {
    220         let count = count_queued(relay: relay)
    221         guard count <= 10 else {
    222             print("can't queue, too many queued events for \(relay)")
    223             return
    224         }
    225         
    226         print("queueing request for \(relay)")
    227         request_queue.append(QueuedRequest(req: r, relay: relay, skip_ephemeral: skip_ephemeral))
    228     }
    229 
    230     func send_raw(_ req: NostrRequestType, to: [RelayURL]? = nil, skip_ephemeral: Bool = true) {
    231         let relays = to.map{ get_relays($0) } ?? self.relays
    232 
    233         // send to local relay (nostrdb)
    234         switch req {
    235         case .typical(let r):
    236             if case .event = r, let rstr = make_nostr_req(r) {
    237                 let _ = ndb.process_client_event(rstr)
    238             }
    239         case .custom(let string):
    240             let _ = ndb.process_client_event(string)
    241         }
    242 
    243         for relay in relays {
    244             if req.is_read && !(relay.descriptor.info.read ?? true) {
    245                 continue
    246             }
    247             
    248             if req.is_write && !(relay.descriptor.info.write ?? true) {
    249                 continue
    250             }
    251             
    252             if relay.descriptor.ephemeral && skip_ephemeral {
    253                 continue
    254             }
    255             
    256             guard relay.connection.isConnected else {
    257                 queue_req(r: req, relay: relay.id, skip_ephemeral: skip_ephemeral)
    258                 continue
    259             }
    260             
    261             relay.connection.send(req, callback: { str in
    262                 self.message_sent_function?((str, relay))
    263             })
    264         }
    265     }
    266 
    267     func send(_ req: NostrRequest, to: [RelayURL]? = nil, skip_ephemeral: Bool = true) {
    268         send_raw(.typical(req), to: to, skip_ephemeral: skip_ephemeral)
    269     }
    270 
    271     func get_relays(_ ids: [RelayURL]) -> [Relay] {
    272         // don't include ephemeral relays in the default list to query
    273         relays.filter { ids.contains($0.id) }
    274     }
    275 
    276     func get_relay(_ id: RelayURL) -> Relay? {
    277         relays.first(where: { $0.id == id })
    278     }
    279 
    280     func run_queue(_ relay_id: RelayURL) {
    281         self.request_queue = request_queue.reduce(into: Array<QueuedRequest>()) { (q, req) in
    282             guard req.relay == relay_id else {
    283                 q.append(req)
    284                 return
    285             }
    286             
    287             print("running queueing request: \(req.req) for \(relay_id)")
    288             self.send_raw(req.req, to: [relay_id], skip_ephemeral: false)
    289         }
    290     }
    291 
    292     func record_seen(relay_id: RelayURL, event: NostrConnectionEvent) {
    293         if case .nostr_event(let ev) = event {
    294             if case .event(_, let nev) = ev {
    295                 let k = SeenEvent(relay_id: relay_id, evid: nev.id)
    296                 if !seen.contains(k) {
    297                     seen.insert(k)
    298                     if counts[relay_id] == nil {
    299                         counts[relay_id] = 1
    300                     } else {
    301                         counts[relay_id] = (counts[relay_id] ?? 0) + 1
    302                     }
    303                 }
    304             }
    305         }
    306     }
    307 
    308     func handle_event(relay_id: RelayURL, event: NostrConnectionEvent) {
    309         record_seen(relay_id: relay_id, event: event)
    310 
    311         // run req queue when we reconnect
    312         if case .ws_event(let ws) = event {
    313             if case .connected = ws {
    314                 run_queue(relay_id)
    315             }
    316         }
    317 
    318         // Handle auth
    319         if case let .nostr_event(nostrResponse) = event,
    320            case let .auth(challenge_string) = nostrResponse {
    321             if let relay = get_relay(relay_id) {
    322                 print("received auth request from \(relay.descriptor.url.id)")
    323                 relay.authentication_state = .pending
    324                 if let keypair {
    325                     if let fullKeypair = keypair.to_full() {
    326                         if let authRequest = make_auth_request(keypair: fullKeypair, challenge_string: challenge_string, relay: relay) {
    327                             send(.auth(authRequest), to: [relay_id], skip_ephemeral: false)
    328                             relay.authentication_state = .verified
    329                         } else {
    330                             print("failed to make auth request")
    331                         }
    332                     } else {
    333                         print("keypair provided did not contain private key, can not sign auth request")
    334                         relay.authentication_state = .error(.no_private_key)
    335                     }
    336                 } else {
    337                     print("no keypair to reply to auth request")
    338                     relay.authentication_state = .error(.no_key)
    339                 }
    340             } else {
    341                 print("no relay found for \(relay_id)")
    342             }
    343         }
    344 
    345         for handler in handlers {
    346             handler.callback(relay_id, event)
    347         }
    348     }
    349 }
    350 
    351 func add_rw_relay(_ pool: RelayPool, _ url: RelayURL) {
    352     try? pool.add_relay(RelayDescriptor(url: url, info: .rw))
    353 }
    354 
    355