1
0
mirror of https://github.com/janet-lang/janet synced 2024-11-28 19:19:53 +00:00

New unmarshal proposal.

Gives more control over unmarshalling
abstract types. This should also
make it possible/easy to write abstract types that cannot
cause unmarshal to segfault.
This commit is contained in:
Calvin Rose 2019-12-06 22:12:18 -06:00
parent 4a0ee5df7d
commit 546669082f
6 changed files with 69 additions and 30 deletions

View File

@ -40,11 +40,14 @@ static Janet it_s64_get(void *p, Janet key);
static Janet it_u64_get(void *p, Janet key); static Janet it_u64_get(void *p, Janet key);
static void int64_marshal(void *p, JanetMarshalContext *ctx) { static void int64_marshal(void *p, JanetMarshalContext *ctx) {
janet_marshal_abstract(ctx, p);
janet_marshal_int64(ctx, *((int64_t *)p)); janet_marshal_int64(ctx, *((int64_t *)p));
} }
static void int64_unmarshal(void *p, JanetMarshalContext *ctx) { static void *int64_unmarshal(JanetMarshalContext *ctx) {
*((int64_t *)p) = janet_unmarshal_int64(ctx); int64_t *p = janet_unmarshal_abstract(ctx, sizeof(int64_t));
p[0] = janet_unmarshal_int64(ctx);
return p;
} }
static void it_s64_tostring(void *p, JanetBuffer *buffer) { static void it_s64_tostring(void *p, JanetBuffer *buffer) {

View File

@ -338,6 +338,13 @@ void janet_marshal_janet(JanetMarshalContext *ctx, Janet x) {
marshal_one(st, x, ctx->flags + 1); marshal_one(st, x, ctx->flags + 1);
} }
void janet_marshal_abstract(JanetMarshalContext *ctx, void *abstract) {
MarshalState *st = (MarshalState *)(ctx->m_state);
janet_table_put(&st->seen,
janet_wrap_abstract(abstract),
janet_wrap_integer(st->nextid++));
}
#define MARK_SEEN() \ #define MARK_SEEN() \
janet_table_put(&st->seen, x, janet_wrap_integer(st->nextid++)) janet_table_put(&st->seen, x, janet_wrap_integer(st->nextid++))
@ -345,11 +352,9 @@ static void marshal_one_abstract(MarshalState *st, Janet x, int flags) {
void *abstract = janet_unwrap_abstract(x); void *abstract = janet_unwrap_abstract(x);
const JanetAbstractType *at = janet_abstract_type(abstract); const JanetAbstractType *at = janet_abstract_type(abstract);
if (at->marshal) { if (at->marshal) {
JanetMarshalContext context = {st, NULL, flags, NULL};
pushbyte(st, LB_ABSTRACT); pushbyte(st, LB_ABSTRACT);
marshal_one(st, janet_csymbolv(at->name), flags + 1); marshal_one(st, janet_csymbolv(at->name), flags + 1);
push64(st, (uint64_t) janet_abstract_size(abstract)); JanetMarshalContext context = {st, NULL, flags, NULL, at};
MARK_SEEN();
at->marshal(abstract, &context); at->marshal(abstract, &context);
} else { } else {
janet_panicf("try to marshal unregistered abstract type, cannot marshal %p", x); janet_panicf("try to marshal unregistered abstract type, cannot marshal %p", x);
@ -983,6 +988,11 @@ static const uint8_t *unmarshal_one_fiber(
return data; return data;
} }
void janet_unmarshal_ensure(JanetMarshalContext *ctx, size_t size) {
UnmarshalState *st = (UnmarshalState *)(ctx->u_state);
MARSH_EOS(st, ctx->data + size);
}
int32_t janet_unmarshal_int(JanetMarshalContext *ctx) { int32_t janet_unmarshal_int(JanetMarshalContext *ctx) {
UnmarshalState *st = (UnmarshalState *)(ctx->u_state); UnmarshalState *st = (UnmarshalState *)(ctx->u_state);
return readint(st, &(ctx->data)); return readint(st, &(ctx->data));
@ -1017,17 +1027,28 @@ Janet janet_unmarshal_janet(JanetMarshalContext *ctx) {
return ret; return ret;
} }
void *janet_unmarshal_abstract(JanetMarshalContext *ctx, size_t size) {
UnmarshalState *st = (UnmarshalState *)(ctx->u_state);
if (ctx->at == NULL) {
janet_panicf("janet_unmarshal_abstract called more than once");
}
void *p = janet_abstract(ctx->at, size);
janet_v_push(st->lookup, janet_wrap_abstract(p));
ctx->at = NULL;
return p;
}
static const uint8_t *unmarshal_one_abstract(UnmarshalState *st, const uint8_t *data, Janet *out, int flags) { static const uint8_t *unmarshal_one_abstract(UnmarshalState *st, const uint8_t *data, Janet *out, int flags) {
Janet key; Janet key;
data = unmarshal_one(st, data, &key, flags + 1); data = unmarshal_one(st, data, &key, flags + 1);
const JanetAbstractType *at = janet_get_abstract_type(key); const JanetAbstractType *at = janet_get_abstract_type(key);
if (at == NULL) return NULL; if (at == NULL) return NULL;
if (at->unmarshal) { if (at->unmarshal) {
void *p = janet_abstract(at, (size_t) read64(st, &data)); JanetMarshalContext context = {NULL, st, flags, data, at};
*out = janet_wrap_abstract(p); *out = janet_wrap_abstract(at->unmarshal(&context));
JanetMarshalContext context = {NULL, st, flags, data}; if (context.at != NULL) {
janet_v_push(st->lookup, *out); janet_panicf("janet_unmarshal_abstract not called");
at->unmarshal(p, &context); }
return context.data; return context.data;
} }
return NULL; return NULL;

View File

@ -33,6 +33,7 @@ static Janet janet_rng_get(void *p, Janet key);
static void janet_rng_marshal(void *p, JanetMarshalContext *ctx) { static void janet_rng_marshal(void *p, JanetMarshalContext *ctx) {
JanetRNG *rng = (JanetRNG *)p; JanetRNG *rng = (JanetRNG *)p;
janet_marshal_abstract(ctx, p);
janet_marshal_int(ctx, (int32_t) rng->a); janet_marshal_int(ctx, (int32_t) rng->a);
janet_marshal_int(ctx, (int32_t) rng->b); janet_marshal_int(ctx, (int32_t) rng->b);
janet_marshal_int(ctx, (int32_t) rng->c); janet_marshal_int(ctx, (int32_t) rng->c);
@ -40,13 +41,14 @@ static void janet_rng_marshal(void *p, JanetMarshalContext *ctx) {
janet_marshal_int(ctx, (int32_t) rng->counter); janet_marshal_int(ctx, (int32_t) rng->counter);
} }
static void janet_rng_unmarshal(void *p, JanetMarshalContext *ctx) { static void *janet_rng_unmarshal(JanetMarshalContext *ctx) {
JanetRNG *rng = (JanetRNG *)p; JanetRNG *rng = janet_unmarshal_abstract(ctx, sizeof(JanetRNG));
rng->a = (uint32_t) janet_unmarshal_int(ctx); rng->a = (uint32_t) janet_unmarshal_int(ctx);
rng->b = (uint32_t) janet_unmarshal_int(ctx); rng->b = (uint32_t) janet_unmarshal_int(ctx);
rng->c = (uint32_t) janet_unmarshal_int(ctx); rng->c = (uint32_t) janet_unmarshal_int(ctx);
rng->d = (uint32_t) janet_unmarshal_int(ctx); rng->d = (uint32_t) janet_unmarshal_int(ctx);
rng->counter = (uint32_t) janet_unmarshal_int(ctx); rng->counter = (uint32_t) janet_unmarshal_int(ctx);
return rng;
} }
static JanetAbstractType JanetRNG_type = { static JanetAbstractType JanetRNG_type = {

View File

@ -1017,6 +1017,7 @@ static void peg_marshal(void *p, JanetMarshalContext *ctx) {
Peg *peg = (Peg *)p; Peg *peg = (Peg *)p;
janet_marshal_size(ctx, peg->bytecode_len); janet_marshal_size(ctx, peg->bytecode_len);
janet_marshal_int(ctx, (int32_t)peg->num_constants); janet_marshal_int(ctx, (int32_t)peg->num_constants);
janet_marshal_abstract(ctx, p);
for (size_t i = 0; i < peg->bytecode_len; i++) for (size_t i = 0; i < peg->bytecode_len; i++)
janet_marshal_int(ctx, (int32_t) peg->bytecode[i]); janet_marshal_int(ctx, (int32_t) peg->bytecode[i]);
for (uint32_t j = 0; j < peg->num_constants; j++) for (uint32_t j = 0; j < peg->num_constants; j++)
@ -1030,25 +1031,28 @@ static size_t size_padded(size_t offset, size_t size) {
return x - (x % size); return x - (x % size);
} }
static void peg_unmarshal(void *p, JanetMarshalContext *ctx) { static void *peg_unmarshal(JanetMarshalContext *ctx) {
char *mem = p; size_t bytecode_len = janet_unmarshal_size(ctx);
Peg *peg = (Peg *)p; uint32_t num_constants = (uint32_t) janet_unmarshal_int(ctx);
peg->bytecode_len = janet_unmarshal_size(ctx);
peg->num_constants = (uint32_t) janet_unmarshal_int(ctx);
/* Calculate offsets. Should match those in make_peg */ /* Calculate offsets. Should match those in make_peg */
size_t bytecode_start = size_padded(sizeof(Peg), sizeof(uint32_t)); size_t bytecode_start = size_padded(sizeof(Peg), sizeof(uint32_t));
size_t bytecode_size = peg->bytecode_len * sizeof(uint32_t); size_t bytecode_size = bytecode_len * sizeof(uint32_t);
size_t constants_start = size_padded(bytecode_start + bytecode_size, sizeof(Janet)); size_t constants_start = size_padded(bytecode_start + bytecode_size, sizeof(Janet));
size_t total_size = constants_start + sizeof(Janet) * num_constants;
/* DOS prevention? I.E. we could read bytecode and constants before
* hand so we don't allocated a ton of memory on bad, short input */
/* Allocate PEG */
char *mem = janet_unmarshal_abstract(ctx, total_size);
Peg *peg = (Peg *)mem;
uint32_t *bytecode = (uint32_t *)(mem + bytecode_start); uint32_t *bytecode = (uint32_t *)(mem + bytecode_start);
Janet *constants = (Janet *)(mem + constants_start); Janet *constants = (Janet *)(mem + constants_start);
peg->bytecode = NULL; peg->bytecode = NULL;
peg->constants = NULL; peg->constants = NULL;
peg->bytecode_len = bytecode_len;
/* Ensure not too large */ peg->num_constants = num_constants;
if (constants_start + sizeof(Janet) * peg->num_constants > janet_abstract_size(p)) {
janet_panic("size mismatch");
}
for (size_t i = 0; i < peg->bytecode_len; i++) for (size_t i = 0; i < peg->bytecode_len; i++)
bytecode[i] = (uint32_t) janet_unmarshal_int(ctx); bytecode[i] = (uint32_t) janet_unmarshal_int(ctx);
@ -1176,7 +1180,7 @@ static void peg_unmarshal(void *p, JanetMarshalContext *ctx) {
peg->bytecode = bytecode; peg->bytecode = bytecode;
peg->constants = constants; peg->constants = constants;
free(op_flags); free(op_flags);
return; return peg;
bad: bad:
free(op_flags); free(op_flags);

View File

@ -94,17 +94,20 @@ static int ta_buffer_gc(void *p, size_t s) {
static void ta_buffer_marshal(void *p, JanetMarshalContext *ctx) { static void ta_buffer_marshal(void *p, JanetMarshalContext *ctx) {
JanetTArrayBuffer *buf = (JanetTArrayBuffer *)p; JanetTArrayBuffer *buf = (JanetTArrayBuffer *)p;
janet_marshal_abstract(ctx, p);
janet_marshal_size(ctx, buf->size); janet_marshal_size(ctx, buf->size);
janet_marshal_int(ctx, buf->flags); janet_marshal_int(ctx, buf->flags);
janet_marshal_bytes(ctx, buf->data, buf->size); janet_marshal_bytes(ctx, buf->data, buf->size);
} }
static void ta_buffer_unmarshal(void *p, JanetMarshalContext *ctx) { static void *ta_buffer_unmarshal(JanetMarshalContext *ctx) {
JanetTArrayBuffer *buf = (JanetTArrayBuffer *)p; JanetTArrayBuffer *buf = janet_unmarshal_abstract(ctx, sizeof(JanetTArrayBuffer));
size_t size = janet_unmarshal_size(ctx); size_t size = janet_unmarshal_size(ctx);
int32_t flags = janet_unmarshal_int(ctx);
ta_buffer_init(buf, size); ta_buffer_init(buf, size);
buf->flags = janet_unmarshal_int(ctx); buf->flags = flags;
janet_unmarshal_bytes(ctx, buf->data, size); janet_unmarshal_bytes(ctx, buf->data, size);
return buf;
} }
static const JanetAbstractType ta_buffer_type = { static const JanetAbstractType ta_buffer_type = {
@ -128,6 +131,7 @@ static int ta_mark(void *p, size_t s) {
static void ta_view_marshal(void *p, JanetMarshalContext *ctx) { static void ta_view_marshal(void *p, JanetMarshalContext *ctx) {
JanetTArrayView *view = (JanetTArrayView *)p; JanetTArrayView *view = (JanetTArrayView *)p;
size_t offset = (view->buffer->data - view->as.u8); size_t offset = (view->buffer->data - view->as.u8);
janet_marshal_abstract(ctx, p);
janet_marshal_size(ctx, view->size); janet_marshal_size(ctx, view->size);
janet_marshal_size(ctx, view->stride); janet_marshal_size(ctx, view->stride);
janet_marshal_int(ctx, view->type); janet_marshal_int(ctx, view->type);
@ -135,11 +139,11 @@ static void ta_view_marshal(void *p, JanetMarshalContext *ctx) {
janet_marshal_janet(ctx, janet_wrap_abstract(view->buffer)); janet_marshal_janet(ctx, janet_wrap_abstract(view->buffer));
} }
static void ta_view_unmarshal(void *p, JanetMarshalContext *ctx) { static void *ta_view_unmarshal(JanetMarshalContext *ctx) {
JanetTArrayView *view = (JanetTArrayView *)p;
size_t offset; size_t offset;
int32_t atype; int32_t atype;
Janet buffer; Janet buffer;
JanetTArrayView *view = janet_unmarshal_abstract(ctx, sizeof(JanetTArrayView));
view->size = janet_unmarshal_size(ctx); view->size = janet_unmarshal_size(ctx);
view->stride = janet_unmarshal_size(ctx); view->stride = janet_unmarshal_size(ctx);
atype = janet_unmarshal_int(ctx); atype = janet_unmarshal_int(ctx);
@ -157,6 +161,7 @@ static void ta_view_unmarshal(void *p, JanetMarshalContext *ctx) {
if (view->buffer->size < buf_need_size) if (view->buffer->size < buf_need_size)
janet_panic("bad typed array offset in marshalled data"); janet_panic("bad typed array offset in marshalled data");
view->as.u8 = view->buffer->data + offset; view->as.u8 = view->buffer->data + offset;
return view;
} }
static JanetMethod tarray_view_methods[6]; static JanetMethod tarray_view_methods[6];

View File

@ -886,6 +886,7 @@ typedef struct {
void *u_state; void *u_state;
int flags; int flags;
const uint8_t *data; const uint8_t *data;
const JanetAbstractType *at;
} JanetMarshalContext; } JanetMarshalContext;
/* Defines an abstract type */ /* Defines an abstract type */
@ -896,7 +897,7 @@ struct JanetAbstractType {
Janet(*get)(void *data, Janet key); Janet(*get)(void *data, Janet key);
void (*put)(void *data, Janet key, Janet value); void (*put)(void *data, Janet key, Janet value);
void (*marshal)(void *p, JanetMarshalContext *ctx); void (*marshal)(void *p, JanetMarshalContext *ctx);
void (*unmarshal)(void *p, JanetMarshalContext *ctx); void *(*unmarshal)(JanetMarshalContext *ctx);
void (*tostring)(void *p, JanetBuffer *buffer); void (*tostring)(void *p, JanetBuffer *buffer);
}; };
@ -1422,13 +1423,16 @@ JANET_API void janet_marshal_int64(JanetMarshalContext *ctx, int64_t value);
JANET_API void janet_marshal_byte(JanetMarshalContext *ctx, uint8_t value); JANET_API void janet_marshal_byte(JanetMarshalContext *ctx, uint8_t value);
JANET_API void janet_marshal_bytes(JanetMarshalContext *ctx, const uint8_t *bytes, size_t len); JANET_API void janet_marshal_bytes(JanetMarshalContext *ctx, const uint8_t *bytes, size_t len);
JANET_API void janet_marshal_janet(JanetMarshalContext *ctx, Janet x); JANET_API void janet_marshal_janet(JanetMarshalContext *ctx, Janet x);
JANET_API void janet_marshal_abstract(JanetMarshalContext *ctx, void *abstract);
JANET_API void janet_unmarshal_ensure(JanetMarshalContext *ctx, size_t size);
JANET_API size_t janet_unmarshal_size(JanetMarshalContext *ctx); JANET_API size_t janet_unmarshal_size(JanetMarshalContext *ctx);
JANET_API int32_t janet_unmarshal_int(JanetMarshalContext *ctx); JANET_API int32_t janet_unmarshal_int(JanetMarshalContext *ctx);
JANET_API int64_t janet_unmarshal_int64(JanetMarshalContext *ctx); JANET_API int64_t janet_unmarshal_int64(JanetMarshalContext *ctx);
JANET_API uint8_t janet_unmarshal_byte(JanetMarshalContext *ctx); JANET_API uint8_t janet_unmarshal_byte(JanetMarshalContext *ctx);
JANET_API void janet_unmarshal_bytes(JanetMarshalContext *ctx, uint8_t *dest, size_t len); JANET_API void janet_unmarshal_bytes(JanetMarshalContext *ctx, uint8_t *dest, size_t len);
JANET_API Janet janet_unmarshal_janet(JanetMarshalContext *ctx); JANET_API Janet janet_unmarshal_janet(JanetMarshalContext *ctx);
JANET_API void *janet_unmarshal_abstract(JanetMarshalContext *ctx, size_t size);
JANET_API void janet_register_abstract_type(const JanetAbstractType *at); JANET_API void janet_register_abstract_type(const JanetAbstractType *at);
JANET_API const JanetAbstractType *janet_get_abstract_type(Janet key); JANET_API const JanetAbstractType *janet_get_abstract_type(Janet key);