commit d3b9705c5200f40772fb2a883ff9daff4c20d68c
parent 603bc30e1b191b21433752320d67337a3bbcc545
Author: William Casarin <jb55@jb55.com>
Date: Sun, 20 Mar 2022 18:09:40 -0700
lnsocket/js: use read_all
We need this if we want to load large requests reliably
Signed-off-by: William Casarin <jb55@jb55.com>
Diffstat:
3 files changed, 85 insertions(+), 47 deletions(-)
diff --git a/lnsocket.c b/lnsocket.c
@@ -205,37 +205,19 @@ int lnsocket_recv(struct lnsocket *ln, u16 *msg_type, unsigned char **payload, u
return 1;
}
-int EXPORT lnsocket_decrypt(struct lnsocket *ln, unsigned char *packet, int len)
+int EXPORT lnsocket_decrypt(struct lnsocket *ln, unsigned char *packet, int size)
{
- struct cursor read, enc, dec;
- u8 hdr[18];
- u16 size;
-
- make_cursor(packet, packet + len, &read);
- if (!cursor_pull(&read, hdr, 18)) {
- return note_error(&ln->errs, "not enough bytes in header, have %d, need 18", len);
- }
-
- if (!cryptomsg_decrypt_header(&ln->crypto_state, hdr, &size)) {
- return note_error(&ln->errs,
- "Failed hdr decrypt with rn=%"PRIu64,
- ln->crypto_state.rn-1);
- }
+ struct cursor enc, dec;
+ make_cursor(packet, packet + size, &enc);
reset_cursor(&ln->msgbuf);
- if (!cursor_slice(&ln->msgbuf, &dec, size))
+ if (!cursor_slice(&ln->msgbuf, &dec, size - 16))
return note_error(&ln->errs, "out of memory: %d + %d = %d > %d",
ln->msgbuf.end - ln->msgbuf.p, size,
ln->msgbuf.end - ln->msgbuf.p + size,
MSGBUF_MEM
);
- if (size + 16 != read.end - read.p)
- return note_error(&ln->errs, "expected enc body size of %d, got %d",
- size + 16, read.end - read.p);
-
- make_cursor(read.p, read.end, &enc);
-
if (!cryptomsg_decrypt_body(&ln->crypto_state,
enc.start, enc.end - enc.start,
dec.start, dec.end - dec.start))
@@ -244,6 +226,18 @@ int EXPORT lnsocket_decrypt(struct lnsocket *ln, unsigned char *packet, int len)
return dec.end - dec.start;
}
+// this is used in js
+int EXPORT lnsocket_decrypt_header(struct lnsocket *ln, unsigned char *hdr)
+{
+ u16 size;
+ if (!cryptomsg_decrypt_header(&ln->crypto_state, hdr, &size))
+ return note_error(&ln->errs,
+ "Failed hdr decrypt with rn=%"PRIu64,
+ ln->crypto_state.rn-1);
+
+ return size;
+}
+
int lnsocket_read(struct lnsocket *ln, unsigned char **buf, unsigned short *len)
{
struct cursor enc, dec;
diff --git a/lnsocket.h b/lnsocket.h
@@ -90,5 +90,6 @@ void EXPORT lnsocket_destroy(struct lnsocket *);
void EXPORT lnsocket_print_errors(struct lnsocket *);
int EXPORT lnsocket_make_default_initmsg(unsigned char *msgbuf, int buflen);
int EXPORT lnsocket_encrypt(struct lnsocket *ln, const unsigned char *msg, unsigned short msglen);
+int EXPORT lnsocket_decrypt_header(struct lnsocket *ln, unsigned char *hdr);
#endif /* LNSOCKET_H */
diff --git a/lnsocket_lib.js b/lnsocket_lib.js
@@ -66,6 +66,7 @@ async function lnsocket_init() {
const lnsocket_destroy = module.cwrap("lnsocket_destroy", "number")
const lnsocket_encrypt = module.cwrap("lnsocket_encrypt", "number", ["int", "array", "int", "int"])
const lnsocket_decrypt = module.cwrap("lnsocket_decrypt", "number", ["int", "array", "int"])
+ const lnsocket_decrypt_header = module.cwrap("lnsocket_decrypt_header", "number", ["number", "array"])
const lnsocket_msgbuf = module.cwrap("lnsocket_msgbuf", "number", ["int"])
const lnsocket_act_one = module.cwrap("lnsocket_act_one", "number", ["number", "string"])
const lnsocket_act_two = module.cwrap("lnsocket_act_two", "number", ["number", "array"])
@@ -77,7 +78,9 @@ async function lnsocket_init() {
function concat_u8_arrays(arrays) {
// sum of individual array lengths
- let totalLength = arrays.reduce((acc, value) => acc + value.length, 0);
+ let totalLength = arrays.reduce((acc, value) =>
+ acc + (value.length || value.byteLength)
+ , 0);
if (!arrays.length) return null;
@@ -85,25 +88,17 @@ async function lnsocket_init() {
let length = 0;
for (let array of arrays) {
- result.set(array, length);
- length += array.length;
+ if (array instanceof ArrayBuffer)
+ result.set(new Uint8Array(array), length);
+ else
+ result.set(array, length);
+
+ length += (array.length || array.byteLength);
}
return result;
}
- function queue_recv(queue) {
- return new Promise((resolve, reject) => {
- const checker = setInterval(() => {
- const val = queue.shift()
- if (val) {
- clearInterval(checker)
- resolve(val)
- }
- }, 5);
- })
- }
-
function parse_msgtype(buf) {
return buf[0] << 8 | buf[1]
}
@@ -123,6 +118,22 @@ async function lnsocket_init() {
this.ln = lnsocket_create()
}
+ LNSocket.prototype.queue_recv = function() {
+ let self = this
+ return new Promise((resolve, reject) => {
+ const checker = setInterval(() => {
+ const val = self.queue.shift()
+ if (val) {
+ clearInterval(checker)
+ resolve(val)
+ } else if (!self.connected) {
+ clearInterval(checker)
+ reject()
+ }
+ }, 5);
+ })
+ }
+
LNSocket.prototype.print_errors = function _lnsocket_print_errors() {
lnsocket_print_errors(this.ln)
}
@@ -152,8 +163,8 @@ async function lnsocket_init() {
const act1 = this.act_one_data(node_id)
this.ws.send(act1)
- const act2 = await this.read_clear()
- if (act2.byteLength != ACT_TWO_SIZE) {
+ const act2 = await this.read_all(ACT_TWO_SIZE)
+ if (act2.length != ACT_TWO_SIZE) {
throw new Error(`expected act2 to be ${ACT_TWO_SIZE} long, got ${act2.length}`)
}
const act3 = this.act_two(act2)
@@ -165,16 +176,51 @@ async function lnsocket_init() {
await this.perform_init()
}
+ LNSocket.prototype.read_all = async function read_all(n) {
+ let count = 0
+ let chunks = []
+ if (!this.connected)
+ throw new Error("read_all: not connected")
+ while (true) {
+ const res = await this.queue_recv()
+ count += res.byteLength
+ if (count > n) {
+ //console.log("count %d > n %d, queue: %d", count, n, this.queue.length)
+ chunks.push(res.slice(0, n))
+ this.queue.unshift(res.slice(n))
+ break
+ } else if (count === n) {
+ //console.log("count %d === n %d, queue: %d", count, n, this.queue.length)
+ chunks.push(res)
+ break
+ } else {
+ //console.log("count %d < n %d, queue: %d", count, n, this.queue.length)
+ chunks.push(res)
+ }
+ }
+
+ return concat_u8_arrays(chunks)
+ }
+
+ LNSocket.prototype.read_header = async function read_header() {
+ const header = await this.read_all(18)
+ if (header.length != 18)
+ throw new Error("Failed to read header")
+ return lnsocket_decrypt_header(this.ln, header)
+ }
+
LNSocket.prototype.rpc = async function lnsocket_rpc(opts) {
const msg = this.make_commando_msg(opts)
this.write(msg)
- return JSON.parse(await this.read_all_rpc())
+ const res = await this.read_all_rpc()
+ return JSON.parse(res)
}
LNSocket.prototype.recv = async function lnsocket_recv() {
const msg = await this.read()
const msgtype = parse_msgtype(msg.slice(0,2))
- return [msgtype, msg.slice(2)]
+ const res = [msgtype, msg.slice(2)]
+ return res
}
LNSocket.prototype.read_all_rpc = async function read_all_rpc() {
@@ -249,13 +295,10 @@ async function lnsocket_init() {
this.ws.send(this.encrypt(dat))
}
- LNSocket.prototype.read_clear = async function _lnsocket_read() {
- return (await queue_recv(this.queue))
- }
-
LNSocket.prototype.read = async function _lnsocket_read() {
- const enc = await this.read_clear()
- return this.decrypt(new Uint8Array(enc))
+ const size = await this.read_header()
+ const enc = await this.read_all(size+16)
+ return this.decrypt(enc)
}
LNSocket.prototype.make_default_initmsg = function _lnsocket_make_default_initmsg() {