From a87015598ceb3fe739b7572d32f936f956f83e0e Mon Sep 17 00:00:00 2001 From: Calvin Rose Date: Fri, 24 Apr 2020 16:18:31 -0500 Subject: [PATCH] Make janet_equals and janet_compare non recursive This makes these operatios use constant stack space rather than linear stackspace given the size of the inputs. This is important to prevent certain parser input from causing a stack overflow - in general, we try to avoid unbounded recursion. --- src/boot/system_test.c | 15 +++ src/core/state.h | 11 ++ src/core/struct.c | 48 --------- src/core/tuple.c | 39 ------- src/core/value.c | 202 +++++++++++++++++++++++++++-------- src/core/vm.c | 5 + src/include/janet.h | 7 +- test/fuzzers/fuzz_dostring.c | 5 +- 8 files changed, 196 insertions(+), 136 deletions(-) diff --git a/src/boot/system_test.c b/src/boot/system_test.c index 3a1de65d..f3e9af78 100644 --- a/src/boot/system_test.c +++ b/src/boot/system_test.c @@ -50,5 +50,20 @@ int system_test() { assert(janet_equals(janet_cstringv("a string."), janet_cstringv("a string."))); assert(janet_equals(janet_csymbolv("sym"), janet_csymbolv("sym"))); + Janet *t1 = janet_tuple_begin(3); + t1[0] = janet_wrap_nil(); + t1[1] = janet_wrap_integer(4); + t1[2] = janet_cstringv("hi"); + Janet tuple1 = janet_wrap_tuple(janet_tuple_end(t1)); + + Janet *t2 = janet_tuple_begin(3); + t2[0] = janet_wrap_nil(); + t2[1] = janet_wrap_integer(4); + t2[2] = janet_cstringv("hi"); + Janet tuple2 = janet_wrap_tuple(janet_tuple_end(t2)); + + assert(janet_equals(tuple1, tuple2)); + + return 0; } diff --git a/src/core/state.h b/src/core/state.h index 0ed8b31d..649b3785 100644 --- a/src/core/state.h +++ b/src/core/state.h @@ -79,6 +79,17 @@ extern JANET_THREAD_LOCAL JanetScratch **janet_scratch_mem; extern JANET_THREAD_LOCAL size_t janet_scratch_cap; extern JANET_THREAD_LOCAL size_t janet_scratch_len; +/* Recursionless traversal of data structures */ +typedef struct { + JanetGCObject *self; + JanetGCObject *other; + int32_t index; + int32_t index2; +} JanetTraversalNode; +extern JANET_THREAD_LOCAL JanetTraversalNode *janet_vm_traversal; +extern JANET_THREAD_LOCAL JanetTraversalNode *janet_vm_traversal_top; +extern JANET_THREAD_LOCAL JanetTraversalNode *janet_vm_traversal_base; + /* Setup / teardown */ #ifdef JANET_THREADS void janet_threads_init(void); diff --git a/src/core/struct.c b/src/core/struct.c index 18356925..ef34b99e 100644 --- a/src/core/struct.c +++ b/src/core/struct.c @@ -167,51 +167,3 @@ JanetTable *janet_struct_to_table(const JanetKV *st) { } return table; } - -/* Check if two structs are equal */ -int janet_struct_equal(const JanetKV *lhs, const JanetKV *rhs) { - int32_t index; - int32_t llen = janet_struct_capacity(lhs); - int32_t rlen = janet_struct_capacity(rhs); - int32_t lhash = janet_struct_hash(lhs); - int32_t rhash = janet_struct_hash(rhs); - if (llen != rlen) - return 0; - if (lhash != rhash) - return 0; - for (index = 0; index < llen; index++) { - const JanetKV *l = lhs + index; - const JanetKV *r = rhs + index; - if (!janet_equals(l->key, r->key)) - return 0; - if (!janet_equals(l->value, r->value)) - return 0; - } - return 1; -} - -/* Compare structs */ -int janet_struct_compare(const JanetKV *lhs, const JanetKV *rhs) { - int32_t i; - int32_t lhash = janet_struct_hash(lhs); - int32_t rhash = janet_struct_hash(rhs); - int32_t llen = janet_struct_capacity(lhs); - int32_t rlen = janet_struct_capacity(rhs); - if (llen < rlen) - return -1; - if (llen > rlen) - return 1; - if (lhash < rhash) - return -1; - if (lhash > rhash) - return 1; - for (i = 0; i < llen; ++i) { - const JanetKV *l = lhs + i; - const JanetKV *r = rhs + i; - int comp = janet_compare(l->key, r->key); - if (comp != 0) return comp; - comp = janet_compare(l->value, r->value); - if (comp != 0) return comp; - } - return 0; -} diff --git a/src/core/tuple.c b/src/core/tuple.c index 9a3415a2..cd1b1aa4 100644 --- a/src/core/tuple.c +++ b/src/core/tuple.c @@ -53,45 +53,6 @@ const Janet *janet_tuple_n(const Janet *values, int32_t n) { return janet_tuple_end(t); } -/* Check if two tuples are equal */ -int janet_tuple_equal(const Janet *lhs, const Janet *rhs) { - int32_t index; - int32_t llen = janet_tuple_length(lhs); - int32_t rlen = janet_tuple_length(rhs); - int32_t lhash = janet_tuple_hash(lhs); - int32_t rhash = janet_tuple_hash(rhs); - if (lhash == 0) - lhash = janet_tuple_hash(lhs) = janet_array_calchash(lhs, llen); - if (rhash == 0) - rhash = janet_tuple_hash(rhs) = janet_array_calchash(rhs, rlen); - if (lhash != rhash) - return 0; - if (llen != rlen) - return 0; - for (index = 0; index < llen; index++) { - if (!janet_equals(lhs[index], rhs[index])) - return 0; - } - return 1; -} - -/* Compare tuples */ -int janet_tuple_compare(const Janet *lhs, const Janet *rhs) { - int32_t i; - int32_t llen = janet_tuple_length(lhs); - int32_t rlen = janet_tuple_length(rhs); - int32_t count = llen < rlen ? llen : rlen; - for (i = 0; i < count; ++i) { - int comp = janet_compare(lhs[i], rhs[i]); - if (comp != 0) return comp; - } - if (llen < rlen) - return -1; - else if (llen > rlen) - return 1; - return 0; -} - /* C Functions */ static Janet cfun_tuple_brackets(int32_t argc, Janet *argv) { diff --git a/src/core/value.c b/src/core/value.c index fff30ac7..8488f786 100644 --- a/src/core/value.c +++ b/src/core/value.c @@ -23,9 +23,81 @@ #ifndef JANET_AMALG #include "features.h" #include "util.h" +#include "state.h" +#include "gc.h" #include #endif +JANET_THREAD_LOCAL JanetTraversalNode *janet_vm_traversal = NULL; +JANET_THREAD_LOCAL JanetTraversalNode *janet_vm_traversal_top = NULL; +JANET_THREAD_LOCAL JanetTraversalNode *janet_vm_traversal_base = NULL; + +static void push_traversal_node(void *lhs, void *rhs) { + JanetTraversalNode node; + node.self = (JanetGCObject *) lhs; + node.other = (JanetGCObject *) rhs; + node.index = 0; + node.index2 = 0; + if (janet_vm_traversal + 1 >= janet_vm_traversal_top) { + size_t oldsize = janet_vm_traversal - janet_vm_traversal_base; + size_t newsize = 2 * oldsize + 1; + if (newsize < 128) { + newsize = 128; + } + JanetTraversalNode *tn = realloc(janet_vm_traversal_base, newsize * sizeof(JanetTraversalNode)); + if (tn == NULL) { + JANET_OUT_OF_MEMORY; + } + janet_vm_traversal_base = tn; + janet_vm_traversal_top = janet_vm_traversal_base + newsize; + janet_vm_traversal = janet_vm_traversal_base + oldsize; + } + *(++janet_vm_traversal) = node; +} + +/* Used for travsersing structs and tuples without recursion */ +static int traversal_next(Janet *x, Janet *y) { + JanetTraversalNode *t = janet_vm_traversal; + while (t && t > janet_vm_traversal_base) { + JanetGCObject *self = t->self; + JanetTupleHead *tself = (JanetTupleHead *)self; + JanetStructHead *sself = (JanetStructHead *)self; + JanetGCObject *other = t->other; + JanetTupleHead *tother = (JanetTupleHead *)other; + JanetStructHead *sother = (JanetStructHead *)other; + if ((self->flags & JANET_MEM_TYPEBITS) == JANET_MEMORY_TUPLE) { + /* Node is a tuple at index t->index */ + if (t->index < tself->length) { + int32_t index = t->index++; + *x = tself->data[index]; + *y = tother->data[index]; + janet_vm_traversal = t; + return 1; + } + } else { + /* Node is a struct at index t->index: if t->index2 is true, we should return the values. */ + if (t->index2) { + t->index2 = 0; + int32_t index = t->index++; + *x = sself->data[index].value; + *y = sother->data[index].value; + janet_vm_traversal = t; + return 1; + } + for (int32_t i = t->index; i < sself->capacity; i++) { + t->index2 = 1; + *x = sself->data[t->index].key; + *y = sother->data[t->index].key; + janet_vm_traversal = t; + return 1; + } + } + t--; + } + janet_vm_traversal = t; + return 0; +} + /* * Define a number of functions that can be used internally on ANY Janet. */ @@ -111,41 +183,51 @@ static int janet_compare_abstract(JanetAbstract xx, JanetAbstract yy) { return xt->compare(xx, yy); } -/* Check if two values are equal. This is strict equality with no conversion. */ int janet_equals(Janet x, Janet y) { - int result = 0; - if (janet_type(x) != janet_type(y)) { - result = 0; - } else { + janet_vm_traversal = janet_vm_traversal_base; + do { + if (janet_type(x) != janet_type(y)) return 0; switch (janet_type(x)) { case JANET_NIL: - result = 1; break; case JANET_BOOLEAN: - result = (janet_unwrap_boolean(x) == janet_unwrap_boolean(y)); + if (janet_unwrap_boolean(x) != janet_unwrap_boolean(y)) return 0; break; case JANET_NUMBER: - result = (janet_unwrap_number(x) == janet_unwrap_number(y)); + if (janet_unwrap_number(x) != janet_unwrap_number(y)) return 0; break; case JANET_STRING: - result = janet_string_equal(janet_unwrap_string(x), janet_unwrap_string(y)); - break; - case JANET_TUPLE: - result = janet_tuple_equal(janet_unwrap_tuple(x), janet_unwrap_tuple(y)); - break; - case JANET_STRUCT: - result = janet_struct_equal(janet_unwrap_struct(x), janet_unwrap_struct(y)); + if (!janet_string_equal(janet_unwrap_string(x), janet_unwrap_string(y))) return 0; break; case JANET_ABSTRACT: - result = !janet_compare_abstract(janet_unwrap_abstract(x), janet_unwrap_abstract(y)); + if (janet_compare_abstract(janet_unwrap_abstract(x), janet_unwrap_abstract(y))) return 0; break; default: - /* compare pointers */ - result = (janet_unwrap_pointer(x) == janet_unwrap_pointer(y)); + if (janet_unwrap_pointer(x) != janet_unwrap_pointer(y)) return 0; break; + case JANET_TUPLE: { + const Janet *t1 = janet_unwrap_tuple(x); + const Janet *t2 = janet_unwrap_tuple(y); + if (t1 == t2) break; + if (janet_tuple_hash(t1) != janet_tuple_hash(t2)) return 0; + if (janet_tuple_length(t1) != janet_tuple_length(t2)) return 0; + push_traversal_node(janet_tuple_head(t1), janet_tuple_head(t2)); + break; + } + break; + case JANET_STRUCT: { + const JanetKV *s1 = janet_unwrap_struct(x); + const JanetKV *s2 = janet_unwrap_struct(y); + if (s1 == s2) break; + if (janet_struct_hash(s1) != janet_struct_hash(s2)) return 0; + if (janet_struct_length(s1) != janet_struct_length(s2)) return 0; + push_traversal_node(janet_struct_head(s1), janet_struct_head(s2)); + break; + } + break; } - } - return result; + } while (traversal_next(&x, &y)); + return 1; } /* Computes a hash value for a function */ @@ -201,38 +283,74 @@ int32_t janet_hash(Janet x) { * If y is less, returns 1. All types are comparable * and should have strict ordering, excepts NaNs. */ int janet_compare(Janet x, Janet y) { - if (janet_type(x) == janet_type(y)) { - switch (janet_type(x)) { + janet_vm_traversal = janet_vm_traversal_base; + do { + JanetType tx = janet_type(x); + JanetType ty = janet_type(y); + if (tx != ty) return tx < ty ? -1 : 1; + switch (tx) { case JANET_NIL: - return 0; - case JANET_BOOLEAN: - return janet_unwrap_boolean(x) - janet_unwrap_boolean(y); + break; + case JANET_BOOLEAN: { + int diff = janet_unwrap_boolean(x) - janet_unwrap_boolean(y); + if (diff) return diff; + break; + } case JANET_NUMBER: { double xx = janet_unwrap_number(x); double yy = janet_unwrap_number(y); - return xx == yy - ? 0 - : (xx < yy) ? -1 : 1; + if (xx == yy) { + break; + } else { + return (xx < yy) ? -1 : 1; + } } case JANET_STRING: case JANET_SYMBOL: - case JANET_KEYWORD: - return janet_string_compare(janet_unwrap_string(x), janet_unwrap_string(y)); - case JANET_TUPLE: - return janet_tuple_compare(janet_unwrap_tuple(x), janet_unwrap_tuple(y)); - case JANET_STRUCT: - return janet_struct_compare(janet_unwrap_struct(x), janet_unwrap_struct(y)); - case JANET_ABSTRACT: - return janet_compare_abstract(janet_unwrap_abstract(x), janet_unwrap_abstract(y)); - default: - if (janet_unwrap_string(x) == janet_unwrap_string(y)) { - return 0; + case JANET_KEYWORD: { + int diff = janet_string_compare(janet_unwrap_string(x), janet_unwrap_string(y)); + if (diff) return diff; + break; + } + case JANET_ABSTRACT: { + int diff = janet_compare_abstract(janet_unwrap_abstract(x), janet_unwrap_abstract(y)); + if (diff) return diff; + break; + } + default: { + if (janet_unwrap_pointer(x) == janet_unwrap_pointer(y)) { + break; } else { - return janet_unwrap_string(x) > janet_unwrap_string(y) ? 1 : -1; + return janet_unwrap_pointer(x) > janet_unwrap_pointer(y) ? 1 : -1; } + } + case JANET_TUPLE: { + const Janet *lhs = janet_unwrap_tuple(x); + const Janet *rhs = janet_unwrap_tuple(y); + int32_t llen = janet_tuple_length(lhs); + int32_t rlen = janet_tuple_length(rhs); + if (llen < rlen) return -1; + if (llen > rlen) return 1; + push_traversal_node(janet_tuple_head(lhs), janet_tuple_head(rhs)); + break; + } + case JANET_STRUCT: { + const JanetKV *lhs = janet_unwrap_struct(x); + const JanetKV *rhs = janet_unwrap_struct(y); + int32_t llen = janet_struct_capacity(lhs); + int32_t rlen = janet_struct_capacity(rhs); + int32_t lhash = janet_struct_hash(lhs); + int32_t rhash = janet_struct_hash(rhs); + if (llen < rlen) return -1; + if (llen > rlen) return 1; + if (lhash < rhash) return -1; + if (lhash > rhash) return 1; + push_traversal_node(janet_struct_head(lhs), janet_struct_head(rhs)); + break; + } } - } - return (janet_type(x) < janet_type(y)) ? -1 : 1; + } while (traversal_next(&x, &y)); + return 0; } static int32_t getter_checkint(Janet key, int32_t max) { diff --git a/src/core/vm.c b/src/core/vm.c index 4355d650..b4e2c557 100644 --- a/src/core/vm.c +++ b/src/core/vm.c @@ -1406,6 +1406,10 @@ int janet_init(void) { janet_vm_abstract_registry = janet_table(0); janet_gcroot(janet_wrap_table(janet_vm_registry)); janet_gcroot(janet_wrap_table(janet_vm_abstract_registry)); + /* Traversal */ + janet_vm_traversal = NULL; + janet_vm_traversal_base = NULL; + janet_vm_traversal_top = NULL; /* Core env */ janet_vm_core_env = NULL; /* Seed RNG */ @@ -1428,6 +1432,7 @@ void janet_deinit(void) { janet_vm_registry = NULL; janet_vm_abstract_registry = NULL; janet_vm_core_env = NULL; + free(janet_vm_traversal_base); #ifdef JANET_THREADS janet_threads_deinit(); #endif diff --git a/src/include/janet.h b/src/include/janet.h index cbf8e58f..0a3f7f43 100644 --- a/src/include/janet.h +++ b/src/include/janet.h @@ -1209,6 +1209,7 @@ JANET_API void janet_buffer_push_u64(JanetBuffer *buffer, uint64_t x); #define JANET_TUPLE_FLAG_BRACKETCTOR 0x10000 #define janet_tuple_head(t) ((JanetTupleHead *)((char *)t - offsetof(JanetTupleHead, data))) +#define janet_tuple_from_head(gcobject) ((const Janet *)((char *)gcobject + offsetof(JanetTupleHead, data))) #define janet_tuple_length(t) (janet_tuple_head(t)->length) #define janet_tuple_hash(t) (janet_tuple_head(t)->hash) #define janet_tuple_sm_line(t) (janet_tuple_head(t)->sm_line) @@ -1217,8 +1218,6 @@ JANET_API void janet_buffer_push_u64(JanetBuffer *buffer, uint64_t x); JANET_API Janet *janet_tuple_begin(int32_t length); JANET_API JanetTuple janet_tuple_end(Janet *tuple); JANET_API JanetTuple janet_tuple_n(const Janet *values, int32_t n); -JANET_API int janet_tuple_equal(JanetTuple lhs, JanetTuple rhs); -JANET_API int janet_tuple_compare(JanetTuple lhs, JanetTuple rhs); /* String/Symbol functions */ #define janet_string_head(s) ((JanetStringHead *)((char *)s - offsetof(JanetStringHead, data))) @@ -1256,6 +1255,7 @@ JANET_API JanetSymbol janet_symbol_gen(void); /* Structs */ #define janet_struct_head(t) ((JanetStructHead *)((char *)t - offsetof(JanetStructHead, data))) +#define janet_struct_from_head(t) ((const JanetKV *)((char *)gcobject + offsetof(JanetStructHead, data))) #define janet_struct_length(t) (janet_struct_head(t)->length) #define janet_struct_capacity(t) (janet_struct_head(t)->capacity) #define janet_struct_hash(t) (janet_struct_head(t)->hash) @@ -1264,8 +1264,6 @@ JANET_API void janet_struct_put(JanetKV *st, Janet key, Janet value); JANET_API JanetStruct janet_struct_end(JanetKV *st); JANET_API Janet janet_struct_get(JanetStruct st, Janet key); JANET_API JanetTable *janet_struct_to_table(JanetStruct st); -JANET_API int janet_struct_equal(JanetStruct lhs, JanetStruct rhs); -JANET_API int janet_struct_compare(JanetStruct lhs, JanetStruct rhs); JANET_API const JanetKV *janet_struct_find(JanetStruct st, Janet key); /* Table functions */ @@ -1298,6 +1296,7 @@ JANET_API const JanetKV *janet_dictionary_next(const JanetKV *kvs, int32_t cap, /* Abstract */ #define janet_abstract_head(u) ((JanetAbstractHead *)((char *)u - offsetof(JanetAbstractHead, data))) +#define janet_abstract_from_head(gcobject) ((JanetAbstract)((char *)gcobject + offsetof(JanetAbstractHead, data))) #define janet_abstract_type(u) (janet_abstract_head(u)->type) #define janet_abstract_size(u) (janet_abstract_head(u)->size) JANET_API void *janet_abstract_begin(const JanetAbstractType *type, size_t size); diff --git a/test/fuzzers/fuzz_dostring.c b/test/fuzzers/fuzz_dostring.c index 17834432..9dedd91c 100644 --- a/test/fuzzers/fuzz_dostring.c +++ b/test/fuzzers/fuzz_dostring.c @@ -10,9 +10,8 @@ int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size) { /* fuzz the parser */ JanetParser parser; janet_parser_init(&parser); - for (int i=0, done = 0; i < size; i++) - { - switch (janet_parser_status(&parser)) { + for (int i = 0, done = 0; i < size; i++) { + switch (janet_parser_status(&parser)) { case JANET_PARSE_DEAD: case JANET_PARSE_ERROR: done = 1;