Add RNG functionality to the math/ module.

The new RNG wraps up state for random number generation, so
one can have many rngs and even marshal and unmarshal them.
Adds math/rng, math/rng-uniform, and math/rng-int.

Also introduce `in` and change semantics for
indexing out of range. This commit enforces stricter
invariants on keys when indexing via a function call
on the data structure, or the new `in` function.

The `get` function is now more lax about keys, and will
not throw an error when a bad key is used for a data structure, instead
returning the default value.
This commit is contained in:
Calvin Rose 2019-11-08 17:35:27 -06:00
parent 58e3e63a89
commit aee1687215
12 changed files with 384 additions and 176 deletions

View File

@ -2,6 +2,13 @@
All notable changes to this project will be documented in this file.
## Unreleased
- Add `math/rng`, `math/rng-int`, and `math/rng-uniform`.
- Add `in` function to index in a stricter manner. Opposingly, `get` will
now not throw errors on bad keys.
- Indexed types and byte sequences will now error when indexed out of range or
with bad keys.
- Add rng functions to Janet. This also replaces the RNG behind `math/random`
and `math/seedrandom` with a consistent, platform independent RNG.
- Add `with-vars` macro.
- Add the `quickbin` command to jpm.
- Create shell.c when making the amlagamated source. This can be compiled with

View File

@ -25,14 +25,14 @@
(array/push modifiers ith))
(if (< i len) (recur (+ i 1)))))))
(def start (fstart 0))
(def args (get more start))
(def args (in more start))
# Add function signature to docstring
(var index 0)
(def arglen (length args))
(def buf (buffer "(" name))
(while (< index arglen)
(buffer/push-string buf " ")
(buffer/format buf "%p" (get args index))
(buffer/format buf "%p" (in args index))
(set index (+ index 1)))
(array/push modifiers (string buf ")\n\n" docstr))
# Build return value
@ -116,7 +116,7 @@
:table true
:buffer true
:struct true})
(fn idempotent? [x] (not (get non-atomic-types (type x))))))
(fn idempotent? [x] (not (in non-atomic-types (type x))))))
# C style macros and functions for imperative sugar. No bitwise though.
(defn inc "Returns x + 1." [x] (+ x 1))
@ -163,9 +163,9 @@
(defn aux [i]
(def restlen (- (length pairs) i))
(if (= restlen 0) nil
(if (= restlen 1) (get pairs i)
(tuple 'if (get pairs i)
(get pairs (+ i 1))
(if (= restlen 1) (in pairs i)
(tuple 'if (in pairs i)
(in pairs (+ i 1))
(aux (+ i 2))))))
(aux 0))
@ -179,9 +179,9 @@
(defn aux [i]
(def restlen (- (length pairs) i))
(if (= restlen 0) nil
(if (= restlen 1) (get pairs i)
(tuple 'if (tuple = sym (get pairs i))
(get pairs (+ i 1))
(if (= restlen 1) (in pairs i)
(tuple 'if (tuple = sym (in pairs i))
(in pairs (+ i 1))
(aux (+ i 2))))))
(if atm
(aux 0)
@ -231,8 +231,8 @@
(while (> i 0)
(-- i)
(set ret (if (= ret true)
(get forms i)
(tuple 'if (get forms i) ret))))
(in forms i)
(tuple 'if (in forms i) ret))))
ret)
(defmacro or
@ -244,7 +244,7 @@
(var i len)
(while (> i 0)
(-- i)
(def fi (get forms i))
(def fi (in forms i))
(set ret (if (idempotent? fi)
(tuple 'if fi fi ret)
(do
@ -260,7 +260,7 @@
(def len (length syms))
(def accum @[])
(while (< i len)
(array/push accum (get syms i) [gensym])
(array/push accum (in syms i) [gensym])
(++ i))
~(let (,;accum) ,;body))
@ -299,7 +299,7 @@
,(unless (= ds in) ~(def ,ds ,in))
(def ,len (,length ,ds))
(while (,< ,i ,len)
(def ,binding (get ,ds ,i))
(def ,binding (in ,ds ,i))
,;body
(++ ,i)))))
@ -311,7 +311,7 @@
,(unless (= ds in) ~(def ,ds ,in))
(var ,k (,next ,ds nil))
(while ,k
(def ,binding ,(if pair? ~(tuple ,k (get ,ds ,k)) k))
(def ,binding ,(if pair? ~(tuple ,k (in ,ds ,k)) k))
,;body
(set ,k (,next ,ds ,k))))))
@ -327,48 +327,47 @@
(defn- loop1
[body head i]
# Terminate recursion
(when (<= (length head) i)
(break ~(do ,;body)))
(def {i binding
(+ i 1) verb
(+ i 2) object} head)
(+ i 1) verb} head)
(cond
# 2 term expression
(when (keyword? binding)
(break
(let [rest (loop1 body head (+ i 2))]
(case binding
:until ~(do (if ,verb (break) nil) ,rest)
:while ~(do (if ,verb nil (break)) ,rest)
:let ~(let ,verb (do ,rest))
:after ~(do ,rest ,verb nil)
:before ~(do ,verb ,rest nil)
:repeat (with-syms [iter]
~(do (var ,iter ,verb) (while (> ,iter 0) ,rest (-- ,iter))))
:when ~(when ,verb ,rest)
(error (string "unexpected loop modifier " binding))))))
# Terminate recursion
(<= (length head) i)
~(do ,;body)
# 2 term expression
(keyword? binding)
(let [rest (loop1 body head (+ i 2))]
(case binding
:until ~(do (if ,verb (break) nil) ,rest)
:while ~(do (if ,verb nil (break)) ,rest)
:let ~(let ,verb (do ,rest))
:after ~(do ,rest ,verb nil)
:before ~(do ,verb ,rest nil)
:repeat (with-syms [iter]
~(do (var ,iter ,verb) (while (> ,iter 0) ,rest (-- ,iter))))
:when ~(when ,verb ,rest)
(error (string "unexpected loop modifier " binding))))
# 3 term expression
(let [rest (loop1 body head (+ i 3))]
(case verb
:range (let [[start stop step] object]
(for-template binding start stop (or step 1) < + [rest]))
:keys (keys-template binding object false [rest])
:pairs (keys-template binding object true [rest])
:down (let [[start stop step] object]
(for-template binding start stop (or step 1) > - [rest]))
:in (each-template binding object [rest])
:iterate (iterate-template binding object rest)
:generate (with-syms [f s]
~(let [,f ,object]
(while true
(def ,binding (,resume ,f))
(if (= :dead (,fiber/status ,f)) (break))
,rest)))
(error (string "unexpected loop verb " verb))))))
# 3 term expression
(def {(+ i 2) object} head)
(let [rest (loop1 body head (+ i 3))]
(case verb
:range (let [[start stop step] object]
(for-template binding start stop (or step 1) < + [rest]))
:keys (keys-template binding object false [rest])
:pairs (keys-template binding object true [rest])
:down (let [[start stop step] object]
(for-template binding start stop (or step 1) > - [rest]))
:in (each-template binding object [rest])
:iterate (iterate-template binding object rest)
:generate (with-syms [f s]
~(let [,f ,object]
(while true
(def ,binding (,resume ,f))
(if (= :dead (,fiber/status ,f)) (break))
,rest)))
(error (string "unexpected loop verb " verb)))))
(defmacro for
"Do a c style for loop for side effects. Returns nil."
@ -466,11 +465,11 @@
(if (zero? len) (error "expected at least 1 binding"))
(if (odd? len) (error "expected an even number of bindings"))
(defn aux [i]
(def bl (get bindings i))
(def br (get bindings (+ 1 i)))
(if (>= i len)
tru
(do
(def bl (in bindings i))
(def br (in bindings (+ 1 i)))
(def atm (idempotent? bl))
(def sym (if atm bl (gensym)))
(if atm
@ -499,7 +498,7 @@
[& functions]
(case (length functions)
0 nil
1 (get functions 0)
1 (in functions 0)
2 (let [[f g] functions] (fn [& x] (f (g ;x))))
3 (let [[f g h] functions] (fn [& x] (f (g (h ;x)))))
4 (let [[f g h i] functions] (fn [& x] (f (g (h (i ;x))))))
@ -547,12 +546,12 @@
(defn first
"Get the first element from an indexed data structure."
[xs]
(get xs 0))
(in xs 0))
(defn last
"Get the last element from an indexed data structure."
[xs]
(get xs (- (length xs) 1)))
(in xs (- (length xs) 1)))
###
###
@ -566,16 +565,16 @@
(defn part
[a lo hi by]
(def pivot (get a hi))
(def pivot (in a hi))
(var i lo)
(for j lo hi
(def aj (get a j))
(def aj (in a j))
(when (by aj pivot)
(def ai (get a i))
(def ai (in a i))
(set (a i) aj)
(set (a j) ai)
(++ i)))
(set (a hi) (get a i))
(set (a hi) (in a i))
(set (a i) pivot)
i)
@ -609,20 +608,20 @@
[f & inds]
(def ninds (length inds))
(if (= 0 ninds) (error "expected at least 1 indexed collection"))
(var limit (length (get inds 0)))
(var limit (length (in inds 0)))
(for i 0 ninds
(def l (length (get inds i)))
(def l (length (in inds i)))
(if (< l limit) (set limit l)))
(def [i1 i2 i3 i4] inds)
(def res (array/new limit))
(case ninds
1 (for i 0 limit (set (res i) (f (get i1 i))))
2 (for i 0 limit (set (res i) (f (get i1 i) (get i2 i))))
3 (for i 0 limit (set (res i) (f (get i1 i) (get i2 i) (get i3 i))))
4 (for i 0 limit (set (res i) (f (get i1 i) (get i2 i) (get i3 i) (get i4 i))))
1 (for i 0 limit (set (res i) (f (in i1 i))))
2 (for i 0 limit (set (res i) (f (in i1 i) (in i2 i))))
3 (for i 0 limit (set (res i) (f (in i1 i) (in i2 i) (in i3 i))))
4 (for i 0 limit (set (res i) (f (in i1 i) (in i2 i) (in i3 i) (in i4 i))))
(for i 0 limit
(def args (array/new ninds))
(for j 0 ninds (set (args j) (get (get inds j) i)))
(for j 0 ninds (set (args j) (in (in inds j) i)))
(set (res i) (f ;args))))
res)
@ -695,7 +694,7 @@
(var i 0)
(var going true)
(while (if (< i len) going)
(def item (get ind i))
(def item (in ind i))
(if (pred item) (set going false) (++ i)))
(if going nil i))
@ -705,7 +704,7 @@
and a not found. Consider find-index if this is an issue."
[pred ind]
(def i (find-index pred ind))
(if (= i nil) nil (get ind i)))
(if (= i nil) nil (in ind i)))
(defn take
"Take first n elements in an indexed type. Returns new indexed instance."
@ -783,7 +782,7 @@
[x & forms]
(defn fop [last n]
(def [h t] (if (= :tuple (type n))
(tuple (get n 0) (array/slice n 1))
(tuple (in n 0) (array/slice n 1))
(tuple n @[])))
(def parts (array/concat @[h last] t))
(tuple/slice parts 0))
@ -796,7 +795,7 @@
[x & forms]
(defn fop [last n]
(def [h t] (if (= :tuple (type n))
(tuple (get n 0) (array/slice n 1))
(tuple (in n 0) (array/slice n 1))
(tuple n @[])))
(def parts (array/concat @[h] t @[last]))
(tuple/slice parts 0))
@ -811,7 +810,7 @@
[x & forms]
(defn fop [last n]
(def [h t] (if (= :tuple (type n))
(tuple (get n 0) (array/slice n 1))
(tuple (in n 0) (array/slice n 1))
(tuple n @[])))
(def sym (gensym))
(def parts (array/concat @[h sym] t))
@ -827,7 +826,7 @@
[x & forms]
(defn fop [last n]
(def [h t] (if (= :tuple (type n))
(tuple (get n 0) (array/slice n 1))
(tuple (in n 0) (array/slice n 1))
(tuple n @[])))
(def sym (gensym))
(def parts (array/concat @[h] t @[sym]))
@ -843,7 +842,7 @@
(defn walk-dict [f form]
(def ret @{})
(loop [k :keys form]
(put ret (f k) (f (get form k))))
(put ret (f k) (f (in form k))))
ret)
(defn walk
@ -950,7 +949,7 @@
(var n (- len 1))
(def reversed (array/new len))
(while (>= n 0)
(array/push reversed (get t n))
(array/push reversed (in t n))
(-- n))
reversed)
@ -961,7 +960,7 @@
[ds]
(def ret @{})
(loop [k :keys ds]
(put ret (get ds k) k))
(put ret (in ds k) k))
ret)
(defn zipcoll
@ -973,7 +972,7 @@
(def lv (length vals))
(def len (if (< lk lv) lk lv))
(for i 0 len
(put res (get keys i) (get vals i)))
(put res (in keys i) (in vals i)))
res)
(defn get-in
@ -1043,7 +1042,7 @@
[tab & colls]
(loop [c :in colls
key :keys c]
(set (tab key) (get c key)))
(set (tab key) (in c key)))
tab)
(defn merge
@ -1054,7 +1053,7 @@
(def container @{})
(loop [c :in colls
key :keys c]
(set (container key) (get c key)))
(set (container key) (in c key)))
container)
(defn keys
@ -1073,7 +1072,7 @@
(def arr (array/new (length x)))
(var k (next x nil))
(while (not= nil k)
(array/push arr (get x k))
(array/push arr (in x k))
(set k (next x k)))
arr)
@ -1083,7 +1082,7 @@
(def arr (array/new (length x)))
(var k (next x nil))
(while (not= nil k)
(array/push arr (tuple k (get x k)))
(array/push arr (tuple k (in x k)))
(set k (next x k)))
arr)
@ -1092,7 +1091,7 @@
[ind]
(def freqs @{})
(each x ind
(def n (get freqs x))
(def n (in freqs x))
(set (freqs x) (if n (+ 1 n) 1)))
freqs)
@ -1106,7 +1105,7 @@
(def len (min ;(map length cols)))
(loop [i :range [0 len]
ci :range [0 ncol]]
(array/push res (get (get cols ci) i))))
(array/push res (in (in cols ci) i))))
res)
(defn distinct
@ -1114,7 +1113,7 @@
[xs]
(def ret @[])
(def seen @{})
(each x xs (if (get seen x) nil (do (put seen x true) (array/push ret x))))
(each x xs (if (in seen x) nil (do (put seen x true) (array/push ret x))))
ret)
(defn flatten-into
@ -1138,7 +1137,7 @@
like @[k v k v ...]. Returns a new array."
[dict]
(def ret (array/new (* 2 (length dict))))
(loop [k :keys dict] (array/push ret k (get dict k)))
(loop [k :keys dict] (array/push ret k (in dict k)))
ret)
(defn interpose
@ -1147,10 +1146,10 @@
[sep ind]
(def len (length ind))
(def ret (array/new (- (* 2 len) 1)))
(if (> len 0) (put ret 0 (get ind 0)))
(if (> len 0) (put ret 0 (in ind 0)))
(var i 1)
(while (< i len)
(array/push ret sep (get ind i))
(array/push ret sep (in ind i))
(++ i))
ret)
@ -1233,7 +1232,7 @@
(cond
(symbol? pattern)
(if (get seen pattern)
(if (in seen pattern)
~(if (= ,pattern ,expr) ,(onmatch) ,sentinel)
(do
(put seen pattern true)
@ -1244,7 +1243,7 @@
# Unification with external values
~(if (= ,(pattern 1) ,expr) ,(onmatch) ,sentinel)
(match-1
(get pattern 0) expr
(in pattern 0) expr
(fn []
~(if (and ,;(tuple/slice pattern 1)) ,(onmatch) ,sentinel)) seen))
@ -1259,7 +1258,7 @@
(++ i)
(if (= i len)
(onmatch)
(match-1 (get pattern i) (tuple get $arr i) aux seen))))
(match-1 (in pattern i) (tuple in $arr i) aux seen))))
,sentinel)))
(dictionary? pattern)
@ -1272,7 +1271,7 @@
(set key (next pattern key))
(if (= key nil)
(onmatch)
(match-1 (get pattern key) (tuple get $dict key) aux seen))))
(match-1 (in pattern key) (tuple in $dict key) aux seen))))
,sentinel)))
:else ~(if (= ,pattern ,expr) ,(onmatch) ,sentinel)))
@ -1293,9 +1292,9 @@
(def len-1 (dec len))
((fn aux [i]
(cond
(= i len-1) (get cases i)
(= i len-1) (in cases i)
(< i len-1) (with-syms [$res]
~(if (= ,sentinel (def ,$res ,(match-1 (get cases i) $x (fn [] (get cases (inc i))) @{})))
~(if (= ,sentinel (def ,$res ,(match-1 (in cases i) $x (fn [] (in cases (inc i))) @{})))
,(aux (+ 2 i))
,$res)))) 0)))
@ -1357,7 +1356,7 @@
(def bind-type
(string " "
(cond
(x :ref) (string :var " (" (type (get (x :ref) 0)) ")")
(x :ref) (string :var " (" (type (in (x :ref) 0)) ")")
(x :macro) :macro
(type (x :value)))
"\n"))
@ -1397,7 +1396,7 @@
(def newt @{})
(var key (next t nil))
(while (not= nil key)
(put newt (recur key) (on-value (get t key)))
(put newt (recur key) (on-value (in t key)))
(set key (next t key)))
newt)
@ -1410,24 +1409,24 @@
(recur x)))
(defn expanddef [t]
(def last (get t (- (length t) 1)))
(def bound (get t 1))
(def last (in t (- (length t) 1)))
(def bound (in t 1))
(tuple/slice
(array/concat
@[(get t 0) (expand-bindings bound)]
@[(in t 0) (expand-bindings bound)]
(tuple/slice t 2 -2)
@[(recur last)])))
(defn expandall [t]
(def args (map recur (tuple/slice t 1)))
(tuple (get t 0) ;args))
(tuple (in t 0) ;args))
(defn expandfn [t]
(def t1 (get t 1))
(def t1 (in t 1))
(if (symbol? t1)
(do
(def args (map recur (tuple/slice t 3)))
(tuple 'fn t1 (get t 2) ;args))
(tuple 'fn t1 (in t 2) ;args))
(do
(def args (map recur (tuple/slice t 2)))
(tuple 'fn t1 ;args))))
@ -1436,15 +1435,15 @@
(defn qq [x]
(case (type x)
:tuple (do
(def x0 (get x 0))
(def x0 (in x 0))
(if (or (= 'unquote x0) (= 'unquote-splicing x0))
(tuple x0 (recur (get x 1)))
(tuple x0 (recur (in x 1)))
(tuple/slice (map qq x))))
:array (map qq x)
:table (table (map qq (kvs x)))
:struct (struct (map qq (kvs x)))
x))
(tuple (get t 0) (qq (get t 1))))
(tuple (in t 0) (qq (in t 1))))
(def specs
{'set expanddef
@ -1458,8 +1457,8 @@
'while expandall})
(defn dotup [t]
(def h (get t 0))
(def s (get specs h))
(def h (in t 0))
(def s (in specs h))
(def entry (or (dyn h) {}))
(def m (entry :value))
(def m? (entry :macro))
@ -1956,7 +1955,7 @@
[path & args]
(def [fullpath mod-kind] (module/find path))
(unless fullpath (error mod-kind))
(if-let [check (get module/cache fullpath)]
(if-let [check (in module/cache fullpath)]
check
(do
(def loader (module/loaders mod-kind))
@ -2123,25 +2122,25 @@ _fiber is bound to the suspended fiber
"q" (fn [&] (set *quiet* true) 1)
"k" (fn [&] (set *compile-only* true) (set *exit-on-error* false) 1)
"n" (fn [&] (set *colorize* false) 1)
"m" (fn [i &] (setdyn :syspath (get args (+ i 1))) 2)
"m" (fn [i &] (setdyn :syspath (in args (+ i 1))) 2)
"c" (fn [i &]
(def e (dofile (get args (+ i 1))))
(spit (get args (+ i 2)) (make-image e))
(def e (dofile (in args (+ i 1))))
(spit (in args (+ i 2)) (make-image e))
(set *no-file* false)
3)
"-" (fn [&] (set *handleopts* false) 1)
"l" (fn [i &]
(import* (get args (+ i 1))
(import* (in args (+ i 1))
:prefix "" :exit *exit-on-error*)
2)
"e" (fn [i &]
(set *no-file* false)
(eval-string (get args (+ i 1)))
(eval-string (in args (+ i 1)))
2)})
(defn- dohandler [n i &]
(def h (get handlers n))
(if h (h i) (do (print "unknown flag -" n) ((get handlers "h")))))
(def h (in handlers n))
(if h (h i) (do (print "unknown flag -" n) ((in handlers "h")))))
(def- safe-forms {'defn true 'defn- true 'defmacro true 'defmacro- true})
(def- importers {'import true 'import* true 'use true 'dofile true 'require true})
@ -2162,7 +2161,7 @@ _fiber is bound to the suspended fiber
(var i 0)
(def lenargs (length args))
(while (< i lenargs)
(def arg (get args i))
(def arg (in args i))
(if (and *handleopts* (= "-" (string/slice arg 0 1)))
(+= i (dohandler (string/slice arg 1 2) i))
(do

View File

@ -129,7 +129,7 @@ const char *janet_getcstring(const Janet *argv, int32_t n) {
int32_t janet_getinteger(const Janet *argv, int32_t n) {
Janet x = argv[n];
if (!janet_checkint(x)) {
janet_panicf("bad slot #%d, expected integer, got %v", n, x);
janet_panicf("bad slot #%d, expected 32 bit signed integer, got %v", n, x);
}
return janet_unwrap_integer(x);
}
@ -137,7 +137,7 @@ int32_t janet_getinteger(const Janet *argv, int32_t n) {
int64_t janet_getinteger64(const Janet *argv, int32_t n) {
Janet x = argv[n];
if (!janet_checkint64(x)) {
janet_panicf("bad slot #%d, expected 64 bit integer, got %v", n, x);
janet_panicf("bad slot #%d, expected 64 bit signed integer, got %v", n, x);
}
return (int64_t) janet_unwrap_number(x);
}

View File

@ -104,7 +104,7 @@ static JanetSlot do_debug(JanetFopts opts, JanetSlot *args) {
janetc_emit(opts.compiler, JOP_SIGNAL | (2 << 24));
return janetc_cslot(janet_wrap_nil());
}
static JanetSlot do_get(JanetFopts opts, JanetSlot *args) {
static JanetSlot do_in(JanetFopts opts, JanetSlot *args) {
return opreduce(opts, args, JOP_GET, janet_wrap_nil());
}
static JanetSlot do_put(JanetFopts opts, JanetSlot *args) {
@ -275,7 +275,7 @@ static const JanetFunOptimizer optimizers[] = {
{minarity2, do_apply},
{maxarity1, do_yield},
{fixarity2, do_resume},
{fixarity2, do_get},
{fixarity2, do_in},
{fixarity3, do_put},
{fixarity1, do_length},
{NULL, do_add},

View File

@ -34,7 +34,7 @@
#define JANET_FUN_APPLY 3
#define JANET_FUN_YIELD 4
#define JANET_FUN_RESUME 5
#define JANET_FUN_GET 6
#define JANET_FUN_IN 6
#define JANET_FUN_PUT 7
#define JANET_FUN_LENGTH 8
#define JANET_FUN_ADD 9

View File

@ -262,6 +262,61 @@ static Janet janet_core_setdyn(int32_t argc, Janet *argv) {
return argv[1];
}
static Janet janet_core_get(int32_t argc, Janet *argv) {
janet_arity(argc, 2, 3);
Janet ds = argv[0];
Janet key = argv[1];
Janet dflt = argc == 3 ? argv[2] : janet_wrap_nil();
JanetType t = janet_type(argv[0]);
switch (t) {
default:
return dflt;
case JANET_STRING:
case JANET_SYMBOL:
case JANET_KEYWORD: {
if (!janet_checkint(key)) return dflt;
int32_t index = janet_unwrap_integer(key);
if (index < 0) return dflt;
const uint8_t *str = janet_unwrap_string(ds);
if (index >= janet_string_length(str)) return dflt;
return janet_wrap_integer(str[index]);
}
case JANET_ABSTRACT: {
void *abst = janet_unwrap_abstract(ds);
JanetAbstractType *type = (JanetAbstractType *)janet_abstract_type(abst);
if (!type->get) return dflt;
return (type->get)(abst, key);
}
case JANET_ARRAY:
case JANET_TUPLE: {
if (!janet_checkint(key)) return dflt;
int32_t index = janet_unwrap_integer(key);
if (index < 0) return dflt;
if (t == JANET_ARRAY) {
JanetArray *a = janet_unwrap_array(ds);
if (index >= a->count) return dflt;
return a->data[index];
} else {
const Janet *t = janet_unwrap_tuple(ds);
if (index >= janet_tuple_length(t)) return dflt;
return t[index];
}
}
case JANET_TABLE: {
JanetTable *flag = NULL;
Janet ret = janet_table_get_ex(janet_unwrap_table(ds), key, &flag);
if (flag == NULL) return dflt;
return ret;
}
case JANET_STRUCT: {
const JanetKV *st = janet_unwrap_struct(ds);
Janet ret = janet_struct_get(st, key);
if (janet_checktype(ret, JANET_NIL)) return dflt;
return ret;
}
}
}
static Janet janet_core_native(int32_t argc, Janet *argv) {
JanetModule init;
janet_arity(argc, 1, 2);
@ -685,6 +740,14 @@ static const JanetReg corelib_cfuns[] = {
JDOC("(slice x &opt start end)\n\n"
"Extract a sub-range of an indexed data strutrue or byte sequence.")
},
{
"get", janet_core_get,
JDOC("(get ds key &opt dflt)\n\n"
"Get the value mapped to key in data structure ds, and return dflt or nil if not found. "
"Similar to get, but will not throw an error if the key is invalid for the data structure "
"unless the data structure is an abstract type. In that case, the abstract type getter may throw "
"an error.")
},
{NULL, NULL, NULL}
};
@ -960,8 +1023,8 @@ JanetTable *janet_core_env(JanetTable *replacements) {
"will be returned to the last yield in the case of a pending fiber, or the argument to "
"the dispatch function in the case of a new fiber. Returns either the return result of "
"the fiber's dispatch function, or the value from the next yield call in fiber."));
janet_quick_asm(env, JANET_FUN_GET,
"get", 3, 2, 3, 4, get_asm, sizeof(get_asm),
janet_quick_asm(env, JANET_FUN_IN,
"in", 3, 2, 3, 4, get_asm, sizeof(get_asm),
JDOC("(get ds key &opt dflt)\n\n"
"Get a value from any associative data structure. Arrays, tuples, tables, structs, strings, "
"symbols, and buffers are all associative and can be used with get. Order structures, name "

View File

@ -27,19 +27,131 @@
#include "util.h"
#endif
static Janet janet_rng_get(void *p, Janet key);
static void janet_rng_marshal(void *p, JanetMarshalContext *ctx) {
JanetRNG *rng = (JanetRNG *)p;
janet_marshal_int(ctx, (int32_t) rng->a);
janet_marshal_int(ctx, (int32_t) rng->b);
janet_marshal_int(ctx, (int32_t) rng->c);
janet_marshal_int(ctx, (int32_t) rng->d);
janet_marshal_int(ctx, (int32_t) rng->counter);
}
static void janet_rng_unmarshal(void *p, JanetMarshalContext *ctx) {
JanetRNG *rng = (JanetRNG *)p;
rng->a = (uint32_t) janet_unmarshal_int(ctx);
rng->b = (uint32_t) janet_unmarshal_int(ctx);
rng->c = (uint32_t) janet_unmarshal_int(ctx);
rng->d = (uint32_t) janet_unmarshal_int(ctx);
rng->counter = (uint32_t) janet_unmarshal_int(ctx);
}
static JanetAbstractType JanetRNG_type = {
"core/rng",
NULL,
NULL,
janet_rng_get,
NULL,
janet_rng_marshal,
janet_rng_unmarshal,
NULL
};
static JANET_THREAD_LOCAL JanetRNG janet_vm_rng = {0};
JanetRNG *janet_default_rng(void) {
return &janet_vm_rng;
}
void janet_rng_seed(JanetRNG *rng, uint32_t seed) {
rng->a = seed + 123573u;
rng->b = (seed + 43234283u) % 12391233u;
rng->c = 0x17af0931u;
rng->d = 0xFFFaaFFFu;
rng->counter = 0u;
}
uint32_t janet_rng_u32(JanetRNG *rng) {
/* Algorithm "xorwow" from p. 5 of Marsaglia, "Xorshift RNGs" */
uint32_t t = rng->d;
uint32_t const s = rng->a;
rng->d = rng->c;
rng->c = rng->b;
rng->b = s;
t ^= t >> 2;
t ^= t << 1;
t ^= s ^ (s << 4);
rng->a = t;
rng->counter += 362437;
return t + rng->counter;
}
double janet_rng_double(JanetRNG *rng) {
uint32_t hi = janet_rng_u32(rng);
uint32_t lo = janet_rng_u32(rng);
uint64_t big = (uint64_t)(lo) | (((uint64_t) hi) << 32);
return ldexp((big >> (64 - 52)), -52);
}
static Janet cfun_rng_make(int32_t argc, Janet *argv) {
janet_arity(argc, 0, 1);
uint32_t seed = (uint32_t)(argc == 1 ? janet_getinteger(argv, 0) : 0);
JanetRNG *rng = janet_abstract(&JanetRNG_type, sizeof(JanetRNG));
janet_rng_seed(rng, seed);
return janet_wrap_abstract(rng);
}
static Janet cfun_rng_uniform(int32_t argc, Janet *argv) {
janet_fixarity(argc, 1);
JanetRNG *rng = janet_getabstract(argv, 0, &JanetRNG_type);
return janet_wrap_number(janet_rng_double(rng));
}
static Janet cfun_rng_int(int32_t argc, Janet *argv) {
janet_arity(argc, 1, 2);
JanetRNG *rng = janet_getabstract(argv, 0, &JanetRNG_type);
if (argc == 1) {
uint32_t word = janet_rng_u32(rng) >> 1;
return janet_wrap_integer(word);
} else {
int32_t max = janet_getinteger(argv, 1);
if (max <= 0) return janet_wrap_number(0);
uint32_t modulo = (uint32_t) max;
uint32_t bad = UINT32_MAX % modulo;
uint32_t word;
do {
word = janet_rng_u32(rng);
} while (word > UINT32_MAX - bad);
word >>= 1;
return janet_wrap_integer(word % modulo);
}
}
static const JanetMethod rng_methods[] = {
{"uniform", cfun_rng_uniform},
{"int", cfun_rng_int},
{NULL, NULL}
};
static Janet janet_rng_get(void *p, Janet key) {
(void) p;
if (!janet_checktype(key, JANET_KEYWORD)) janet_panicf("expected keyword method");
return janet_getmethod(janet_unwrap_keyword(key), rng_methods);
}
/* Get a random number */
static Janet janet_rand(int32_t argc, Janet *argv) {
(void) argv;
janet_fixarity(argc, 0);
double r = (rand() % RAND_MAX) / ((double) RAND_MAX);
return janet_wrap_number(r);
return janet_wrap_number(janet_rng_double(&janet_vm_rng));
}
/* Seed the random number generator */
static Janet janet_srand(int32_t argc, Janet *argv) {
janet_fixarity(argc, 1);
int32_t x = janet_getinteger(argv, 0);
srand((unsigned) x);
janet_rng_seed(&janet_vm_rng, (uint32_t) x);
return janet_wrap_nil();
}
@ -108,7 +220,7 @@ static const JanetReg math_cfuns[] = {
{
"math/seedrandom", janet_srand,
JDOC("(math/seedrandom seed)\n\n"
"Set the seed for the random number generator. 'seed' should be an "
"Set the seed for the random number generator. 'seed' should be "
"an integer.")
},
{
@ -201,6 +313,24 @@ static const JanetReg math_cfuns[] = {
JDOC("(math/atan2 y x)\n\n"
"Return the arctangent of y/x. Works even when x is 0.")
},
{
"math/rng", cfun_rng_make,
JDOC("(math/rng &opt seed)\n\n"
"Creates a Psuedo-Random number generator, with an optional seed. "
"The seed should be an unsigned 32 bit integer. "
"Do not use this for cryptography. Returns a core/rng abstract type.")
},
{
"math/rng-uniform", cfun_rng_uniform,
JDOC("(math/rng-seed rng seed)\n\n"
"Extract a random number in the range [0, 1) from the RNG.")
},
{
"math/rng-int", cfun_rng_int,
JDOC("(math/rng-int rng &opt max)\n\n"
"Extract a random random integer in the range [0, max] from the RNG. If "
"no max is given, the default is 2^31 - 1.")
},
{NULL, NULL, NULL}
};

View File

@ -55,8 +55,13 @@ int janet_dobytes(JanetTable *env, const uint8_t *bytes, int32_t len, const char
done = 1;
}
} else {
janet_eprintf("compile error in %s: %s\n", sourcePath,
(const char *)cres.error);
if (cres.macrofiber) {
janet_eprintf("compile error in %s: ", sourcePath);
janet_stacktrace(cres.macrofiber, janet_wrap_string(cres.error));
} else {
janet_eprintf("compile error in %s: %s\n", sourcePath,
(const char *)cres.error);
}
errflags |= 0x02;
done = 1;
}

View File

@ -145,6 +145,16 @@ int janet_compare(Janet x, Janet y) {
return (janet_type(x) < janet_type(y)) ? -1 : 1;
}
static int32_t getter_checkint(Janet key, int32_t max) {
if (!janet_checkint(key)) goto bad;
int32_t ret = janet_unwrap_integer(key);
if (ret < 0) goto bad;
if (ret >= max) goto bad;
return ret;
bad:
janet_panicf("expected integer key in range [0, %d), got %v", max, key);
}
/* Gets a value and returns. Can panic. */
Janet janet_get(Janet ds, Janet key) {
Janet value;
@ -160,56 +170,28 @@ Janet janet_get(Janet ds, Janet key) {
break;
case JANET_ARRAY: {
JanetArray *array = janet_unwrap_array(ds);
int32_t index;
if (!janet_checkint(key))
janet_panic("expected integer key");
index = janet_unwrap_integer(key);
if (index < 0 || index >= array->count) {
value = janet_wrap_nil();
} else {
value = array->data[index];
}
int32_t index = getter_checkint(key, array->count);
value = array->data[index];
break;
}
case JANET_TUPLE: {
const Janet *tuple = janet_unwrap_tuple(ds);
int32_t index;
if (!janet_checkint(key))
janet_panic("expected integer key");
index = janet_unwrap_integer(key);
if (index < 0 || index >= janet_tuple_length(tuple)) {
value = janet_wrap_nil();
} else {
value = tuple[index];
}
int32_t len = janet_tuple_length(tuple);
value = tuple[getter_checkint(key, len)];
break;
}
case JANET_BUFFER: {
JanetBuffer *buffer = janet_unwrap_buffer(ds);
int32_t index;
if (!janet_checkint(key))
janet_panic("expected integer key");
index = janet_unwrap_integer(key);
if (index < 0 || index >= buffer->count) {
value = janet_wrap_nil();
} else {
value = janet_wrap_integer(buffer->data[index]);
}
int32_t index = getter_checkint(key, buffer->count);
value = janet_wrap_integer(buffer->data[index]);
break;
}
case JANET_STRING:
case JANET_SYMBOL:
case JANET_KEYWORD: {
const uint8_t *str = janet_unwrap_string(ds);
int32_t index;
if (!janet_checkint(key))
janet_panic("expected integer key");
index = janet_unwrap_integer(key);
if (index < 0 || index >= janet_string_length(str)) {
value = janet_wrap_nil();
} else {
value = janet_wrap_integer(str[index]);
}
int32_t index = getter_checkint(key, janet_string_length(str));
value = janet_wrap_integer(str[index]);
break;
}
case JANET_ABSTRACT: {
@ -356,7 +338,7 @@ void janet_putindex(Janet ds, int32_t index, Janet value) {
janet_buffer_ensure(buffer, index + 1, 2);
buffer->count = index + 1;
}
buffer->data[index] = janet_unwrap_integer(value);
buffer->data[index] = (uint8_t)(janet_unwrap_integer(value) & 0xFF);
break;
}
case JANET_TABLE: {
@ -382,11 +364,8 @@ void janet_put(Janet ds, Janet key, Janet value) {
janet_panicf("expected %T, got %v",
JANET_TFLAG_ARRAY | JANET_TFLAG_BUFFER | JANET_TFLAG_TABLE, ds);
case JANET_ARRAY: {
int32_t index;
JanetArray *array = janet_unwrap_array(ds);
if (!janet_checkint(key)) janet_panicf("expected integer key, got %v", key);
index = janet_unwrap_integer(key);
if (index < 0 || index == INT32_MAX) janet_panicf("bad integer key, got %v", key);
int32_t index = getter_checkint(key, INT32_MAX - 1);
if (index >= array->count) {
janet_array_setcount(array, index + 1);
}
@ -394,11 +373,8 @@ void janet_put(Janet ds, Janet key, Janet value) {
break;
}
case JANET_BUFFER: {
int32_t index;
JanetBuffer *buffer = janet_unwrap_buffer(ds);
if (!janet_checkint(key)) janet_panicf("expected integer key, got %v", key);
index = janet_unwrap_integer(key);
if (index < 0 || index == INT32_MAX) janet_panicf("bad integer key, got %v", key);
int32_t index = getter_checkint(key, INT32_MAX - 1);
if (!janet_checkint(value))
janet_panicf("can only put integers in buffers, got %v", value);
if (index >= buffer->count) {

View File

@ -1171,6 +1171,8 @@ int janet_init(void) {
/* Initialize registry */
janet_vm_registry = janet_table(0);
janet_gcroot(janet_wrap_table(janet_vm_registry));
/* Seed RNG */
janet_rng_seed(janet_default_rng(), 0);
return 0;
}

View File

@ -316,6 +316,7 @@ typedef struct JanetView JanetView;
typedef struct JanetByteView JanetByteView;
typedef struct JanetDictView JanetDictView;
typedef struct JanetRange JanetRange;
typedef struct JanetRNG JanetRNG;
typedef Janet(*JanetCFunction)(int32_t argc, Janet *argv);
/* Basic types for all Janet Values */
@ -927,6 +928,11 @@ struct JanetRange {
int32_t end;
};
struct JanetRNG {
uint32_t a, b, c, d;
uint32_t counter;
};
/***** END SECTION TYPES *****/
/***** START SECTION OPCODES *****/
@ -1103,6 +1109,11 @@ JANET_API void janet_debug_find(
JanetFuncDef **def_out, int32_t *pc_out,
const uint8_t *source, int32_t line, int32_t column);
/* RNG */
JANET_API JanetRNG *janet_default_rng(void);
JANET_API void janet_rng_seed(JanetRNG *rng, uint32_t seed);
JANET_API uint32_t janet_rng_u32(JanetRNG *rng);
/* Array functions */
JANET_API JanetArray *janet_array(int32_t capacity);
JANET_API JanetArray *janet_array_n(const Janet *elements, int32_t n);

View File

@ -193,4 +193,19 @@
# Trim empty string
(assert (= "" (string/trim " ")) "string/trim regression")
# RNGs
(defn test-rng
[rng]
(assert (all identity (seq [i :range [0 1000]]
(<= (math/rng-int rng i) i))) "math/rng-int test")
(assert (all identity (seq [i :range [0 1000]]
(def x (math/rng-uniform rng))
(and (>= x 0) (< x 1))))
"math/rng-uniform test"))
(def seedrng (math/rng 123))
(for i 0 75
(test-rng (math/rng (:int seedrng))))
(end-suite)