Commit d1e355ea authored by Damien George's avatar Damien George
Browse files

py: Fix check of small-int overflow when parsing ints.

Also unifies use of SMALL_INT_FITS macro across parser and runtime.
parent 813ed3bd
...@@ -249,7 +249,7 @@ STATIC mp_parse_node_t fold_constants(compiler_t *comp, mp_parse_node_t pn, mp_m ...@@ -249,7 +249,7 @@ STATIC mp_parse_node_t fold_constants(compiler_t *comp, mp_parse_node_t pn, mp_m
// shouldn't happen // shouldn't happen
assert(0); assert(0);
} }
if (MP_PARSE_FITS_SMALL_INT(arg0)) { if (MP_SMALL_INT_FITS(arg0)) {
//printf("%ld + %ld\n", arg0, arg1); //printf("%ld + %ld\n", arg0, arg1);
pn = mp_parse_node_new_leaf(MP_PARSE_NODE_SMALL_INT, arg0); pn = mp_parse_node_new_leaf(MP_PARSE_NODE_SMALL_INT, arg0);
} }
...@@ -264,7 +264,7 @@ STATIC mp_parse_node_t fold_constants(compiler_t *comp, mp_parse_node_t pn, mp_m ...@@ -264,7 +264,7 @@ STATIC mp_parse_node_t fold_constants(compiler_t *comp, mp_parse_node_t pn, mp_m
// int * int // int * int
if (!mp_small_int_mul_overflow(arg0, arg1)) { if (!mp_small_int_mul_overflow(arg0, arg1)) {
arg0 *= arg1; arg0 *= arg1;
if (MP_PARSE_FITS_SMALL_INT(arg0)) { if (MP_SMALL_INT_FITS(arg0)) {
pn = mp_parse_node_new_leaf(MP_PARSE_NODE_SMALL_INT, arg0); pn = mp_parse_node_new_leaf(MP_PARSE_NODE_SMALL_INT, arg0);
} }
} }
...@@ -337,7 +337,7 @@ STATIC mp_parse_node_t fold_constants(compiler_t *comp, mp_parse_node_t pn, mp_m ...@@ -337,7 +337,7 @@ STATIC mp_parse_node_t fold_constants(compiler_t *comp, mp_parse_node_t pn, mp_m
mp_load_method_maybe(elem->value, q_attr, dest); mp_load_method_maybe(elem->value, q_attr, dest);
if (MP_OBJ_IS_SMALL_INT(dest[0]) && dest[1] == NULL) { if (MP_OBJ_IS_SMALL_INT(dest[0]) && dest[1] == NULL) {
machine_int_t val = MP_OBJ_SMALL_INT_VALUE(dest[0]); machine_int_t val = MP_OBJ_SMALL_INT_VALUE(dest[0]);
if (MP_PARSE_FITS_SMALL_INT(val)) { if (MP_SMALL_INT_FITS(val)) {
pn = mp_parse_node_new_leaf(MP_PARSE_NODE_SMALL_INT, val); pn = mp_parse_node_new_leaf(MP_PARSE_NODE_SMALL_INT, val);
} }
} }
......
...@@ -72,10 +72,6 @@ typedef struct _mp_obj_base_t mp_obj_base_t; ...@@ -72,10 +72,6 @@ typedef struct _mp_obj_base_t mp_obj_base_t;
// These macros check for small int, qstr or object, and access small int and qstr values // These macros check for small int, qstr or object, and access small int and qstr values
// In SMALL_INT, next-to-highest bits is used as sign, so both must match for value in range
#define MP_SMALL_INT_MIN ((mp_small_int_t)(((machine_int_t)WORD_MSBIT_HIGH) >> 1))
#define MP_SMALL_INT_MAX ((mp_small_int_t)(~(MP_SMALL_INT_MIN)))
#define MP_OBJ_FITS_SMALL_INT(n) ((((n) ^ ((n) << 1)) & WORD_MSBIT_HIGH) == 0)
// these macros have now become inline functions; see below // these macros have now become inline functions; see below
//#define MP_OBJ_IS_SMALL_INT(o) ((((mp_small_int_t)(o)) & 1) != 0) //#define MP_OBJ_IS_SMALL_INT(o) ((((mp_small_int_t)(o)) & 1) != 0)
//#define MP_OBJ_IS_QSTR(o) ((((mp_small_int_t)(o)) & 3) == 2) //#define MP_OBJ_IS_QSTR(o) ((((mp_small_int_t)(o)) & 3) == 2)
......
...@@ -35,6 +35,7 @@ ...@@ -35,6 +35,7 @@
#include "qstr.h" #include "qstr.h"
#include "obj.h" #include "obj.h"
#include "parsenum.h" #include "parsenum.h"
#include "smallint.h"
#include "mpz.h" #include "mpz.h"
#include "objint.h" #include "objint.h"
#include "runtime0.h" #include "runtime0.h"
...@@ -251,7 +252,7 @@ mp_obj_t mp_obj_new_int_from_uint(machine_uint_t value) { ...@@ -251,7 +252,7 @@ mp_obj_t mp_obj_new_int_from_uint(machine_uint_t value) {
} }
mp_obj_t mp_obj_new_int(machine_int_t value) { mp_obj_t mp_obj_new_int(machine_int_t value) {
if (MP_OBJ_FITS_SMALL_INT(value)) { if (MP_SMALL_INT_FITS(value)) {
return MP_OBJ_NEW_SMALL_INT(value); return MP_OBJ_NEW_SMALL_INT(value);
} }
nlr_raise(mp_obj_new_exception_msg(&mp_type_OverflowError, "small int overflow")); nlr_raise(mp_obj_new_exception_msg(&mp_type_OverflowError, "small int overflow"));
......
...@@ -34,6 +34,7 @@ ...@@ -34,6 +34,7 @@
#include "misc.h" #include "misc.h"
#include "qstr.h" #include "qstr.h"
#include "obj.h" #include "obj.h"
#include "smallint.h"
#include "mpz.h" #include "mpz.h"
#include "objint.h" #include "objint.h"
#include "runtime0.h" #include "runtime0.h"
...@@ -140,7 +141,7 @@ mp_obj_t mp_obj_int_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) { ...@@ -140,7 +141,7 @@ mp_obj_t mp_obj_int_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) {
} }
mp_obj_t mp_obj_new_int(machine_int_t value) { mp_obj_t mp_obj_new_int(machine_int_t value) {
if (MP_OBJ_FITS_SMALL_INT(value)) { if (MP_SMALL_INT_FITS(value)) {
return MP_OBJ_NEW_SMALL_INT(value); return MP_OBJ_NEW_SMALL_INT(value);
} }
return mp_obj_new_int_from_ll(value); return mp_obj_new_int_from_ll(value);
......
...@@ -35,6 +35,7 @@ ...@@ -35,6 +35,7 @@
#include "qstr.h" #include "qstr.h"
#include "parsenumbase.h" #include "parsenumbase.h"
#include "obj.h" #include "obj.h"
#include "smallint.h"
#include "mpz.h" #include "mpz.h"
#include "objint.h" #include "objint.h"
#include "runtime0.h" #include "runtime0.h"
...@@ -239,7 +240,7 @@ mp_obj_t mp_obj_int_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) { ...@@ -239,7 +240,7 @@ mp_obj_t mp_obj_int_binary_op(int op, mp_obj_t lhs_in, mp_obj_t rhs_in) {
} }
mp_obj_t mp_obj_new_int(machine_int_t value) { mp_obj_t mp_obj_new_int(machine_int_t value) {
if (MP_OBJ_FITS_SMALL_INT(value)) { if (MP_SMALL_INT_FITS(value)) {
return MP_OBJ_NEW_SMALL_INT(value); return MP_OBJ_NEW_SMALL_INT(value);
} }
return mp_obj_new_int_from_ll(value); return mp_obj_new_int_from_ll(value);
......
...@@ -36,6 +36,7 @@ ...@@ -36,6 +36,7 @@
#include "lexer.h" #include "lexer.h"
#include "parsenumbase.h" #include "parsenumbase.h"
#include "parse.h" #include "parse.h"
#include "smallint.h"
#define RULE_ACT_KIND_MASK (0xf0) #define RULE_ACT_KIND_MASK (0xf0)
#define RULE_ACT_ARG_MASK (0x0f) #define RULE_ACT_ARG_MASK (0x0f)
...@@ -311,13 +312,13 @@ STATIC void push_result_token(parser_t *parser, const mp_lexer_t *lex) { ...@@ -311,13 +312,13 @@ STATIC void push_result_token(parser_t *parser, const mp_lexer_t *lex) {
int i = mp_parse_num_base(str, len, &base); int i = mp_parse_num_base(str, len, &base);
bool overflow = false; bool overflow = false;
for (; i < len; i++) { for (; i < len; i++) {
machine_int_t old_val = int_val; int dig;
if (unichar_isdigit(str[i]) && str[i] - '0' < base) { if (unichar_isdigit(str[i]) && str[i] - '0' < base) {
int_val = base * int_val + str[i] - '0'; dig = str[i] - '0';
} else if (base == 16 && 'a' <= str[i] && str[i] <= 'f') { } else if (base == 16 && 'a' <= str[i] && str[i] <= 'f') {
int_val = base * int_val + str[i] - 'a' + 10; dig = str[i] - 'a' + 10;
} else if (base == 16 && 'A' <= str[i] && str[i] <= 'F') { } else if (base == 16 && 'A' <= str[i] && str[i] <= 'F') {
int_val = base * int_val + str[i] - 'A' + 10; dig = str[i] - 'A' + 10;
} else if (str[i] == '.' || str[i] == 'e' || str[i] == 'E' || str[i] == 'j' || str[i] == 'J') { } else if (str[i] == '.' || str[i] == 'e' || str[i] == 'E' || str[i] == 'j' || str[i] == 'J') {
dec = true; dec = true;
break; break;
...@@ -325,17 +326,18 @@ STATIC void push_result_token(parser_t *parser, const mp_lexer_t *lex) { ...@@ -325,17 +326,18 @@ STATIC void push_result_token(parser_t *parser, const mp_lexer_t *lex) {
small_int = false; small_int = false;
break; break;
} }
if (int_val < old_val) { // add next digi and check for overflow
// If new value became less than previous, it's overflow if (mp_small_int_mul_overflow(int_val, base)) {
overflow = true; overflow = true;
} else if ((old_val ^ int_val) & WORD_MSBIT_HIGH) { }
// If signed number changed sign - it's overflow int_val = int_val * base + dig;
if (!MP_SMALL_INT_FITS(int_val)) {
overflow = true; overflow = true;
} }
} }
if (dec) { if (dec) {
pn = mp_parse_node_new_leaf(MP_PARSE_NODE_DECIMAL, qstr_from_strn(str, len)); pn = mp_parse_node_new_leaf(MP_PARSE_NODE_DECIMAL, qstr_from_strn(str, len));
} else if (small_int && !overflow && MP_PARSE_FITS_SMALL_INT(int_val)) { } else if (small_int && !overflow && MP_SMALL_INT_FITS(int_val)) {
pn = mp_parse_node_new_leaf(MP_PARSE_NODE_SMALL_INT, int_val); pn = mp_parse_node_new_leaf(MP_PARSE_NODE_SMALL_INT, int_val);
} else { } else {
pn = mp_parse_node_new_leaf(MP_PARSE_NODE_INTEGER, qstr_from_strn(str, len)); pn = mp_parse_node_new_leaf(MP_PARSE_NODE_INTEGER, qstr_from_strn(str, len));
......
...@@ -37,13 +37,6 @@ struct _mp_lexer_t; ...@@ -37,13 +37,6 @@ struct _mp_lexer_t;
// - xx...x10010: a string of bytes; bits 5 and above are the qstr holding the value // - xx...x10010: a string of bytes; bits 5 and above are the qstr holding the value
// - xx...x10110: a token; bits 5 and above are mp_token_kind_t // - xx...x10110: a token; bits 5 and above are mp_token_kind_t
// TODO: these can now be unified with MP_OBJ_FITS_SMALL_INT(x)
// makes sure the top 2 bits of x are all cleared (positive number) or all set (negavite number)
// these macros can probably go somewhere else because they are used more than just in the parser
#define MP_UINT_HIGH_2_BITS (~((~((machine_uint_t)0)) >> 2))
// parser's small ints are different from VM small int
#define MP_PARSE_FITS_SMALL_INT(x) (((((machine_uint_t)(x)) & MP_UINT_HIGH_2_BITS) == 0) || ((((machine_uint_t)(x)) & MP_UINT_HIGH_2_BITS) == MP_UINT_HIGH_2_BITS))
#define MP_PARSE_NODE_NULL (0) #define MP_PARSE_NODE_NULL (0)
#define MP_PARSE_NODE_SMALL_INT (0x1) #define MP_PARSE_NODE_SMALL_INT (0x1)
#define MP_PARSE_NODE_ID (0x02) #define MP_PARSE_NODE_ID (0x02)
......
...@@ -34,6 +34,7 @@ ...@@ -34,6 +34,7 @@
#include "obj.h" #include "obj.h"
#include "parsenumbase.h" #include "parsenumbase.h"
#include "parsenum.h" #include "parsenum.h"
#include "smallint.h"
#if MICROPY_ENABLE_FLOAT #if MICROPY_ENABLE_FLOAT
#include <math.h> #include <math.h>
...@@ -70,16 +71,16 @@ mp_obj_t mp_parse_num_integer(const char *restrict str, uint len, int base) { ...@@ -70,16 +71,16 @@ mp_obj_t mp_parse_num_integer(const char *restrict str, uint len, int base) {
machine_int_t int_val = 0; machine_int_t int_val = 0;
const char *restrict str_val_start = str; const char *restrict str_val_start = str;
for (; str < top; str++) { for (; str < top; str++) {
machine_int_t old_val = int_val; // get next digit as a value
int dig = *str; int dig = *str;
if (unichar_isdigit(dig) && dig - '0' < base) { if (unichar_isdigit(dig) && dig - '0' < base) {
// 0-9 digit // 0-9 digit
int_val = base * int_val + dig - '0'; dig = dig - '0';
} else if (base == 16) { } else if (base == 16) {
dig |= 0x20; dig |= 0x20;
if ('a' <= dig && dig <= 'f') { if ('a' <= dig && dig <= 'f') {
// a-f hex digit // a-f hex digit
int_val = base * int_val + dig - 'a' + 10; dig = dig - 'a' + 10;
} else { } else {
// unknown character // unknown character
break; break;
...@@ -88,11 +89,13 @@ mp_obj_t mp_parse_num_integer(const char *restrict str, uint len, int base) { ...@@ -88,11 +89,13 @@ mp_obj_t mp_parse_num_integer(const char *restrict str, uint len, int base) {
// unknown character // unknown character
break; break;
} }
if (int_val < old_val) {
// If new value became less than previous, it's overflow // add next digi and check for overflow
if (mp_small_int_mul_overflow(int_val, base)) {
goto overflow; goto overflow;
} else if ((old_val ^ int_val) & WORD_MSBIT_HIGH) { }
// If signed number changed sign - it's overflow int_val = int_val * base + dig;
if (!MP_SMALL_INT_FITS(int_val)) {
goto overflow; goto overflow;
} }
} }
......
...@@ -413,7 +413,7 @@ mp_obj_t mp_binary_op(int op, mp_obj_t lhs, mp_obj_t rhs) { ...@@ -413,7 +413,7 @@ mp_obj_t mp_binary_op(int op, mp_obj_t lhs, mp_obj_t rhs) {
goto unsupported_op; goto unsupported_op;
} }
// TODO: We just should make mp_obj_new_int() inline and use that // TODO: We just should make mp_obj_new_int() inline and use that
if (MP_OBJ_FITS_SMALL_INT(lhs_val)) { if (MP_SMALL_INT_FITS(lhs_val)) {
return MP_OBJ_NEW_SMALL_INT(lhs_val); return MP_OBJ_NEW_SMALL_INT(lhs_val);
} else { } else {
return mp_obj_new_int(lhs_val); return mp_obj_new_int(lhs_val);
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include "mpconfig.h" #include "mpconfig.h"
#include "qstr.h" #include "qstr.h"
#include "obj.h" #include "obj.h"
#include "smallint.h"
bool mp_small_int_mul_overflow(machine_int_t x, machine_int_t y) { bool mp_small_int_mul_overflow(machine_int_t x, machine_int_t y) {
// Check for multiply overflow; see CERT INT32-C // Check for multiply overflow; see CERT INT32-C
......
...@@ -26,6 +26,11 @@ ...@@ -26,6 +26,11 @@
// Functions for small integer arithmetic // Functions for small integer arithmetic
// In SMALL_INT, next-to-highest bits is used as sign, so both must match for value in range
#define MP_SMALL_INT_MIN ((mp_small_int_t)(((machine_int_t)WORD_MSBIT_HIGH) >> 1))
#define MP_SMALL_INT_MAX ((mp_small_int_t)(~(MP_SMALL_INT_MIN)))
#define MP_SMALL_INT_FITS(n) ((((n) ^ ((n) << 1)) & WORD_MSBIT_HIGH) == 0)
bool mp_small_int_mul_overflow(machine_int_t x, machine_int_t y); bool mp_small_int_mul_overflow(machine_int_t x, machine_int_t y);
machine_int_t mp_small_int_modulo(machine_int_t dividend, machine_int_t divisor); machine_int_t mp_small_int_modulo(machine_int_t dividend, machine_int_t divisor);
machine_int_t mp_small_int_floor_divide(machine_int_t num, machine_int_t denom); machine_int_t mp_small_int_floor_divide(machine_int_t num, machine_int_t denom);
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment