Commit be790f94 authored by John R. Lenton's avatar John R. Lenton
Browse files

Implemented set binary ops.

parent 0de386bf
......@@ -9,6 +9,7 @@
#include "mpqstr.h"
#include "obj.h"
#include "runtime.h"
#include "runtime0.h"
#include "map.h"
typedef struct _mp_obj_set_t {
......@@ -227,7 +228,7 @@ static mp_obj_t set_isdisjoint(mp_obj_t self_in, mp_obj_t other) {
}
static MP_DEFINE_CONST_FUN_OBJ_2(set_isdisjoint_obj, set_isdisjoint);
static mp_obj_t set_issubset(mp_obj_t self_in, mp_obj_t other_in) {
static mp_obj_t set_issubset_internal(mp_obj_t self_in, mp_obj_t other_in, bool proper) {
mp_obj_set_t *self;
bool cleanup_self = false;
if (MP_OBJ_IS_TYPE(self_in, &set_type)) {
......@@ -245,13 +246,17 @@ static mp_obj_t set_issubset(mp_obj_t self_in, mp_obj_t other_in) {
other = set_make_new(NULL, 1, &other_in);
cleanup_other = true;
}
mp_obj_t iter = set_getiter(self);
mp_obj_t next;
mp_obj_t out = mp_const_true;
while ((next = set_it_iternext(iter)) != mp_const_stop_iteration) {
if (!mp_set_lookup(&other->set, next, MP_MAP_LOOKUP)) {
out = mp_const_false;
break;
bool out = true;
if (proper && self->set.used == other->set.used) {
out = false;
} else {
mp_obj_t iter = set_getiter(self);
mp_obj_t next;
while ((next = set_it_iternext(iter)) != mp_const_stop_iteration) {
if (!mp_set_lookup(&other->set, next, MP_MAP_LOOKUP)) {
out = false;
break;
}
}
}
if (cleanup_self) {
......@@ -260,15 +265,39 @@ static mp_obj_t set_issubset(mp_obj_t self_in, mp_obj_t other_in) {
if (cleanup_other) {
set_clear(other);
}
return out;
return MP_BOOL(out);
}
static mp_obj_t set_issubset(mp_obj_t self_in, mp_obj_t other_in) {
return set_issubset_internal(self_in, other_in, false);
}
static MP_DEFINE_CONST_FUN_OBJ_2(set_issubset_obj, set_issubset);
static mp_obj_t set_issubset_proper(mp_obj_t self_in, mp_obj_t other_in) {
return set_issubset_internal(self_in, other_in, true);
}
static mp_obj_t set_issuperset(mp_obj_t self_in, mp_obj_t other_in) {
return set_issubset(other_in, self_in);
return set_issubset_internal(other_in, self_in, false);
}
static MP_DEFINE_CONST_FUN_OBJ_2(set_issuperset_obj, set_issuperset);
static mp_obj_t set_issuperset_proper(mp_obj_t self_in, mp_obj_t other_in) {
return set_issubset_internal(other_in, self_in, true);
}
static mp_obj_t set_equal(mp_obj_t self_in, mp_obj_t other_in) {
assert(MP_OBJ_IS_TYPE(self_in, &set_type));
mp_obj_set_t *self = self_in;
if (!MP_OBJ_IS_TYPE(other_in, &set_type)) {
return mp_const_false;
}
mp_obj_set_t *other = other_in;
if (self->set.used != other->set.used) {
return mp_const_false;
}
return set_issubset(self_in, other_in);
}
static mp_obj_t set_pop(mp_obj_t self_in) {
assert(MP_OBJ_IS_TYPE(self_in, &set_type));
mp_obj_set_t *self = self_in;
......@@ -341,6 +370,42 @@ static mp_obj_t set_union(mp_obj_t self_in, mp_obj_t other_in) {
static MP_DEFINE_CONST_FUN_OBJ_2(set_union_obj, set_union);
static mp_obj_t set_binary_op(int op, mp_obj_t lhs, mp_obj_t rhs) {
mp_obj_t args[] = {lhs, rhs};
switch (op) {
case RT_BINARY_OP_OR:
return set_union(lhs, rhs);
case RT_BINARY_OP_XOR:
return set_symmetric_difference(lhs, rhs);
case RT_BINARY_OP_AND:
return set_intersect(lhs, rhs);
case RT_BINARY_OP_SUBTRACT:
return set_diff(2, args);
case RT_BINARY_OP_INPLACE_OR:
return set_union(lhs, rhs);
case RT_BINARY_OP_INPLACE_XOR:
return set_symmetric_difference(lhs, rhs);
case RT_BINARY_OP_INPLACE_AND:
return set_intersect(lhs, rhs);
case RT_BINARY_OP_INPLACE_SUBTRACT:
return set_diff(2, args);
case RT_COMPARE_OP_LESS:
return set_issubset_proper(lhs, rhs);
case RT_COMPARE_OP_MORE:
return set_issuperset_proper(lhs, rhs);
case RT_COMPARE_OP_EQUAL:
return set_equal(lhs, rhs);
case RT_COMPARE_OP_LESS_EQUAL:
return set_issubset(lhs, rhs);
case RT_COMPARE_OP_MORE_EQUAL:
return set_issuperset(lhs, rhs);
case RT_COMPARE_OP_NOT_EQUAL:
return MP_BOOL(set_equal(lhs, rhs) == mp_const_false);
default:
// op not supported
return NULL;
}
}
/******************************************************************************/
/* set constructors & public C API */
......@@ -372,6 +437,7 @@ const mp_obj_type_t set_type = {
"set",
.print = set_print,
.make_new = set_make_new,
.binary_op = set_binary_op,
.getiter = set_getiter,
.methods = set_type_methods,
};
......
def r(s):
l = list(s)
l.sort()
print(l)
s = {1, 2}
t = {2, 3}
r(s | t)
r(s ^ t)
r(s & t)
r(s - t)
u = s.copy()
u |= t
r(u)
u = s.copy()
u ^= t
r(u)
u = s.copy()
u &= t
r(u)
u = s.copy()
u -= t
r(u)
print(s == t)
print(s != t)
print(s > t)
print(s >= t)
print(s < t)
print(s <= t)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment