commit 252a77fd975c6b74becbf1ba47f899e971ad2011
parent a611a5d25240220a3302e28e0b7db57fbb7c82c9
Author: Bryan Montz <bryanmontz@me.com>
Date: Wed, 15 Mar 2023 10:37:37 -0600
Reduce battery usage by using exp backoff on connections
Changelog-Changed: Reduce battery usage by using exp backoff on connections
Diffstat:
12 files changed, 323 insertions(+), 161 deletions(-)
diff --git a/damus.xcodeproj/project.pbxproj b/damus.xcodeproj/project.pbxproj
@@ -222,6 +222,7 @@
4CF0ABF029857E9200D66079 /* Bech32Object.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4CF0ABEF29857E9200D66079 /* Bech32Object.swift */; };
4CF0ABF62985CD5500D66079 /* UserSearch.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4CF0ABF52985CD5500D66079 /* UserSearch.swift */; };
4FE60CDD295E1C5E00105A1F /* Wallet.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4FE60CDC295E1C5E00105A1F /* Wallet.swift */; };
+ 5023E76329AA3627007D3D50 /* RelayPoolTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 5023E76229AA3627007D3D50 /* RelayPoolTests.swift */; };
50A50A8D29A09E1C00C01BE7 /* RequestTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 50A50A8C29A09E1C00C01BE7 /* RequestTests.swift */; };
5C513FBA297F72980072348F /* CustomPicker.swift in Sources */ = {isa = PBXBuildFile; fileRef = 5C513FB9297F72980072348F /* CustomPicker.swift */; };
5C513FCC2984ACA60072348F /* QRCodeView.swift in Sources */ = {isa = PBXBuildFile; fileRef = 5C513FCB2984ACA60072348F /* QRCodeView.swift */; };
@@ -581,6 +582,7 @@
4CF0ABEF29857E9200D66079 /* Bech32Object.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Bech32Object.swift; sourceTree = "<group>"; };
4CF0ABF52985CD5500D66079 /* UserSearch.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = UserSearch.swift; sourceTree = "<group>"; };
4FE60CDC295E1C5E00105A1F /* Wallet.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Wallet.swift; sourceTree = "<group>"; };
+ 5023E76229AA3627007D3D50 /* RelayPoolTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = RelayPoolTests.swift; sourceTree = "<group>"; };
50A50A8C29A09E1C00C01BE7 /* RequestTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = RequestTests.swift; sourceTree = "<group>"; };
5C513FB9297F72980072348F /* CustomPicker.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = CustomPicker.swift; sourceTree = "<group>"; };
5C513FCB2984ACA60072348F /* QRCodeView.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = QRCodeView.swift; sourceTree = "<group>"; };
@@ -1070,6 +1072,7 @@
4CE6DEF627F7A08200C66700 /* damusTests */ = {
isa = PBXGroup;
children = (
+ 5023E76229AA3627007D3D50 /* RelayPoolTests.swift */,
50A50A8C29A09E1C00C01BE7 /* RequestTests.swift */,
DD597CBC2963D85A00C64D32 /* MarkdownTests.swift */,
4C90BD1B283AC38E008EE7EF /* Bech32Tests.swift */,
@@ -1571,6 +1574,7 @@
DD597CBD2963D85A00C64D32 /* MarkdownTests.swift in Sources */,
3A3040EF29A8FEE9008A0F29 /* EventDetailBarTests.swift in Sources */,
4C3EA67B28FF7B3900C48A62 /* InvoiceTests.swift in Sources */,
+ 5023E76329AA3627007D3D50 /* RelayPoolTests.swift in Sources */,
4C363A9E2828A822006E126D /* ReplyTests.swift in Sources */,
4CB883AA297612FF00DC99E7 /* ZapTests.swift in Sources */,
4CB8839A297322D200DC99E7 /* DMTests.swift in Sources */,
diff --git a/damus/Models/EventsModel.swift b/damus/Models/EventsModel.swift
@@ -31,9 +31,9 @@ class EventsModel: ObservableObject {
}
func subscribe() {
- state.pool.subscribe(sub_id: sub_id,
- filters: [get_filter()],
- handler: handle_nostr_event)
+ state.pool.subscribe_to(sub_id: sub_id,
+ filters: [get_filter()],
+ handler: handle_nostr_event)
}
func unsubscribe() {
diff --git a/damus/Models/FollowersModel.swift b/damus/Models/FollowersModel.swift
@@ -40,7 +40,7 @@ class FollowersModel: ObservableObject {
let filter = get_filter()
let filters = [filter]
print_filters(relay_id: "following", filters: [filters])
- self.damus_state.pool.subscribe(sub_id: sub_id, filters: filters, handler: handle_event)
+ self.damus_state.pool.subscribe_to(sub_id: sub_id, filters: filters, handler: handle_event)
}
func unsubscribe() {
diff --git a/damus/Models/FollowingModel.swift b/damus/Models/FollowingModel.swift
@@ -41,7 +41,7 @@ class FollowingModel {
}
let filters = [filter]
print_filters(relay_id: "following", filters: [filters])
- self.damus_state.pool.subscribe(sub_id: sub_id, filters: filters, handler: handle_event)
+ self.damus_state.pool.subscribe_to(sub_id: sub_id, filters: filters, handler: handle_event)
}
func unsubscribe() {
diff --git a/damus/Models/ProfileModel.swift b/damus/Models/ProfileModel.swift
@@ -83,8 +83,8 @@ class ProfileModel: ObservableObject, Equatable {
print("subscribing to profile \(pubkey) with sub_id \(sub_id)")
print_filters(relay_id: "profile", filters: [[text_filter], [profile_filter]])
- damus.pool.subscribe(sub_id: sub_id, filters: [text_filter], handler: handle_event)
- damus.pool.subscribe(sub_id: prof_subid, filters: [profile_filter], handler: handle_event)
+ damus.pool.subscribe_to(sub_id: sub_id, filters: [text_filter], handler: handle_event)
+ damus.pool.subscribe_to(sub_id: prof_subid, filters: [profile_filter], handler: handle_event)
}
func handle_profile_contact_event(_ ev: NostrEvent) {
diff --git a/damus/Models/SearchHomeModel.swift b/damus/Models/SearchHomeModel.swift
@@ -38,7 +38,7 @@ class SearchHomeModel: ObservableObject {
func subscribe() {
loading = true
let to_relays = determine_to_relays(pool: damus_state.pool, filters: damus_state.relay_filters)
- damus_state.pool.subscribe(sub_id: base_subid, filters: [get_base_filter()], handler: handle_event, to: to_relays)
+ damus_state.pool.subscribe_to(sub_id: base_subid, filters: [get_base_filter()], to: to_relays, handler: handle_event)
}
func unsubscribe(to: String? = nil) {
diff --git a/damus/Models/ThreadModel.swift b/damus/Models/ThreadModel.swift
@@ -104,8 +104,8 @@ class ThreadModel: ObservableObject {
print("subscribing to thread \(event.id) with sub_id \(base_subid)")
loading = true
- damus_state.pool.subscribe(sub_id: base_subid, filters: base_filters, handler: handle_event)
- damus_state.pool.subscribe(sub_id: meta_subid, filters: meta_filters, handler: handle_event)
+ damus_state.pool.subscribe_to(sub_id: base_subid, filters: base_filters, handler: handle_event)
+ damus_state.pool.subscribe_to(sub_id: meta_subid, filters: meta_filters, handler: handle_event)
}
func add_event(_ ev: NostrEvent, privkey: String?) {
diff --git a/damus/Models/ZapsModel.swift b/damus/Models/ZapsModel.swift
@@ -29,7 +29,7 @@ class ZapsModel: ObservableObject {
case .note(let note_target):
filter.referenced_ids = [note_target.note_id]
}
- state.pool.subscribe(sub_id: zaps_subid, filters: [filter], handler: handle_event)
+ state.pool.subscribe_to(sub_id: zaps_subid, filters: [filter], handler: handle_event)
}
func unsubscribe() {
diff --git a/damus/Nostr/RelayConnection.swift b/damus/Nostr/RelayConnection.swift
@@ -14,9 +14,15 @@ enum NostrConnectionEvent {
}
final class RelayConnection: WebSocketDelegate {
- private(set) var isConnected = false
- private(set) var isConnecting = false
- private(set) var isReconnecting = false
+ enum State {
+ case notConnected
+ case connecting
+ case reconnecting
+ case connected
+ case failed
+ }
+
+ private(set) var state: State = .notConnected
private(set) var last_connection_attempt: TimeInterval = 0
private lazy var socket = {
@@ -25,38 +31,36 @@ final class RelayConnection: WebSocketDelegate {
socket.delegate = self
return socket
}()
- private var handleEvent: (NostrConnectionEvent) -> ()
- private let url: URL
-
- init(url: URL, handleEvent: @escaping (NostrConnectionEvent) -> ()) {
+ private let eventHandler: (NostrConnectionEvent) -> ()
+ let url: URL
+
+ init(url: URL, eventHandler: @escaping (NostrConnectionEvent) -> ()) {
self.url = url
- self.handleEvent = handleEvent
+ self.eventHandler = eventHandler
}
func reconnect() {
- if isConnected {
- isReconnecting = true
+ if state == .connected {
+ state = .reconnecting
disconnect()
} else {
// we're already disconnected, so just connect
- connect(force: true)
+ connect()
}
}
func connect(force: Bool = false) {
- if !force && (isConnected || isConnecting) {
+ if !force && (state == .connected || state == .connecting) {
return
}
- isConnecting = true
+ state = .connecting
last_connection_attempt = Date().timeIntervalSince1970
socket.connect()
}
func disconnect() {
socket.disconnect()
- isConnected = false
- isConnecting = false
}
func send(_ req: NostrRequest) {
@@ -68,51 +72,52 @@ final class RelayConnection: WebSocketDelegate {
socket.write(string: req)
}
+ private func decodeEvent(_ txt: String) throws -> NostrConnectionEvent {
+ if let ev = decode_nostr_event(txt: txt) {
+ return .nostr_event(ev)
+ } else {
+ throw DecodingError.dataCorrupted(.init(codingPath: [], debugDescription: "decoding event failed"))
+ }
+ }
+
+ @MainActor
+ private func handleEvent(_ event: NostrConnectionEvent) async {
+ eventHandler(event)
+ }
+
// MARK: - WebSocketDelegate
func didReceive(event: WebSocketEvent, client: WebSocket) {
switch event {
case .connected:
- self.isConnected = true
- self.isConnecting = false
+ state = .connected
case .disconnected:
- self.isConnecting = false
- self.isConnected = false
- if self.isReconnecting {
- self.isReconnecting = false
- self.connect()
+ if state == .reconnecting {
+ connect()
+ } else {
+ state = .notConnected
}
case .cancelled, .error:
- self.isConnecting = false
- self.isConnected = false
+ state = .failed
case .text(let txt):
- if txt.count > 2000 {
- DispatchQueue.global(qos: .default).async {
- if let ev = decode_nostr_event(txt: txt) {
- DispatchQueue.main.async {
- self.handleEvent(.nostr_event(ev))
- }
- return
- }
- }
- } else {
- if let ev = decode_nostr_event(txt: txt) {
- handleEvent(.nostr_event(ev))
- return
+ Task(priority: .userInitiated) {
+ do {
+ let event = try decodeEvent(txt)
+ await handleEvent(event)
+ } catch {
+ print("decode failed for \(txt): \(error)")
+ // TODO: trigger event error
}
}
- print("decode failed for \(txt)")
- // TODO: trigger event error
-
default:
break
}
- handleEvent(.ws_event(event))
+ eventHandler(.ws_event(event))
}
}
diff --git a/damus/Nostr/RelayPool.swift b/damus/Nostr/RelayPool.swift
@@ -7,22 +7,6 @@
import Foundation
-struct SubscriptionId: Identifiable, CustomStringConvertible {
- let id: String
-
- var description: String {
- id
- }
-}
-
-struct RelayId: Identifiable, CustomStringConvertible {
- let id: String
-
- var description: String {
- id
- }
-}
-
struct RelayHandler {
let sub_id: String
let callback: (String, NostrConnectionEvent) -> ()
@@ -33,58 +17,58 @@ struct QueuedRequest {
let relay: String
}
-struct NostrRequestId: Equatable, Hashable {
- let relay: String?
- let sub_id: String
-}
-
-class RelayPool {
- var relays: [Relay] = []
- var handlers: [RelayHandler] = []
- var request_queue: [QueuedRequest] = []
- var seen: Set<String> = Set()
- var counts: [String: UInt64] = [:]
+final class RelayPool {
+ enum Constants {
+ /// Used for an exponential backoff algorithm when retrying stale connections
+ /// Each retry attempt will be delayed by raising this base delay to an exponent
+ /// equal to the number of previous retries.
+ static let base_reconnect_delay: TimeInterval = 2
+ static let max_queued_requests = 10
+ static let max_retry_attempts = 3
+ }
+
+ private(set) var relays: [Relay] = []
+ private(set) var handlers: [RelayHandler] = []
+ private var request_queue: [QueuedRequest] = []
+ private(set) var seen: Set<String> = Set()
+ private(set) var counts: [String: UInt64] = [:]
+ private var retry_attempts_per_relay: [URL: Int] = [:]
var descriptors: [RelayDescriptor] {
relays.map { $0.descriptor }
}
var num_connecting: Int {
- return relays.reduce(0) { n, r in n + (r.connection.isConnecting ? 1 : 0) }
+ relays.reduce(0) { n, r in n + (r.connection.state == .connecting ? 1 : 0) }
}
func remove_handler(sub_id: String) {
- self.handlers = handlers.filter { $0.sub_id != sub_id }
+ guard let index = handlers.firstIndex(where: { $0.sub_id == sub_id }) else {
+ return
+ }
+ handlers.remove(at: index)
print("removing \(sub_id) handler, current: \(handlers.count)")
}
func register_handler(sub_id: String, handler: @escaping (String, NostrConnectionEvent) -> ()) {
- for handler in handlers {
- // don't add duplicate handlers
- if handler.sub_id == sub_id {
- return
- }
+ guard !handlers.contains(where: { $0.sub_id == sub_id }) else {
+ return // don't add duplicate handlers
}
- self.handlers.append(RelayHandler(sub_id: sub_id, callback: handler))
+
+ handlers.append(RelayHandler(sub_id: sub_id, callback: handler))
print("registering \(sub_id) handler, current: \(self.handlers.count)")
}
func remove_relay(_ relay_id: String) {
- var i: Int = 0
+ disconnect(from: [relay_id])
- self.disconnect(to: [relay_id])
-
- for relay in relays {
- if relay.id == relay_id {
- relays.remove(at: i)
- break
- }
-
- i += 1
+ if let index = relays.firstIndex(where: { $0.id == relay_id }) {
+ relays.remove(at: index)
}
}
- func add_relay(_ url: URL, info: RelayInfo) throws {
+ @discardableResult
+ func add_relay(_ url: URL, info: RelayInfo) throws -> Relay {
let relay_id = get_relay_id(url)
if get_relay(relay_id) != nil {
throw RelayError.RelayAlreadyExists
@@ -94,40 +78,57 @@ class RelayPool {
}
let descriptor = RelayDescriptor(url: url, info: info)
let relay = Relay(descriptor: descriptor, connection: conn)
- self.relays.append(relay)
+ relays.append(relay)
+ return relay
}
/// This is used to retry dead connections
func connect_to_disconnected() {
- for relay in relays {
+ for relay in relays where !relay.is_broken && relay.connection.state != .connected {
let c = relay.connection
- let is_connecting = c.isReconnecting || c.isConnecting
+ let is_connecting = c.state == .reconnecting || c.state == .connecting
+
+ let retry_attempts = retry_attempts_per_relay[c.url] ?? 0
- if is_connecting && (Date.now.timeIntervalSince1970 - c.last_connection_attempt) > 5 {
- print("stale connection detected (\(relay.descriptor.url.absoluteString)). retrying...")
- relay.connection.connect(force: true)
- } else if relay.is_broken || is_connecting || c.isConnected {
+ let delay = pow(Constants.base_reconnect_delay, TimeInterval(retry_attempts + 1)) // the + 1 helps us avoid a 1-second delay for the first retry
+ if is_connecting && (Date.now.timeIntervalSince1970 - c.last_connection_attempt) > delay {
+ if retry_attempts > Constants.max_retry_attempts {
+ if c.state != .notConnected {
+ c.disconnect()
+ print("exceeded max connection attempts with \(relay.descriptor.url.absoluteString)")
+ relay.mark_broken()
+ }
+ continue
+ } else {
+ print("stale connection detected (\(relay.descriptor.url.absoluteString)). retrying after \(delay) seconds...")
+ c.connect(force: true)
+ retry_attempts_per_relay[c.url] = retry_attempts + 1
+ }
+ } else if is_connecting {
continue
} else {
- relay.connection.reconnect()
+ c.reconnect()
}
-
}
}
- func reconnect(to: [String]? = nil) {
- let relays = to.map{ get_relays($0) } ?? self.relays
- for relay in relays {
+ func reconnect(to relay_ids: [String]? = nil) {
+ let relays: [Relay]
+ if let relay_ids {
+ relays = get_relays(relay_ids)
+ } else {
+ relays = self.relays
+ }
+
+ for relay in relays where !relay.is_broken {
// don't try to reconnect to broken relays
relay.connection.reconnect()
}
}
func mark_broken(_ relay_id: String) {
- for relay in relays {
- relay.mark_broken()
- }
+ relays.first(where: { $0.id == relay_id })?.mark_broken()
}
func connect(to: [String]? = nil) {
@@ -137,8 +138,8 @@ class RelayPool {
}
}
- func disconnect(to: [String]? = nil) {
- let relays = to.map{ get_relays($0) } ?? self.relays
+ private func disconnect(from: [String]? = nil) {
+ let relays = from.map{ get_relays($0) } ?? self.relays
for relay in relays {
relay.connection.disconnect()
}
@@ -146,35 +147,23 @@ class RelayPool {
func unsubscribe(sub_id: String, to: [String]? = nil) {
if to == nil {
- self.remove_handler(sub_id: sub_id)
+ remove_handler(sub_id: sub_id)
}
- self.send(.unsubscribe(sub_id), to: to)
- }
-
- func subscribe(sub_id: String, filters: [NostrFilter], handler: @escaping (String, NostrConnectionEvent) -> (), to: [String]? = nil) {
- register_handler(sub_id: sub_id, handler: handler)
- send(.subscribe(.init(filters: filters, sub_id: sub_id)), to: to)
+ send(.unsubscribe(sub_id), to: to)
}
- func subscribe_to(sub_id: String, filters: [NostrFilter], to: [String]?, handler: @escaping (String, NostrConnectionEvent) -> ()) {
+ func subscribe_to(sub_id: String, filters: [NostrFilter], to: [String]? = nil, handler: @escaping (String, NostrConnectionEvent) -> ()) {
register_handler(sub_id: sub_id, handler: handler)
send(.subscribe(.init(filters: filters, sub_id: sub_id)), to: to)
}
func count_queued(relay: String) -> Int {
- var c = 0
- for request in request_queue {
- if request.relay == relay {
- c += 1
- }
- }
-
- return c
+ request_queue.filter({ $0.relay == relay }).count
}
func queue_req(r: NostrRequest, relay: String) {
let count = count_queued(relay: relay)
- guard count <= 10 else {
+ guard count < Constants.max_queued_requests else {
print("can't queue, too many queued events for \(relay)")
return
}
@@ -184,10 +173,10 @@ class RelayPool {
}
func send(_ req: NostrRequest, to: [String]? = nil) {
- let relays = to.map{ get_relays($0) } ?? self.relays
-
+ let relays = to.map { get_relays($0) } ?? self.relays
+
for relay in relays {
- guard relay.connection.isConnected else {
+ guard relay.connection.state == .connected else {
queue_req(r: req, relay: relay.id)
continue
}
@@ -207,17 +196,14 @@ class RelayPool {
func record_last_pong(relay_id: String, event: NostrConnectionEvent) {
if case .ws_event(let ws_event) = event {
if case .pong = ws_event {
- for relay in relays {
- if relay.id == relay_id {
- relay.last_pong = UInt32(Date.now.timeIntervalSince1970)
- return
- }
+ if let relay = relays.first(where: { $0.id == relay_id }) {
+ relay.last_pong = UInt32(Date.now.timeIntervalSince1970)
}
}
}
}
- func run_queue(_ relay_id: String) {
+ private func run_queue(_ relay_id: String) {
self.request_queue = request_queue.reduce(into: Array<QueuedRequest>()) { (q, req) in
guard req.relay == relay_id else {
q.append(req)
@@ -235,17 +221,14 @@ class RelayPool {
let k = relay_id + nev.id
if !seen.contains(k) {
seen.insert(k)
- if counts[relay_id] == nil {
- counts[relay_id] = 1
- } else {
- counts[relay_id] = (counts[relay_id] ?? 0) + 1
- }
+ let prev_count = counts[relay_id] ?? 0
+ counts[relay_id] = prev_count + 1
}
}
}
}
- func handle_event(relay_id: String, event: NostrConnectionEvent) {
+ private func handle_event(relay_id: String, event: NostrConnectionEvent) {
record_last_pong(relay_id: relay_id, event: event)
record_seen(relay_id: relay_id, event: event)
@@ -265,7 +248,5 @@ class RelayPool {
func add_rw_relay(_ pool: RelayPool, _ url: String) {
let url_ = URL(string: url)!
- try? pool.add_relay(url_, info: RelayInfo.rw)
+ let _ = try? pool.add_relay(url_, info: RelayInfo.rw)
}
-
-
diff --git a/damus/Views/Relays/RelayStatus.swift b/damus/Views/Relays/RelayStatus.swift
@@ -7,6 +7,16 @@
import SwiftUI
+extension RelayConnection.State {
+ var indicatorColor: Color {
+ switch self {
+ case .connected: return .green
+ case .connecting, .reconnecting: return .yellow
+ default: return .red
+ }
+ }
+}
+
struct RelayStatus: View {
let pool: RelayPool
let relay: String
@@ -16,18 +26,10 @@ struct RelayStatus: View {
@State var conn_color: Color = .gray
func update_connection_color() {
- for relay in pool.relays {
- if relay.id == self.relay {
- let c = relay.connection
- if c.isConnected {
- conn_color = .green
- } else if c.isConnecting || c.isReconnecting {
- conn_color = .yellow
- } else {
- conn_color = .red
- }
- }
+ guard let relay = pool.relays.first(where: { $0.id == relay }) else {
+ return
}
+ conn_color = relay.connection.state.indicatorColor
}
var body: some View {
diff --git a/damusTests/RelayPoolTests.swift b/damusTests/RelayPoolTests.swift
@@ -0,0 +1,170 @@
+//
+// RelayPoolTests.swift
+// damusTests
+//
+// Created by Bryan Montz on 2/25/23.
+//
+
+import XCTest
+@testable import damus
+
+final class RelayPoolTests: XCTestCase {
+
+ private let fakeRelayURL = URL(string: "wss://some.relay.com")!
+
+ private func setUpPool() throws -> RelayPool {
+ let pool = RelayPool()
+ XCTAssertTrue(pool.relays.isEmpty)
+
+ try pool.add_relay(fakeRelayURL, info: RelayInfo.rw)
+ return pool
+ }
+
+ // MARK: - Relay Add/Remove
+
+ func testAddRelay() throws {
+ let pool = try setUpPool()
+
+ XCTAssertEqual(pool.relays.count, 1)
+ }
+
+ func testRejectDuplicateRelay() throws {
+ let pool = try setUpPool()
+
+ XCTAssertThrowsError(try pool.add_relay(fakeRelayURL, info: RelayInfo.rw)) { error in
+ XCTAssertEqual(error as? RelayError, RelayError.RelayAlreadyExists)
+ }
+ }
+
+ func testRemoveRelay() throws {
+ let pool = try setUpPool()
+
+ XCTAssertEqual(pool.relays.count, 1)
+
+ pool.remove_relay(fakeRelayURL.absoluteString)
+
+ XCTAssertTrue(pool.relays.isEmpty)
+ }
+
+ func testMarkRelayBroken() throws {
+ let pool = try setUpPool()
+
+ let relay = try XCTUnwrap(pool.relays.first(where: { $0.id == fakeRelayURL.absoluteString }))
+ XCTAssertFalse(relay.is_broken)
+
+ pool.mark_broken(fakeRelayURL.absoluteString)
+ XCTAssertTrue(relay.is_broken)
+ }
+
+ func testGetRelay() throws {
+ let pool = try setUpPool()
+ XCTAssertNotNil(pool.get_relay(fakeRelayURL.absoluteString))
+ }
+
+ func testGetRelays() throws {
+ let pool = try setUpPool()
+
+ try pool.add_relay(URL(string: "wss://second.relay.com")!, info: RelayInfo.rw)
+
+ let allRelays = pool.get_relays([fakeRelayURL.absoluteString, "wss://second.relay.com"])
+ XCTAssertEqual(allRelays.count, 2)
+
+ let relays = pool.get_relays(["wss://second.relay.com"])
+ XCTAssertEqual(relays.count, 1)
+ }
+
+ // MARK: - Handler Add/Remove
+
+ private func setUpPoolWithHandler(sub_id: String) -> RelayPool {
+ let pool = RelayPool()
+ XCTAssertTrue(pool.handlers.isEmpty)
+
+ pool.register_handler(sub_id: sub_id) { _, _ in }
+ return pool
+ }
+
+ func testAddHandler() {
+ let sub_id = "123"
+ let pool = setUpPoolWithHandler(sub_id: sub_id)
+
+ XCTAssertEqual(pool.handlers.count, 1)
+ }
+
+ func testRejectDuplicateHandler() {
+ let sub_id = "123"
+ let pool = setUpPoolWithHandler(sub_id: sub_id)
+ XCTAssertEqual(pool.handlers.count, 1)
+
+ pool.register_handler(sub_id: sub_id) { _, _ in }
+
+ XCTAssertEqual(pool.handlers.count, 1)
+ }
+
+ func testRemoveHandler() {
+ let sub_id = "123"
+ let pool = setUpPoolWithHandler(sub_id: sub_id)
+ XCTAssertEqual(pool.handlers.count, 1)
+ pool.remove_handler(sub_id: sub_id)
+ XCTAssertTrue(pool.handlers.isEmpty)
+ }
+
+ func testRecordLastPong() throws {
+ let pool = try setUpPool()
+ let relayId = fakeRelayURL.absoluteString
+ let relay = try XCTUnwrap(pool.get_relay(relayId))
+ XCTAssertEqual(relay.last_pong, 0)
+
+ let pongEvent = NostrConnectionEvent.ws_event(.pong(nil))
+ pool.record_last_pong(relay_id: relayId, event: pongEvent)
+ XCTAssertNotEqual(relay.last_pong, 0)
+ }
+
+ func testSeenAndCounts() throws {
+ let pool = try setUpPool()
+
+ XCTAssertTrue(pool.seen.isEmpty)
+ XCTAssertTrue(pool.counts.isEmpty)
+
+ let event = NostrEvent(id: "123", content: "", pubkey: "")
+ let connectionEvent = NostrConnectionEvent.nostr_event(NostrResponse.event("", event))
+ let relay_id = fakeRelayURL.absoluteString
+ pool.record_seen(relay_id: relay_id, event: connectionEvent)
+
+ XCTAssertTrue(pool.seen.contains("wss://some.relay.com123"))
+
+ XCTAssertEqual(pool.counts[relay_id], 1)
+
+ pool.record_seen(relay_id: relay_id, event: connectionEvent)
+ // don't count the same event twice
+ XCTAssertEqual(pool.counts[relay_id], 1)
+ }
+
+ func testAddQueuedRequest() throws {
+ let pool = try setUpPool()
+
+ XCTAssertEqual(pool.count_queued(relay: fakeRelayURL.absoluteString), 0)
+
+ let req = NostrRequest.unsubscribe("")
+ pool.queue_req(r: req, relay: fakeRelayURL.absoluteString)
+
+ XCTAssertEqual(pool.count_queued(relay: fakeRelayURL.absoluteString), 1)
+ }
+
+ func testRejectTooManyQueuedRequests() throws {
+ let pool = try setUpPool()
+
+ let maxRequests = RelayPool.Constants.max_queued_requests
+ for _ in 0..<maxRequests {
+ let req = NostrRequest.unsubscribe("")
+ pool.queue_req(r: req, relay: fakeRelayURL.absoluteString)
+ }
+
+ XCTAssertEqual(pool.count_queued(relay: fakeRelayURL.absoluteString), maxRequests)
+
+ // try to add one beyond the maximum
+ let req = NostrRequest.unsubscribe("")
+ pool.queue_req(r: req, relay: fakeRelayURL.absoluteString)
+
+ XCTAssertEqual(pool.count_queued(relay: fakeRelayURL.absoluteString), maxRequests)
+ }
+}