Commit 66a5bf68 authored by Damien George's avatar Damien George
Browse files

Merge pull request #142 from chipaca/containment

Implemented support for `in` and `not in` operators.
parents 0f59203e f5a0a7d2
...@@ -57,6 +57,12 @@ static mp_obj_t dict_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) { ...@@ -57,6 +57,12 @@ static mp_obj_t dict_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) {
return elem->value; return elem->value;
} }
} }
case RT_COMPARE_OP_IN:
case RT_COMPARE_OP_NOT_IN:
{
mp_map_elem_t *elem = mp_map_lookup(&o->map, rhs_in, MP_MAP_LOOKUP);
return MP_BOOL((op == RT_COMPARE_OP_IN) ^ (elem == NULL));
}
default: default:
// op not supported // op not supported
return NULL; return NULL;
...@@ -362,10 +368,20 @@ static void dict_view_print(void (*print)(void *env, const char *fmt, ...), void ...@@ -362,10 +368,20 @@ static void dict_view_print(void (*print)(void *env, const char *fmt, ...), void
print(env, "])"); print(env, "])");
} }
static mp_obj_t dict_view_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) {
/* only supported for the 'keys' kind until sets and dicts are refactored */
mp_obj_dict_view_t *o = lhs_in;
if (o->kind != MP_DICT_VIEW_KEYS) return NULL;
if (op != RT_COMPARE_OP_IN && op != RT_COMPARE_OP_NOT_IN) return NULL;
return dict_binary_op(op, o->dict, rhs_in);
}
static const mp_obj_type_t dict_view_type = { static const mp_obj_type_t dict_view_type = {
{ &mp_const_type }, { &mp_const_type },
"dict_view", "dict_view",
.print = dict_view_print, .print = dict_view_print,
.binary_op = dict_view_binary_op,
.getiter = dict_view_getiter, .getiter = dict_view_getiter,
}; };
......
...@@ -45,6 +45,7 @@ void set_print(void (*print)(void *env, const char *fmt, ...), void *env, mp_obj ...@@ -45,6 +45,7 @@ void set_print(void (*print)(void *env, const char *fmt, ...), void *env, mp_obj
print(env, "}"); print(env, "}");
} }
static mp_obj_t set_make_new(mp_obj_t type_in, int n_args, const mp_obj_t *args) { static mp_obj_t set_make_new(mp_obj_t type_in, int n_args, const mp_obj_t *args) {
switch (n_args) { switch (n_args) {
case 0: case 0:
...@@ -405,6 +406,13 @@ static mp_obj_t set_binary_op(int op, mp_obj_t lhs, mp_obj_t rhs) { ...@@ -405,6 +406,13 @@ static mp_obj_t set_binary_op(int op, mp_obj_t lhs, mp_obj_t rhs) {
return set_issuperset(lhs, rhs); return set_issuperset(lhs, rhs);
case RT_COMPARE_OP_NOT_EQUAL: case RT_COMPARE_OP_NOT_EQUAL:
return MP_BOOL(set_equal(lhs, rhs) == mp_const_false); return MP_BOOL(set_equal(lhs, rhs) == mp_const_false);
case RT_COMPARE_OP_IN:
case RT_COMPARE_OP_NOT_IN:
{
mp_obj_set_t *o = lhs;
mp_obj_t elem = mp_set_lookup(&o->set, rhs, MP_MAP_LOOKUP);
return MP_BOOL((op == RT_COMPARE_OP_IN) ^ (elem == NULL));
}
default: default:
// op not supported // op not supported
return NULL; return NULL;
......
...@@ -85,6 +85,15 @@ mp_obj_t str_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) { ...@@ -85,6 +85,15 @@ mp_obj_t str_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) {
return mp_obj_new_str(qstr_from_str_take(val, alloc_len)); return mp_obj_new_str(qstr_from_str_take(val, alloc_len));
} }
break; break;
case RT_COMPARE_OP_IN:
case RT_COMPARE_OP_NOT_IN:
/* NOTE `a in b` is `b.__contains__(a)` */
if (MP_OBJ_IS_TYPE(rhs_in, &str_type)) {
const char *rhs_str = qstr_str(((mp_obj_str_t*)rhs_in)->qstr);
/* FIXME \0 in strs */
return MP_BOOL((op == RT_COMPARE_OP_IN) ^ (strstr(lhs_str, rhs_str) == NULL));
}
break;
} }
return MP_OBJ_NULL; // op not supported return MP_OBJ_NULL; // op not supported
......
...@@ -568,7 +568,41 @@ mp_obj_t rt_binary_op(int op, mp_obj_t lhs, mp_obj_t rhs) { ...@@ -568,7 +568,41 @@ mp_obj_t rt_binary_op(int op, mp_obj_t lhs, mp_obj_t rhs) {
} else if (MP_OBJ_IS_TYPE(rhs, &complex_type)) { } else if (MP_OBJ_IS_TYPE(rhs, &complex_type)) {
return mp_obj_complex_binary_op(op, lhs_val, 0, rhs); return mp_obj_complex_binary_op(op, lhs_val, 0, rhs);
} }
} else { }
/* deal with `in` and `not in`
*
* NOTE `a in b` is `b.__contains__(a)`, hence why the generic dispatch
* needs to go below
*/
if (op == RT_COMPARE_OP_IN || op == RT_COMPARE_OP_NOT_IN) {
if (!MP_OBJ_IS_SMALL_INT(rhs)) {
mp_obj_base_t *o = rhs;
if (o->type->binary_op != NULL) {
mp_obj_t res = o->type->binary_op(op, rhs, lhs);
if (res != NULL) {
return res;
}
}
if (o->type->getiter != NULL) {
/* second attempt, walk the iterator */
mp_obj_t next = NULL;
mp_obj_t iter = rt_getiter(rhs);
while ((next = rt_iternext(iter)) != mp_const_stop_iteration) {
if (mp_obj_equal(next, lhs)) {
return MP_BOOL(op == RT_COMPARE_OP_IN);
}
}
return MP_BOOL(op != RT_COMPARE_OP_IN);
}
}
nlr_jump(mp_obj_new_exception_msg_varg(
MP_QSTR_TypeError, "'%s' object is not iterable",
mp_obj_get_type_str(rhs)));
return mp_const_none;
}
if (MP_OBJ_IS_OBJ(lhs)) { if (MP_OBJ_IS_OBJ(lhs)) {
mp_obj_base_t *o = lhs; mp_obj_base_t *o = lhs;
if (o->type->binary_op != NULL) { if (o->type->binary_op != NULL) {
...@@ -577,13 +611,14 @@ mp_obj_t rt_binary_op(int op, mp_obj_t lhs, mp_obj_t rhs) { ...@@ -577,13 +611,14 @@ mp_obj_t rt_binary_op(int op, mp_obj_t lhs, mp_obj_t rhs) {
return result; return result;
} }
} }
} // TODO implement dispatch for reverse binary ops
} }
// TODO specify in error message what the operator is // TODO specify in error message what the operator is
nlr_jump(mp_obj_new_exception_msg_varg(MP_QSTR_TypeError, nlr_jump(mp_obj_new_exception_msg_varg(MP_QSTR_TypeError,
"unsupported operand types for binary operator: '%s', '%s'", "unsupported operand types for binary operator: '%s', '%s'",
mp_obj_get_type_str(lhs), mp_obj_get_type_str(rhs))); mp_obj_get_type_str(lhs), mp_obj_get_type_str(rhs)));
return mp_const_none;
} }
mp_obj_t rt_make_function_from_id(int unique_code_id) { mp_obj_t rt_make_function_from_id(int unique_code_id) {
......
...@@ -42,4 +42,7 @@ echo "$numpassed tests passed" ...@@ -42,4 +42,7 @@ echo "$numpassed tests passed"
if [[ $numfailed != 0 ]] if [[ $numfailed != 0 ]]
then then
echo "$numfailed tests failed -$namefailed" echo "$numfailed tests failed -$namefailed"
exit 1
else
exit 0
fi fi
for i in 1, 2:
for o in {1:2}, {1}, {1:2}.keys():
print("{} in {}: {}".format(i, o, i in o))
print("{} not in {}: {}".format(i, o, i not in o))
haystack = "supercalifragilistc"
for needle in (haystack[i:] for i in range(len(haystack))):
print(needle, "in", haystack, "::", needle in haystack)
print(needle, "not in", haystack, "::", needle not in haystack)
print(haystack, "in", needle, "::", haystack in needle)
print(haystack, "not in", needle, "::", haystack not in needle)
for needle in (haystack[:i+1] for i in range(len(haystack))):
print(needle, "in", haystack, "::", needle in haystack)
print(needle, "not in", haystack, "::", needle not in haystack)
print(haystack, "in", needle, "::", haystack in needle)
print(haystack, "not in", needle, "::", haystack not in needle)
# until here, the tests would work without the 'second attempt' iteration thing.
for i in 1, 2:
for o in [], [1], [1, 2]:
print("{} in {}: {}".format(i, o, i in o))
print("{} not in {}: {}".format(i, o, i not in o))
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment