Commit 0cbc0722 authored by Paul Sokolovsky's avatar Paul Sokolovsky
Browse files

extmod/moduheapq: Adhoc changes to support ordering by utime.ticks_ms().

As required for further elaboration of uasyncio, like supporting baremetal
systems with wraparound timesources. This is not intended to be public
interface, and likely will be further refactored in the future.
parent 3c0da6a3
...@@ -28,9 +28,12 @@ ...@@ -28,9 +28,12 @@
#include "py/objlist.h" #include "py/objlist.h"
#include "py/runtime0.h" #include "py/runtime0.h"
#include "py/runtime.h" #include "py/runtime.h"
#include "py/smallint.h"
#if MICROPY_PY_UHEAPQ #if MICROPY_PY_UHEAPQ
#define MODULO MICROPY_PY_UTIME_TICKS_PERIOD
// the algorithm here is modelled on CPython's heapq.py // the algorithm here is modelled on CPython's heapq.py
STATIC mp_obj_list_t *get_heap(mp_obj_t heap_in) { STATIC mp_obj_list_t *get_heap(mp_obj_t heap_in) {
...@@ -40,12 +43,33 @@ STATIC mp_obj_list_t *get_heap(mp_obj_t heap_in) { ...@@ -40,12 +43,33 @@ STATIC mp_obj_list_t *get_heap(mp_obj_t heap_in) {
return MP_OBJ_TO_PTR(heap_in); return MP_OBJ_TO_PTR(heap_in);
} }
STATIC void heap_siftdown(mp_obj_list_t *heap, mp_uint_t start_pos, mp_uint_t pos) { STATIC bool time_less_than(mp_obj_t item, mp_obj_t parent) {
if (!MP_OBJ_IS_TYPE(item, &mp_type_tuple) || !MP_OBJ_IS_TYPE(parent, &mp_type_tuple)) {
mp_raise_TypeError("");
}
mp_obj_tuple_t *item_p = MP_OBJ_TO_PTR(item);
mp_obj_tuple_t *parent_p = MP_OBJ_TO_PTR(parent);
mp_uint_t item_tm = MP_OBJ_SMALL_INT_VALUE(item_p->items[0]);
mp_uint_t parent_tm = MP_OBJ_SMALL_INT_VALUE(parent_p->items[0]);
mp_uint_t res = parent_tm - item_tm;
if ((mp_int_t)res < 0) {
res += MODULO;
}
return res < (MODULO / 2);
}
STATIC void heap_siftdown(mp_obj_list_t *heap, mp_uint_t start_pos, mp_uint_t pos, bool timecmp) {
mp_obj_t item = heap->items[pos]; mp_obj_t item = heap->items[pos];
while (pos > start_pos) { while (pos > start_pos) {
mp_uint_t parent_pos = (pos - 1) >> 1; mp_uint_t parent_pos = (pos - 1) >> 1;
mp_obj_t parent = heap->items[parent_pos]; mp_obj_t parent = heap->items[parent_pos];
if (mp_binary_op(MP_BINARY_OP_LESS, item, parent) == mp_const_true) { bool lessthan;
if (MP_UNLIKELY(timecmp)) {
lessthan = time_less_than(item, parent);
} else {
lessthan = (mp_binary_op(MP_BINARY_OP_LESS, item, parent) == mp_const_true);
}
if (lessthan) {
heap->items[pos] = parent; heap->items[pos] = parent;
pos = parent_pos; pos = parent_pos;
} else { } else {
...@@ -55,32 +79,43 @@ STATIC void heap_siftdown(mp_obj_list_t *heap, mp_uint_t start_pos, mp_uint_t po ...@@ -55,32 +79,43 @@ STATIC void heap_siftdown(mp_obj_list_t *heap, mp_uint_t start_pos, mp_uint_t po
heap->items[pos] = item; heap->items[pos] = item;
} }
STATIC void heap_siftup(mp_obj_list_t *heap, mp_uint_t pos) { STATIC void heap_siftup(mp_obj_list_t *heap, mp_uint_t pos, bool timecmp) {
mp_uint_t start_pos = pos; mp_uint_t start_pos = pos;
mp_uint_t end_pos = heap->len; mp_uint_t end_pos = heap->len;
mp_obj_t item = heap->items[pos]; mp_obj_t item = heap->items[pos];
for (mp_uint_t child_pos = 2 * pos + 1; child_pos < end_pos; child_pos = 2 * pos + 1) { for (mp_uint_t child_pos = 2 * pos + 1; child_pos < end_pos; child_pos = 2 * pos + 1) {
// choose right child if it's <= left child // choose right child if it's <= left child
if (child_pos + 1 < end_pos && mp_binary_op(MP_BINARY_OP_LESS, heap->items[child_pos], heap->items[child_pos + 1]) == mp_const_false) { if (child_pos + 1 < end_pos) {
child_pos += 1; bool lessthan;
if (MP_UNLIKELY(timecmp)) {
lessthan = time_less_than(heap->items[child_pos], heap->items[child_pos + 1]);
} else {
lessthan = (mp_binary_op(MP_BINARY_OP_LESS, heap->items[child_pos], heap->items[child_pos + 1]) == mp_const_true);
}
if (!lessthan) {
child_pos += 1;
}
} }
// bubble up the smaller child // bubble up the smaller child
heap->items[pos] = heap->items[child_pos]; heap->items[pos] = heap->items[child_pos];
pos = child_pos; pos = child_pos;
} }
heap->items[pos] = item; heap->items[pos] = item;
heap_siftdown(heap, start_pos, pos); heap_siftdown(heap, start_pos, pos, timecmp);
} }
STATIC mp_obj_t mod_uheapq_heappush(mp_obj_t heap_in, mp_obj_t item) { STATIC mp_obj_t mod_uheapq_heappush(size_t n_args, const mp_obj_t *args) {
mp_obj_t heap_in = args[0];
mp_obj_list_t *heap = get_heap(heap_in); mp_obj_list_t *heap = get_heap(heap_in);
mp_obj_list_append(heap_in, item); mp_obj_list_append(heap_in, args[1]);
heap_siftdown(heap, 0, heap->len - 1); bool is_timeq = (n_args > 2 && args[2] == mp_const_true);
heap_siftdown(heap, 0, heap->len - 1, is_timeq);
return mp_const_none; return mp_const_none;
} }
STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_uheapq_heappush_obj, mod_uheapq_heappush); STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_uheapq_heappush_obj, 2, 3, mod_uheapq_heappush);
STATIC mp_obj_t mod_uheapq_heappop(mp_obj_t heap_in) { STATIC mp_obj_t mod_uheapq_heappop(size_t n_args, const mp_obj_t *args) {
mp_obj_t heap_in = args[0];
mp_obj_list_t *heap = get_heap(heap_in); mp_obj_list_t *heap = get_heap(heap_in);
if (heap->len == 0) { if (heap->len == 0) {
nlr_raise(mp_obj_new_exception_msg(&mp_type_IndexError, "empty heap")); nlr_raise(mp_obj_new_exception_msg(&mp_type_IndexError, "empty heap"));
...@@ -90,16 +125,17 @@ STATIC mp_obj_t mod_uheapq_heappop(mp_obj_t heap_in) { ...@@ -90,16 +125,17 @@ STATIC mp_obj_t mod_uheapq_heappop(mp_obj_t heap_in) {
heap->items[0] = heap->items[heap->len]; heap->items[0] = heap->items[heap->len];
heap->items[heap->len] = MP_OBJ_NULL; // so we don't retain a pointer heap->items[heap->len] = MP_OBJ_NULL; // so we don't retain a pointer
if (heap->len) { if (heap->len) {
heap_siftup(heap, 0); bool is_timeq = (n_args > 1 && args[1] == mp_const_true);
heap_siftup(heap, 0, is_timeq);
} }
return item; return item;
} }
STATIC MP_DEFINE_CONST_FUN_OBJ_1(mod_uheapq_heappop_obj, mod_uheapq_heappop); STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(mod_uheapq_heappop_obj, 1, 2, mod_uheapq_heappop);
STATIC mp_obj_t mod_uheapq_heapify(mp_obj_t heap_in) { STATIC mp_obj_t mod_uheapq_heapify(mp_obj_t heap_in) {
mp_obj_list_t *heap = get_heap(heap_in); mp_obj_list_t *heap = get_heap(heap_in);
for (mp_uint_t i = heap->len / 2; i > 0;) { for (mp_uint_t i = heap->len / 2; i > 0;) {
heap_siftup(heap, --i); heap_siftup(heap, --i, false);
} }
return mp_const_none; return mp_const_none;
} }
......
# Test adhoc extension to uheapq to support wraparound
# time (utime.ticks_ms() style) task queue.
from utime import ticks_add, ticks_diff
import uheapq as heapq
DEBUG = 0
MAX = ticks_add(0, -1)
MODULO_HALF = MAX // 2 + 1
if DEBUG:
def dprint(*v):
print(*v)
else:
def dprint(*v):
pass
# Try not to crash on invalid data
h = []
heapq.heappush(h, 1)
try:
heapq.heappush(h, 2, True)
assert False
except TypeError:
pass
heapq.heappush(h, 2)
try:
heapq.heappop(h, True)
assert False
except TypeError:
pass
def pop_all(h):
l = []
while h:
l.append(heapq.heappop(h, True))
dprint(l)
return l
def add(h, v):
heapq.heappush(h, (v, None), True)
h = []
add(h, 0)
add(h, MAX)
add(h, MAX - 1)
add(h, 101)
add(h, 100)
add(h, MAX - 2)
dprint(h)
l = pop_all(h)
for i in range(len(l) - 1):
diff = ticks_diff(l[i + 1][0], l[i][0])
assert diff > 0
def edge_case(edge, offset):
h = []
add(h, ticks_add(0, offset))
add(h, ticks_add(edge, offset))
dprint(h)
l = pop_all(h)
diff = ticks_diff(l[1][0], l[0][0])
dprint(diff, diff > 0)
return diff
dprint("===")
diff = edge_case(MODULO_HALF - 1, 0)
assert diff == MODULO_HALF - 1
assert edge_case(MODULO_HALF - 1, 100) == diff
assert edge_case(MODULO_HALF - 1, -100) == diff
# We expect diff to be always positive, per the definition of heappop() which should return
# the smallest value.
# This is the edge case where this invariant breaks, due to assymetry of two's-complement
# range - there's one more negative integer than positive, so heappushing values like below
# will then make ticks_diff() return the minimum negative value. We could make heappop
# return them in a different order, but ticks_diff() result would be the same. Conclusion:
# never add to a heap values where (a - b) == MODULO_HALF (and which are >= MODULO_HALF
# ticks apart in real time of course).
dprint("===")
diff = edge_case(MODULO_HALF, 0)
assert diff == -MODULO_HALF
assert edge_case(MODULO_HALF, 100) == diff
assert edge_case(MODULO_HALF, -100) == diff
dprint("===")
diff = edge_case(MODULO_HALF + 1, 0)
assert diff == MODULO_HALF - 1
assert edge_case(MODULO_HALF + 1, 100) == diff
assert edge_case(MODULO_HALF + 1, -100) == diff
print("OK")
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment