From 06c755c98a54a45bb2e85ec53c62f9b4ad9a4435 Mon Sep 17 00:00:00 2001 From: Calvin Rose Date: Fri, 3 Aug 2018 13:41:44 -0400 Subject: [PATCH] Be stricter with function arity. --- src/core/core.dst | 49 ++++++++++++++++++++++++--------------------- src/core/fiber.c | 29 +++++++++++++++++++++++++-- src/core/fiber.h | 4 ++-- src/core/specials.c | 9 ++++++++- src/core/vm.c | 22 +++++++++++++++++--- test/suite0.dst | 2 +- test/suite1.dst | 6 +++--- 7 files changed, 86 insertions(+), 35 deletions(-) diff --git a/src/core/core.dst b/src/core/core.dst index 66e8a556..c6034aaf 100644 --- a/src/core/core.dst +++ b/src/core/core.dst @@ -268,7 +268,7 @@ [head & body] (def len (length head)) (defn doone - [i preds] + @[i preds] (default preds @['and]) (if (>= i len) (tuple.prepend body 'do) @@ -338,7 +338,7 @@ subloop (tuple ':= $i (tuple + 1 $i))))) (error (string "unexpected loop verb: " verb))))))) - (doone 0)) + (doone 0 nil)) (defmacro for "Similar to loop, but accumulates the loop body into an array and returns that." @@ -364,13 +364,13 @@ (defmacro coro "A wrapper for making fibers. Same as (fiber (fn [] ...body))." [& body] - (tuple fiber.new (apply tuple 'fn [] body))) + (tuple fiber.new (apply tuple 'fn @[] body))) (defmacro if-let "Takes the first one or two forms in a vector and if both are true binds all the forms with let and evaluates the first expression else evaluates the second" - [bindings tru fal] + @[bindings tru fal] (def len (length bindings)) (if (zero? len) (error "expected at least 1 binding")) (if (odd? len) (error "expected an even number of bindings")) @@ -477,12 +477,12 @@ (sort-help a (+ piv 1) hi by)) a) - (fn [a by] + (fn @[a by] (sort-help a 0 (- (length a) 1) (or by order<))))) (defn sorted "Returns the sorted version of an indexed data structure." - [ind by] + @[ind by] (def sa (sort (apply1 array ind) by)) (if (= :tuple (type ind)) (apply1 tuple sa) @@ -491,7 +491,7 @@ (defn reduce "Reduce, also know as fold-left in many languages, transforms an indexed type (array, tuple) with a function to produce a value." - [f init ind] + @[f init ind] (var res init) (loop [x :in ind] (:= res (f res x))) @@ -545,7 +545,7 @@ "Map a function over every element in an array or tuple and use array to concatenate the results. Returns the same type as the input sequence." - [f ind t] + @[f ind t] (def res @[]) (loop [x :in ind] (array.concat res (f x))) @@ -556,7 +556,7 @@ (defn filter "Given a predicate, take only elements from an array or tuple for which (pred element) is truthy. Returns the same type as the input sequence." - [pred ind t] + @[pred ind t] (def res @[]) (loop [item :in ind] (if (pred item) @@ -673,12 +673,12 @@ (if (zero? (length more)) f (fn [& r] (apply1 f (array.concat @[] more r))))) -(defn every? [pred seq] +(defn every? [pred ind] (var res true) (var i 0) - (def len (length seq)) + (def len (length ind)) (while (< i len) - (def item (get seq i)) + (def item (get ind i)) (if (pred item) (++ i) (do (:= res false) (:= i len)))) @@ -709,7 +709,7 @@ (defn zipcoll "Creates an table or tuple from two arrays/tuples. If a third argument of :struct is given result is struct else is table." - [keys vals t] + @[keys vals t] (def res @{}) (def lk (length keys)) (def lv (length vals)) @@ -987,7 +987,8 @@ ### ### -(defn make-env [parent] +(defn make-env + @[parent] (def parent (if parent parent _env)) (def newenv (table.setproto @{} parent)) (put newenv '_env @{:value newenv :private true}) @@ -1005,7 +1006,7 @@ This function can be used to implement a repl very easily, simply pass a function that reads line from stdin to chunks, and print to onvalue." - [env chunks onvalue onerr where] + @[env chunks onvalue onerr where] # Are we done yet? (var going true) @@ -1047,7 +1048,7 @@ (var good true) (def f (fiber.new - (fn [] + (fn @[] (def res (compile source env where)) (if (= (type res) :function) (res) @@ -1121,7 +1122,7 @@ environment is needed, use run-context." [str] (var state (string str)) - (defn chunks [buf] + (defn chunks [buf _] (def ret state) (:= state nil) (if ret @@ -1191,7 +1192,7 @@ (def cache @{}) (def loading @{}) - (fn require [path args] + (fn require @[path args] (when (get loading path) (error (string "circular dependency: module " path " is loading"))) (def {:exit exit-on-error} (or args {})) @@ -1206,10 +1207,10 @@ (if f (do # Normal dst module - (defn chunks [buf] (file.read f 1024 buf)) + (defn chunks [buf _] (file.read f 1024 buf)) (run-context newenv chunks identity (if exit-on-error - (fn [a b c d] (default-error-handler a b c d) (os.exit 1)) + (fn @[a b c d] (default-error-handler a b c d) (os.exit 1)) default-error-handler) path) (file.close f)) @@ -1239,11 +1240,12 @@ (put env (symbol prefix k) newv)) (:= k (next newenv k)))) -(defmacro import [path & args] +(defmacro import "Import a module. First requires the module, and then merges its symbols into the current environment, prepending a given prefix as needed. (use the :as or :prefix option to set a prefix). If no prefix is provided, use the name of the module as a prefix." + [path & args] (def argm (map (fn [x] (if (and (symbol? x) (= (get x 0) 58)) x @@ -1251,9 +1253,10 @@ args)) (apply tuple import* '_env (string path) argm)) -(defn repl [getchunk onvalue onerr] +(defn repl "Run a repl. The first parameter is an optional function to call to get a chunk of source code. Should return nil for end of file." + @[getchunk onvalue onerr] (def newenv (make-env)) (default getchunk (fn [buf] (file.read stdin :line buf))) @@ -1265,7 +1268,7 @@ (defn all-symbols "Get all symbols available in the current environment." - [env] + @[env] (default env *env*) (def envs @[]) (do (var e env) (while e (array.push envs e) (:= e (table.getproto e)))) diff --git a/src/core/fiber.c b/src/core/fiber.c index 2db04821..3b2dc223 100644 --- a/src/core/fiber.c +++ b/src/core/fiber.c @@ -107,7 +107,7 @@ void dst_fiber_pushn(DstFiber *fiber, const Dst *arr, int32_t n) { } /* Push a stack frame to a fiber */ -void dst_fiber_funcframe(DstFiber *fiber, DstFunction *func) { +int dst_fiber_funcframe(DstFiber *fiber, DstFunction *func) { DstStackFrame *newframe; int32_t i; @@ -116,6 +116,13 @@ void dst_fiber_funcframe(DstFiber *fiber, DstFunction *func) { int32_t nextframe = fiber->stackstart; int32_t nextstacktop = nextframe + func->def->slotcount + DST_FRAME_SIZE; + /* Check strict arity */ + if (func->def->flags & DST_FUNCDEF_FLAG_FIXARITY) { + if (func->def->arity != (fiber->stacktop - fiber->stackstart)) { + return 1; + } + } + if (fiber->capacity < nextstacktop) { dst_fiber_setcapacity(fiber, 2 * nextstacktop); } @@ -146,6 +153,9 @@ void dst_fiber_funcframe(DstFiber *fiber, DstFunction *func) { oldtop - tuplehead)); } } + + /* Good return */ + return 0; } /* If a frame has a closure environment, detach it from @@ -165,12 +175,19 @@ static void dst_env_detach(DstFuncEnv *env) { } /* Create a tail frame for a function */ -void dst_fiber_funcframe_tail(DstFiber *fiber, DstFunction *func) { +int dst_fiber_funcframe_tail(DstFiber *fiber, DstFunction *func) { int32_t i; int32_t nextframetop = fiber->frame + func->def->slotcount; int32_t nextstacktop = nextframetop + DST_FRAME_SIZE; int32_t stacksize; + /* Check strict arity */ + if (func->def->flags & DST_FUNCDEF_FLAG_FIXARITY) { + if (func->def->arity != (fiber->stacktop - fiber->stackstart)) { + return 1; + } + } + if (fiber->capacity < nextstacktop) { dst_fiber_setcapacity(fiber, 2 * nextstacktop); } @@ -213,6 +230,9 @@ void dst_fiber_funcframe_tail(DstFiber *fiber, DstFunction *func) { dst_fiber_frame(fiber)->func = func; dst_fiber_frame(fiber)->pc = func->def->bytecode; dst_fiber_frame(fiber)->flags |= DST_STACKFRAME_TAILCALL; + + /* Good return */ + return 0; } /* Push a stack frame to a fiber for a c function */ @@ -263,6 +283,11 @@ static int cfun_new(DstArgs args) { DST_MINARITY(args, 1); DST_MAXARITY(args, 2); DST_ARG_FUNCTION(func, args, 0); + if (func->def->flags & DST_FUNCDEF_FLAG_FIXARITY) { + if (func->def->arity != 1) { + DST_THROW(args, "expected unit arity function in fiber constructor"); + } + } fiber = dst_fiber(func, 64); if (args.n == 2) { const uint8_t *flags; diff --git a/src/core/fiber.h b/src/core/fiber.h index dcc17105..be0e8c02 100644 --- a/src/core/fiber.h +++ b/src/core/fiber.h @@ -40,8 +40,8 @@ void dst_fiber_push(DstFiber *fiber, Dst x); void dst_fiber_push2(DstFiber *fiber, Dst x, Dst y); void dst_fiber_push3(DstFiber *fiber, Dst x, Dst y, Dst z); void dst_fiber_pushn(DstFiber *fiber, const Dst *arr, int32_t n); -void dst_fiber_funcframe(DstFiber *fiber, DstFunction *func); -void dst_fiber_funcframe_tail(DstFiber *fiber, DstFunction *func); +int dst_fiber_funcframe(DstFiber *fiber, DstFunction *func); +int dst_fiber_funcframe_tail(DstFiber *fiber, DstFunction *func); void dst_fiber_cframe(DstFiber *fiber, DstCFunction cfun); void dst_fiber_popframe(DstFiber *fiber); diff --git a/src/core/specials.c b/src/core/specials.c index 86cdb190..39d50a2c 100644 --- a/src/core/specials.c +++ b/src/core/specials.c @@ -549,7 +549,14 @@ static DstSlot dstc_fn(DstFopts opts, int32_t argn, const Dst *argv) { /* Build function */ def = dstc_pop_funcdef(c); def->arity = arity; - if (varargs) def->flags |= DST_FUNCDEF_FLAG_VARARG; + + /* Tuples indicated fixed arity, arrays indicate flexible arity */ + /* TODO - revisit this */ + if (varargs) + def->flags |= DST_FUNCDEF_FLAG_VARARG; + else if (dst_checktype(paramv, DST_TUPLE)) + def->flags |= DST_FUNCDEF_FLAG_FIXARITY; + if (selfref) def->name = dst_unwrap_symbol(head); defindex = dstc_addfuncdef(c, def); diff --git a/src/core/vm.c b/src/core/vm.c index f152e4b6..4999a4d3 100644 --- a/src/core/vm.c +++ b/src/core/vm.c @@ -768,7 +768,8 @@ static void *op_lookup[255] = { if (dst_checktype(callee, DST_FUNCTION)) { func = dst_unwrap_function(callee); dst_stack_frame(stack)->pc = pc; - dst_fiber_funcframe(fiber, func); + if (dst_fiber_funcframe(fiber, func)) + goto vm_arity_error; stack = fiber->data + fiber->frame; pc = func->def->bytecode; vm_checkgc_next(); @@ -794,7 +795,8 @@ static void *op_lookup[255] = { Dst callee = stack[oparg(1, 0xFFFFFF)]; if (dst_checktype(callee, DST_FUNCTION)) { func = dst_unwrap_function(callee); - dst_fiber_funcframe_tail(fiber, func); + if (dst_fiber_funcframe_tail(fiber, func)) + goto vm_arity_error; stack = fiber->data + fiber->frame; pc = func->def->bytecode; vm_checkgc_next(); @@ -1190,6 +1192,17 @@ static void *op_lookup[255] = { goto vm_reset; } + /* Handle function calls with bad arity */ + vm_arity_error: + { + retreg = dst_wrap_string(dst_formatc("calling %V got %d arguments, expected %d", + dst_wrap_function(func), + fiber->stacktop - fiber->stackstart, + func->def->arity)); + signal = DST_SIGNAL_ERROR; + goto vm_exit; + } + /* Resume a child fiber */ vm_resume_child: { @@ -1293,7 +1306,10 @@ DstSignal dst_call( *f = fiber; for (i = 0; i < argn; i++) dst_fiber_push(fiber, argv[i]); - dst_fiber_funcframe(fiber, fiber->root); + if (dst_fiber_funcframe(fiber, fiber->root)) { + *out = dst_cstringv("arity mismatch"); + return DST_SIGNAL_ERROR; + } /* Prevent push an extra value on the stack */ dst_fiber_set_status(fiber, DST_STATUS_PENDING); return dst_continue(fiber, dst_wrap_nil(), out); diff --git a/test/suite0.dst b/test/suite0.dst index f14abf53..d5ae22a5 100644 --- a/test/suite0.dst +++ b/test/suite0.dst @@ -155,7 +155,7 @@ # yield tests -(def t (fiber.new (fn [] (yield 1) (yield 2) 3))) +(def t (fiber.new (fn @[] (yield 1) (yield 2) 3))) (assert (= 1 (resume t)) "initial transfer to new fiber") (assert (= 2 (resume t)) "second transfer to fiber") diff --git a/test/suite1.dst b/test/suite1.dst index 429ac3c3..b6d32668 100644 --- a/test/suite1.dst +++ b/test/suite1.dst @@ -43,7 +43,7 @@ (defn assert-many [f n e] (var good true) (loop [i :range [0 n]] - (if (not (f i)) + (if (not (f)) (:= good false))) (assert good e)) @@ -76,9 +76,9 @@ # More fiber semantics (var myvar 0) -(defn fiberstuff [] +(defn fiberstuff @[] (++ myvar) - (def f (fiber.new (fn [] (++ myvar) (debug) (++ myvar)))) + (def f (fiber.new (fn @[] (++ myvar) (debug) (++ myvar)))) (resume f) (++ myvar))