diff --git a/src/yyjson.c b/src/yyjson.c index 6fd6f72..51025ab 100644 --- a/src/yyjson.c +++ b/src/yyjson.c @@ -2137,6 +2137,11 @@ static_inline bool size_add_is_overflow(usize size, usize add) { return size > (size + add); } +/** Returns whether the size is overflow after multiplication. */ +static_inline bool size_mul_is_overflow(usize size, usize mul) { + return mul != 0 && size > USIZE_MAX / mul; +} + /** Returns whether the size is power of 2 (size should not be 0). */ static_inline bool size_is_pow2(usize size) { return (size & (size - 1)) == 0; @@ -2804,62 +2809,103 @@ yyjson_mut_val *yyjson_mut_val_mut_copy(yyjson_mut_doc *doc, return NULL; } -/* Count the number of values and the total length of the strings. */ -static void yyjson_mut_stat(const yyjson_mut_val *val, +/* Count value and string pool sizes; return false on overflow. */ +static bool yyjson_mut_stat(const yyjson_mut_val *val, usize *val_sum, usize *str_sum) { - yyjson_type type = unsafe_yyjson_get_type(val); + yyjson_type type; + + type = unsafe_yyjson_get_type(val); + if (unlikely(size_add_is_overflow(*val_sum, 1))) return false; *val_sum += 1; + if (type == YYJSON_TYPE_ARR || type == YYJSON_TYPE_OBJ) { yyjson_mut_val *child = (yyjson_mut_val *)val->uni.ptr; usize len = unsafe_yyjson_get_len(val), i; - len <<= (u8)(type == YYJSON_TYPE_OBJ); + + if (type == YYJSON_TYPE_OBJ) { + if (unlikely(len > USIZE_MAX / 2)) return false; + len <<= 1; + } + + if (unlikely(size_add_is_overflow(*val_sum, len))) return false; *val_sum += len; + for (i = 0; i < len; i++) { yyjson_type stype = unsafe_yyjson_get_type(child); if (stype == YYJSON_TYPE_STR || stype == YYJSON_TYPE_RAW) { - *str_sum += unsafe_yyjson_get_len(child) + 1; + usize str_len = unsafe_yyjson_get_len(child); + if (unlikely(size_add_is_overflow(str_len, 1) || + size_add_is_overflow(*str_sum, str_len + 1))) { + return false; + } + *str_sum += str_len + 1; } else if (stype == YYJSON_TYPE_ARR || stype == YYJSON_TYPE_OBJ) { - yyjson_mut_stat(child, val_sum, str_sum); + if (unlikely(!yyjson_mut_stat(child, val_sum, str_sum))) { + return false; + } *val_sum -= 1; } child = child->next; } } else if (type == YYJSON_TYPE_STR || type == YYJSON_TYPE_RAW) { - *str_sum += unsafe_yyjson_get_len(val) + 1; + usize len = unsafe_yyjson_get_len(val); + if (unlikely(size_add_is_overflow(len, 1) || + size_add_is_overflow(*str_sum, len + 1))) { + return false; + } + *str_sum += len + 1; } + + return true; } /* Copy mutable values to immutable value pool. */ static usize yyjson_imut_copy(yyjson_val **val_ptr, char **buf_ptr, const yyjson_mut_val *mval) { yyjson_val *val = *val_ptr; - yyjson_type type = unsafe_yyjson_get_type(mval); + yyjson_type type; + + type = unsafe_yyjson_get_type(mval); + if (type == YYJSON_TYPE_ARR || type == YYJSON_TYPE_OBJ) { yyjson_mut_val *child = (yyjson_mut_val *)mval->uni.ptr; usize len = unsafe_yyjson_get_len(mval), i; usize val_sum = 1; + if (type == YYJSON_TYPE_OBJ) { + if (unlikely(len > USIZE_MAX / 2)) return 0; if (len) child = child->next->next; len <<= 1; } else { if (len) child = child->next; } + *val_ptr = val + 1; for (i = 0; i < len; i++) { - val_sum += yyjson_imut_copy(val_ptr, buf_ptr, child); + usize child_sum = yyjson_imut_copy(val_ptr, buf_ptr, child); + if (unlikely(child_sum == 0 || + size_add_is_overflow(val_sum, child_sum))) { + return 0; + } + val_sum += child_sum; child = child->next; } val->tag = mval->tag; + if (unlikely(size_mul_is_overflow(val_sum, sizeof(yyjson_val)))) { + return 0; + } val->uni.ofs = val_sum * sizeof(yyjson_val); return val_sum; } else if (type == YYJSON_TYPE_STR || type == YYJSON_TYPE_RAW) { char *buf = *buf_ptr; usize len = unsafe_yyjson_get_len(mval); + memcpy((void *)buf, (const void *)mval->uni.str, len); buf[len] = '\0'; val->tag = mval->tag; val->uni.str = buf; *val_ptr = val + 1; + if (unlikely(size_add_is_overflow(len, 1))) return 0; *buf_ptr = buf + len + 1; return 1; } else { @@ -2878,7 +2924,7 @@ yyjson_doc *yyjson_mut_doc_imut_copy(const yyjson_mut_doc *mdoc, yyjson_doc *yyjson_mut_val_imut_copy(const yyjson_mut_val *mval, const yyjson_alc *alc) { - usize val_num = 0, str_sum = 0, hdr_size, buf_size; + usize val_num = 0, str_sum = 0, hdr_size, val_size, buf_size; yyjson_doc *doc = NULL; yyjson_val *val_hdr = NULL; @@ -2889,11 +2935,15 @@ yyjson_doc *yyjson_mut_val_imut_copy(const yyjson_mut_val *mval, if (!alc) alc = &YYJSON_DEFAULT_ALC; /* traverse the input value to get pool size */ - yyjson_mut_stat(mval, &val_num, &str_sum); + if (unlikely(!yyjson_mut_stat(mval, &val_num, &str_sum))) return NULL; + if (unlikely(size_add_is_overflow(str_sum, 1))) return NULL; /* create doc and val pool */ hdr_size = size_align_up(sizeof(yyjson_doc), sizeof(yyjson_val)); - buf_size = hdr_size + val_num * sizeof(yyjson_val); + if (unlikely(size_mul_is_overflow(val_num, sizeof(yyjson_val)))) return NULL; + val_size = val_num * sizeof(yyjson_val); + if (unlikely(size_add_is_overflow(hdr_size, val_size))) return NULL; + buf_size = hdr_size + val_size; doc = (yyjson_doc *)alc->malloc(alc->ctx, buf_size); if (!doc) return NULL; memset(doc, 0, sizeof(yyjson_doc)); @@ -2913,6 +2963,11 @@ yyjson_doc *yyjson_mut_val_imut_copy(const yyjson_mut_val *mval, /* copy vals and strs */ doc->val_read = yyjson_imut_copy(&val_hdr, &str_hdr, mval); + if (unlikely(doc->val_read == 0)) { + if (doc->str_pool) alc->free(alc->ctx, doc->str_pool); + alc->free(alc->ctx, (void *)doc); + return NULL; + } doc->dat_read = str_sum + 1; return doc; }