1
0
mirror of https://github.com/janet-lang/janet synced 2025-01-27 07:34:44 +00:00

Update threads.c to avoid a deadlock.

This commit is contained in:
Calvin Rose 2019-12-06 01:46:23 -06:00
parent dbcceefc20
commit c804ae9f7c
3 changed files with 235 additions and 161 deletions

View File

@ -1,10 +1,11 @@
(defn worker-main
"Sends 11 messages back to parent"
[parent]
(def name (:receive parent))
(def interval (:receive parent))
(for i 0 10
(os/sleep interval)
(printf "thread %s wakeup no. %d" name i))
(:send parent (string/format "thread %s wakeup no. %d" name i)))
(:send parent name))
(defn make-worker
@ -14,10 +15,45 @@
(:send name)
(:send interval)))
(def bob (make-worker "bob" 0.2))
(def joe (make-worker "joe" 0.3))
(def sam (make-worker "sam" 0.5))
(def bob (make-worker "bob" 0.02))
(def joe (make-worker "joe" 0.03))
(def sam (make-worker "sam" 0.05))
(:close joe)
(try (:receive joe) ([err] (print "Got expected error: " err)))
# Receive out of order
(for i 0 3
(print "worker " (thread/receive [bob sam joe]) " finished!"))
(for i 0 22
(print (thread/receive [bob sam])))
#
# Recursive Thread Tree - should pause for a bit, and then print a cool zigzag.
#
(def rng (math/rng (os/cryptorand 16)))
(defn choose [& xs]
(in xs (:int rng (length xs))))
(defn worker-tree
[parent]
(def name (:receive parent))
(def depth (:receive parent))
(if (< depth 5)
(do
(defn subtree []
(-> (thread/new)
(:send worker-tree)
(:send (string name "/" (choose "bob" "marley" "harry" "suki" "anna" "yu")))
(:send (inc depth))))
(let [l (subtree)
r (subtree)
lrep (thread/receive l)
rrep (thread/receive r)]
(:send parent [name ;lrep ;rrep])))
(do
(:send parent [name]))))
(def lines (:receive (-> (thread/new) (:send worker-tree) (:send "adam") (:send 0))))
(map print lines)

View File

@ -31,17 +31,18 @@
#include <setjmp.h>
JANET_THREAD_LOCAL pthread_cond_t janet_vm_thread_cond;
JANET_THREAD_LOCAL pthread_mutex_t janet_vm_thread_lock;
static JANET_THREAD_LOCAL JanetThreadSelector janet_vm_thread_selector;
void janet_threads_init(void) {
pthread_cond_init(&janet_vm_thread_cond, NULL);
pthread_mutex_init(&janet_vm_thread_lock, NULL);
pthread_mutex_init(&janet_vm_thread_selector.mutex, NULL);
pthread_cond_init(&janet_vm_thread_selector.cond, NULL);
janet_vm_thread_selector.channel = NULL;
}
void janet_threads_deinit(void) {
pthread_cond_destroy(&janet_vm_thread_cond);
pthread_mutex_destroy(&janet_vm_thread_lock);
pthread_mutex_destroy(&janet_vm_thread_selector.mutex);
pthread_cond_destroy(&janet_vm_thread_selector.cond);
janet_vm_thread_selector.channel = NULL;
}
static JanetTable *janet_get_core_table(const char *name) {
@ -56,9 +57,10 @@ static JanetTable *janet_get_core_table(const char *name) {
static void janet_channel_init(JanetChannel *channel) {
janet_buffer_init(&channel->buf, 0);
pthread_mutex_init(&channel->lock, NULL);
channel->rx_cond = NULL;
channel->rx_lock = NULL;
channel->selector = NULL;
channel->refCount = 2;
channel->encode = NULL;
channel->decode = NULL;
}
/* Return 1 if channel memory should be freed, otherwise 0 */
@ -71,120 +73,10 @@ static int janet_channel_deref(JanetChannel *channel) {
} else {
channel->refCount--;
pthread_mutex_unlock(&channel->lock);
/* Wake up other side if they are blocked, otherwise
* they will block forever. */
if (NULL != channel->rx_cond) {
pthread_cond_signal(channel->rx_cond);
}
return 0;
}
}
/* Returns 1 if could not send. Does not block or panic. Bytes should be a janet value that
* has been marshalled. */
static int janet_channel_send(JanetChannel *channel, Janet msg, JanetTable *dict) {
pthread_mutex_lock(&channel->lock);
/* Check for closed channel */
if (channel->refCount <= 1) {
pthread_mutex_unlock(&channel->lock);
return 1;
}
/* Hack to capture all panics from marshalling. This works because
* we know janet_marshal won't mess with other essential global state. */
jmp_buf buf;
jmp_buf *old_buf = janet_vm_jmp_buf;
janet_vm_jmp_buf = &buf;
int32_t oldcount = channel->buf.count;
int ret = 0;
if (setjmp(buf)) {
ret = 1;
channel->buf.count = oldcount;
} else {
janet_marshal(&channel->buf, msg, dict, 0);
/* Was empty, signal to cond */
if (oldcount == 0 && (NULL != channel->rx_cond)) {
pthread_cond_signal(channel->rx_cond);
}
}
/* Cleanup */
janet_vm_jmp_buf = old_buf;
pthread_mutex_unlock(&channel->lock);
return ret;
}
/* Returns 1 if nothing in queue or failed to get item. Does not block or panic. Uses dict to read bytes from
* the channel and unmarshal them. */
static int janet_channel_receive(JanetChannel *channel, Janet *msg_out,
JanetTable *dict, int nowait) {
pthread_mutex_lock(&channel->lock);
/* If queue is empty, block for now. */
while (channel->buf.count == 0) {
/* Check for closed channel (1 ref left means other side quit) */
if (nowait || channel->refCount <= 1) {
pthread_mutex_unlock(&channel->lock);
return 1;
}
/* Since each thread sets its own rx_cond, we know it's not NULL */
pthread_cond_wait(channel->rx_cond, &channel->lock);
}
/* Hack to capture all panics from marshalling. This works because
* we know janet_marshal won't mess with other essential global state. */
jmp_buf buf;
jmp_buf *old_buf = janet_vm_jmp_buf;
janet_vm_jmp_buf = &buf;
/* Handle errors */
int ret = 0;
if (setjmp(buf)) {
/* Clear the channel on errors */
channel->buf.count = 0;
ret = 1;
} else {
/* Read from beginning of channel */
const uint8_t *nextItem = NULL;
Janet item = janet_unmarshal(channel->buf.data, channel->buf.count, 0, dict, &nextItem);
/* Update memory and put result into *msg_out */
int32_t chunkCount = nextItem - channel->buf.data;
memmove(channel->buf.data, nextItem, channel->buf.count - chunkCount);
channel->buf.count -= chunkCount;
*msg_out = item;
}
/* Cleanup */
janet_vm_jmp_buf = old_buf;
pthread_mutex_unlock(&channel->lock);
return ret;
}
static int janet_channel_select(int32_t n, JanetThread **threads,
Janet *msg_out) {
for (;;) {
/* First, loop over channels for any that have any messages, but
* don't acquire any locks. Any incorrect behavior here will not mess
* anything up*/
for (int32_t i = 0; i < n; i++) {
JanetThread *thread = threads[i];
JanetChannel *channel = thread->rx;
if (channel != NULL && channel->buf.count) {
int status = janet_channel_receive(channel, msg_out, thread->decode, 1);
if (!status) return 0;
}
}
/* If no messages waiting, wait for signal */
pthread_cond_wait(&janet_vm_thread_cond, &janet_vm_thread_lock);
}
}
static void janet_close_thread(JanetThread *thread) {
if (NULL != thread->rx) {
JanetChannel *rx = thread->rx;
@ -210,11 +102,160 @@ static int thread_gc(void *p, size_t size) {
static int thread_mark(void *p, size_t size) {
JanetThread *thread = (JanetThread *)p;
(void) size;
if (NULL != thread->encode) {
janet_mark(janet_wrap_table(thread->encode));
JanetChannel *rx = thread->rx;
JanetChannel *tx = thread->tx;
if (tx && tx->encode) {
janet_mark(janet_wrap_table(tx->encode));
}
if (NULL != thread->decode) {
janet_mark(janet_wrap_table(thread->decode));
if (rx && rx->encode) {
janet_mark(janet_wrap_table(rx->decode));
}
return 0;
}
/* Returns 1 if could not send, but do not panic or block (for long). */
static int janet_channel_send(JanetChannel *tx, Janet msg) {
JanetThreadSelector *selector = tx->selector;
/* Check for closed channel */
if (tx->refCount <= 1) return 1;
/* Hack to capture all panics from marshalling. This works because
* we know janet_marshal won't mess with other essential global state. */
jmp_buf buf;
jmp_buf *old_buf = janet_vm_jmp_buf;
janet_vm_jmp_buf = &buf;
int32_t oldcount = tx->buf.count;
int ret = 0;
if (setjmp(buf)) {
ret = 1;
tx->buf.count = oldcount;
} else {
janet_marshal(&tx->buf, msg, tx->encode, 0);
if (selector) {
pthread_mutex_lock(&selector->mutex);
if (!selector->channel) {
selector->channel = tx;
pthread_cond_signal(&selector->cond);
}
pthread_mutex_unlock(&selector->mutex);
}
}
/* Cleanup */
janet_vm_jmp_buf = old_buf;
return ret;
}
/* Returns 0 on successful message.
* Returns 1 if nothing in queue or failed to get item. In this case,
* also sets the channel's selector value.
* Returns 2 if channel closed.
* Does not block (for long) or panic, and sets the channel's selector
* . */
static int janet_channel_receive(JanetChannel *rx, Janet *msg_out) {
/* Check for no messages */
while (rx->buf.count == 0) {
int is_dead = rx->refCount <= 1;
rx->selector = &janet_vm_thread_selector;
return is_dead ? 2 : 1;
}
/* Hack to capture all panics from marshalling. This works because
* we know janet_marshal won't mess with other essential global state. */
jmp_buf buf;
jmp_buf *old_buf = janet_vm_jmp_buf;
janet_vm_jmp_buf = &buf;
/* Handle errors */
int ret = 0;
if (setjmp(buf)) {
rx->buf.count = 0;
rx->selector = &janet_vm_thread_selector;
ret = 1;
} else {
/* Read from beginning of channel */
const uint8_t *nextItem = NULL;
Janet item = janet_unmarshal(rx->buf.data, rx->buf.count,
0, rx->decode, &nextItem);
/* Update memory and put result into *msg_out */
int32_t chunkCount = nextItem - rx->buf.data;
memmove(rx->buf.data, nextItem, rx->buf.count - chunkCount);
rx->buf.count -= chunkCount;
*msg_out = item;
/* Got message, unset selector */
rx->selector = NULL;
}
/* Cleanup */
janet_vm_jmp_buf = old_buf;
return ret;
}
/* Get a message from one of the channels given. */
static int janet_channel_select(int32_t n, JanetChannel **rxs, Janet *msg_out) {
int32_t maxChannel = -1;
for (;;) {
janet_vm_thread_selector.channel = NULL;
/* Try each channel, first without acquiring locks and looking
* only for existing messages, then with acquiring
* locks, which will not miss messages. */
for (int trylock = 1; trylock >= 0; trylock--) {
for (int32_t i = 0; i < n; i++) {
JanetChannel *rx = rxs[i];
if (trylock) {
if (rx->buf.count == 0 || pthread_mutex_trylock(&rx->lock)) continue;
} else {
pthread_mutex_lock(&rxs[i]->lock);
}
int status = janet_channel_receive(rxs[i], msg_out);
pthread_mutex_unlock(&rxs[i]->lock);
if (status == 0) goto gotMessage;
maxChannel = maxChannel > i ? maxChannel : i;
if (status == 2) {
/* channel closed and will receive no more messages, drop it */
rxs[i] = rxs[--n];
--i;
}
}
}
/* All channels closed */
if (n == 0) return 1;
pthread_mutex_lock(&janet_vm_thread_selector.mutex);
{
/* Wait until we have a channel */
if (NULL == janet_vm_thread_selector.channel) {
pthread_cond_wait(
&janet_vm_thread_selector.cond,
&janet_vm_thread_selector.mutex);
}
/* Got channel, swap it with first channel, and
* then go back to receiving messages. */
JanetChannel *rx = janet_vm_thread_selector.channel;
int32_t index = 0;
while (rxs[index] != rx) index++;
rxs[index] = rxs[0];
rxs[0] = rx;
}
pthread_mutex_unlock(&janet_vm_thread_selector.mutex);
}
gotMessage:
/* got message, unset selectors and return */
for (int32_t j = 0; j <= maxChannel && j < n; j++) {
pthread_mutex_lock(&rxs[j]->lock);
rxs[j]->selector = NULL;
pthread_mutex_unlock(&rxs[j]->lock);
}
return 0;
}
@ -236,8 +277,8 @@ static JanetThread *janet_make_thread(JanetChannel *rx, JanetChannel *tx, JanetT
JanetThread *thread = janet_abstract(&Thread_AT, sizeof(JanetThread));
thread->rx = rx;
thread->tx = tx;
thread->encode = encode;
thread->decode = decode;
rx->decode = decode;
tx->encode = encode;
return thread;
}
@ -256,14 +297,12 @@ static int thread_worker(JanetChannel *tx) {
/* Create self thread */
JanetChannel *rx = tx + 1;
rx->rx_cond = &janet_vm_thread_cond;
rx->rx_lock = &janet_vm_thread_lock;
JanetThread *thread = janet_make_thread(rx, tx, encode, decode);
Janet threadv = janet_wrap_abstract(thread);
/* Unmarshal the function */
Janet funcv;
int status = janet_channel_receive(rx, &funcv, decode, 0);
int status = janet_channel_select(1, &rx, &funcv);
if (status) goto error;
if (!janet_checktype(funcv, JANET_FUNCTION)) goto error;
JanetFunction *func = janet_unwrap_function(funcv);
@ -328,8 +367,6 @@ static Janet cfun_thread_new(int32_t argc, Janet *argv) {
JanetChannel *tx = rx + 1;
janet_channel_init(rx);
janet_channel_init(tx);
rx->rx_cond = &janet_vm_thread_cond;
rx->rx_lock = &janet_vm_thread_lock;
JanetThread *thread = janet_make_thread(rx, tx, encode, decode);
if (janet_thread_start_child(thread))
janet_panic("could not start thread");
@ -338,9 +375,11 @@ static Janet cfun_thread_new(int32_t argc, Janet *argv) {
static Janet cfun_thread_send(int32_t argc, Janet *argv) {
janet_fixarity(argc, 2);
JanetThread *thread = janet_getthread(argv, 0);
if (NULL == thread->tx) janet_panic("channel has closed");
int status = janet_channel_send(thread->tx, argv[1], thread->encode);
JanetChannel *tx = janet_getthread(argv, 0)->tx;
if (NULL == tx) janet_panic("channel has closed");
pthread_mutex_lock(&tx->lock);
int status = janet_channel_send(tx, argv[1]);
pthread_mutex_unlock(&tx->lock);
if (status) {
janet_panicf("failed to send message %v", argv[1]);
}
@ -354,30 +393,24 @@ static Janet cfun_thread_receive(int32_t argc, Janet *argv) {
int32_t count;
const Janet *items;
if (janet_indexed_view(argv[0], &items, &count)) {
if (count == 0) {
janet_panic("expected at least 1 thread");
}
if (count == 1) {
JanetThread *thread = janet_getthread(items, 0);
if (NULL == thread->rx) janet_panic("channel has closed");
status = janet_channel_receive(thread->rx, &out, thread->decode, 0);
} else {
/* Select */
int32_t realcount = 0;
JanetThread **threads = janet_smalloc(sizeof(JanetThread *) * count);
/* Select on multiple threads */
if (count == 0) janet_panic("expected at least 1 thread");
int32_t realcount = 0;
JanetChannel *rxs_stack[10] = {NULL};
JanetChannel **rxs = (count > 10)
? janet_smalloc(count * sizeof(JanetChannel *))
: rxs_stack;
for (int32_t i = 0; i < count; i++) {
JanetThread *thread = janet_getthread(items, i);
if (thread->rx != NULL) threads[realcount++] = thread;
}
status = janet_channel_select(realcount, threads, &out);
janet_sfree(threads);
if (thread->rx != NULL) rxs[realcount++] = thread->rx;
}
status = janet_channel_select(realcount, rxs, &out);
if (rxs != rxs_stack) janet_sfree(rxs);
} else {
/* Get from one thread */
JanetThread *thread = janet_getthread(argv, 0);
if (NULL == thread->rx) janet_panic("channel has closed");
status = janet_channel_receive(thread->rx, &out, thread->decode, 0);
status = janet_channel_select(1, &thread->rx, &out);
}
if (status) {
janet_panic("failed to receive message");

View File

@ -947,18 +947,23 @@ struct JanetRNG {
#include <pthread.h>
typedef struct JanetThread JanetThread;
typedef struct JanetChannel JanetChannel;
typedef struct JanetThreadSelector JanetThreadSelector;
struct JanetThreadSelector {
pthread_mutex_t mutex;
pthread_cond_t cond;
JanetChannel *channel;
};
struct JanetChannel {
pthread_mutex_t lock;
pthread_cond_t *rx_cond;
pthread_mutex_t *rx_lock;
JanetBuffer buf;
int refCount;
JanetThreadSelector *selector;
JanetTable *encode; /* only touched by writers */
JanetTable *decode; /* only touched by readers */
};
struct JanetThread {
JanetChannel *rx;
JanetChannel *tx;
JanetTable *encode;
JanetTable *decode;
};
#endif