Commit 07241cd3 authored by Paul Sokolovsky's avatar Paul Sokolovsky
Browse files

py/objstringio: If created from immutable object, follow copy on write policy.

Don't create copy of immutable object's contents until .write() is called
on BytesIO.
parent b24ccfc6
...@@ -68,10 +68,23 @@ STATIC mp_uint_t stringio_read(mp_obj_t o_in, void *buf, mp_uint_t size, int *er ...@@ -68,10 +68,23 @@ STATIC mp_uint_t stringio_read(mp_obj_t o_in, void *buf, mp_uint_t size, int *er
return size; return size;
} }
STATIC void stringio_copy_on_write(mp_obj_stringio_t *o) {
const void *buf = o->vstr->buf;
o->vstr->buf = m_new(char, o->vstr->len);
memcpy(o->vstr->buf, buf, o->vstr->len);
o->vstr->fixed_buf = false;
o->ref_obj = MP_OBJ_NULL;
}
STATIC mp_uint_t stringio_write(mp_obj_t o_in, const void *buf, mp_uint_t size, int *errcode) { STATIC mp_uint_t stringio_write(mp_obj_t o_in, const void *buf, mp_uint_t size, int *errcode) {
(void)errcode; (void)errcode;
mp_obj_stringio_t *o = MP_OBJ_TO_PTR(o_in); mp_obj_stringio_t *o = MP_OBJ_TO_PTR(o_in);
check_stringio_is_open(o); check_stringio_is_open(o);
if (o->vstr->fixed_buf) {
stringio_copy_on_write(o);
}
mp_uint_t new_pos = o->pos + size; mp_uint_t new_pos = o->pos + size;
if (new_pos < size) { if (new_pos < size) {
// Writing <size> bytes will overflow o->pos beyond limit of mp_uint_t. // Writing <size> bytes will overflow o->pos beyond limit of mp_uint_t.
...@@ -155,11 +168,11 @@ STATIC mp_obj_t stringio___exit__(size_t n_args, const mp_obj_t *args) { ...@@ -155,11 +168,11 @@ STATIC mp_obj_t stringio___exit__(size_t n_args, const mp_obj_t *args) {
} }
STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(stringio___exit___obj, 4, 4, stringio___exit__); STATIC MP_DEFINE_CONST_FUN_OBJ_VAR_BETWEEN(stringio___exit___obj, 4, 4, stringio___exit__);
STATIC mp_obj_stringio_t *stringio_new(const mp_obj_type_t *type, mp_uint_t alloc) { STATIC mp_obj_stringio_t *stringio_new(const mp_obj_type_t *type) {
mp_obj_stringio_t *o = m_new_obj(mp_obj_stringio_t); mp_obj_stringio_t *o = m_new_obj(mp_obj_stringio_t);
o->base.type = type; o->base.type = type;
o->vstr = vstr_new(alloc);
o->pos = 0; o->pos = 0;
o->ref_obj = MP_OBJ_NULL;
return o; return o;
} }
...@@ -170,17 +183,28 @@ STATIC mp_obj_t stringio_make_new(const mp_obj_type_t *type_in, size_t n_args, s ...@@ -170,17 +183,28 @@ STATIC mp_obj_t stringio_make_new(const mp_obj_type_t *type_in, size_t n_args, s
bool initdata = false; bool initdata = false;
mp_buffer_info_t bufinfo; mp_buffer_info_t bufinfo;
mp_obj_stringio_t *o = stringio_new(type_in);
if (n_args > 0) { if (n_args > 0) {
if (MP_OBJ_IS_INT(args[0])) { if (MP_OBJ_IS_INT(args[0])) {
sz = mp_obj_get_int(args[0]); sz = mp_obj_get_int(args[0]);
} else { } else {
mp_get_buffer_raise(args[0], &bufinfo, MP_BUFFER_READ); mp_get_buffer_raise(args[0], &bufinfo, MP_BUFFER_READ);
if (MP_OBJ_IS_STR_OR_BYTES(args[0])) {
o->vstr = m_new_obj(vstr_t);
vstr_init_fixed_buf(o->vstr, bufinfo.len, bufinfo.buf);
o->vstr->len = bufinfo.len;
o->ref_obj = args[0];
return MP_OBJ_FROM_PTR(o);
}
sz = bufinfo.len; sz = bufinfo.len;
initdata = true; initdata = true;
} }
} }
mp_obj_stringio_t *o = stringio_new(type_in, sz); o->vstr = vstr_new(sz);
if (initdata) { if (initdata) {
stringio_write(MP_OBJ_FROM_PTR(o), bufinfo.buf, bufinfo.len, NULL); stringio_write(MP_OBJ_FROM_PTR(o), bufinfo.buf, bufinfo.len, NULL);
......
...@@ -33,6 +33,8 @@ typedef struct _mp_obj_stringio_t { ...@@ -33,6 +33,8 @@ typedef struct _mp_obj_stringio_t {
vstr_t *vstr; vstr_t *vstr;
// StringIO has single pointer used for both reading and writing // StringIO has single pointer used for both reading and writing
mp_uint_t pos; mp_uint_t pos;
// Underlying object buffered by this StringIO
mp_obj_t ref_obj;
} mp_obj_stringio_t; } mp_obj_stringio_t;
#endif // MICROPY_INCLUDED_PY_OBJSTRINGIO_H #endif // MICROPY_INCLUDED_PY_OBJSTRINGIO_H
# Make sure that write operations on io.BytesIO don't
# change original object it was constructed from.
try:
import uio as io
except ImportError:
import io
b = b"foobar"
a = io.BytesIO(b)
a.write(b"1")
print(b)
print(a.getvalue())
b = bytearray(b"foobar")
a = io.BytesIO(b)
a.write(b"1")
print(b)
print(a.getvalue())
# Creating BytesIO from immutable object should not immediately
# copy its content.
try:
import uio
import micropython
micropython.mem_total
except (ImportError, AttributeError):
print("SKIP")
raise SystemExit
data = b"1234" * 256
before = micropython.mem_total()
buf = uio.BytesIO(data)
after = micropython.mem_total()
print(after - before < len(data))
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment