cbs.c 21.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
/*
 * This file is part of FFmpeg.
 *
 * FFmpeg is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation; either
 * version 2.1 of the License, or (at your option) any later version.
 *
 * FFmpeg is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
 * License along with FFmpeg; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
 */

#include <string.h>

#include "config.h"

#include "libavutil/avassert.h"
24
#include "libavutil/buffer.h"
25 26 27 28 29 30 31
#include "libavutil/common.h"

#include "cbs.h"
#include "cbs_internal.h"


static const CodedBitstreamType *cbs_type_table[] = {
32 33 34
#if CONFIG_CBS_AV1
    &ff_cbs_type_av1,
#endif
35 36 37
#if CONFIG_CBS_H264
    &ff_cbs_type_h264,
#endif
38 39 40
#if CONFIG_CBS_H265
    &ff_cbs_type_h265,
#endif
41 42 43
#if CONFIG_CBS_JPEG
    &ff_cbs_type_jpeg,
#endif
44 45 46
#if CONFIG_CBS_MPEG2
    &ff_cbs_type_mpeg2,
#endif
47 48 49
#if CONFIG_CBS_VP9
    &ff_cbs_type_vp9,
#endif
50 51
};

52
const enum AVCodecID ff_cbs_all_codec_ids[] = {
53 54 55
#if CONFIG_CBS_AV1
    AV_CODEC_ID_AV1,
#endif
56 57 58 59 60 61
#if CONFIG_CBS_H264
    AV_CODEC_ID_H264,
#endif
#if CONFIG_CBS_H265
    AV_CODEC_ID_H265,
#endif
62 63 64
#if CONFIG_CBS_JPEG
    AV_CODEC_ID_MJPEG,
#endif
65 66
#if CONFIG_CBS_MPEG2
    AV_CODEC_ID_MPEG2VIDEO,
67 68 69
#endif
#if CONFIG_CBS_VP9
    AV_CODEC_ID_VP9,
70 71 72 73
#endif
    AV_CODEC_ID_NONE
};

74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
int ff_cbs_init(CodedBitstreamContext **ctx_ptr,
                enum AVCodecID codec_id, void *log_ctx)
{
    CodedBitstreamContext *ctx;
    const CodedBitstreamType *type;
    int i;

    type = NULL;
    for (i = 0; i < FF_ARRAY_ELEMS(cbs_type_table); i++) {
        if (cbs_type_table[i]->codec_id == codec_id) {
            type = cbs_type_table[i];
            break;
        }
    }
    if (!type)
        return AVERROR(EINVAL);

    ctx = av_mallocz(sizeof(*ctx));
    if (!ctx)
        return AVERROR(ENOMEM);

    ctx->log_ctx = log_ctx;
    ctx->codec   = type;

98 99 100 101 102 103
    if (type->priv_data_size) {
        ctx->priv_data = av_mallocz(ctx->codec->priv_data_size);
        if (!ctx->priv_data) {
            av_freep(&ctx);
            return AVERROR(ENOMEM);
        }
104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124
    }

    ctx->decompose_unit_types = NULL;

    ctx->trace_enable = 0;
    ctx->trace_level  = AV_LOG_TRACE;

    *ctx_ptr = ctx;
    return 0;
}

void ff_cbs_close(CodedBitstreamContext **ctx_ptr)
{
    CodedBitstreamContext *ctx = *ctx_ptr;

    if (!ctx)
        return;

    if (ctx->codec && ctx->codec->close)
        ctx->codec->close(ctx);

125
    av_freep(&ctx->write_buffer);
126 127 128 129 130 131 132
    av_freep(&ctx->priv_data);
    av_freep(ctx_ptr);
}

static void cbs_unit_uninit(CodedBitstreamContext *ctx,
                            CodedBitstreamUnit *unit)
{
133 134
    av_buffer_unref(&unit->content_ref);
    unit->content = NULL;
135

136 137 138
    av_buffer_unref(&unit->data_ref);
    unit->data             = NULL;
    unit->data_size        = 0;
139 140 141
    unit->data_bit_padding = 0;
}

142 143
void ff_cbs_fragment_reset(CodedBitstreamContext *ctx,
                           CodedBitstreamFragment *frag)
144 145 146 147 148 149 150
{
    int i;

    for (i = 0; i < frag->nb_units; i++)
        cbs_unit_uninit(ctx, &frag->units[i]);
    frag->nb_units = 0;

151 152
    av_buffer_unref(&frag->data_ref);
    frag->data             = NULL;
153 154 155 156
    frag->data_size        = 0;
    frag->data_bit_padding = 0;
}

157 158 159 160 161 162 163 164 165
void ff_cbs_fragment_free(CodedBitstreamContext *ctx,
                          CodedBitstreamFragment *frag)
{
    ff_cbs_fragment_reset(ctx, frag);

    av_freep(&frag->units);
    frag->nb_units_allocated = 0;
}

166 167 168 169 170 171
static int cbs_read_fragment_content(CodedBitstreamContext *ctx,
                                     CodedBitstreamFragment *frag)
{
    int err, i, j;

    for (i = 0; i < frag->nb_units; i++) {
172 173
        CodedBitstreamUnit *unit = &frag->units[i];

174 175
        if (ctx->decompose_unit_types) {
            for (j = 0; j < ctx->nb_decompose_unit_types; j++) {
176
                if (ctx->decompose_unit_types[j] == unit->type)
177 178 179 180 181 182
                    break;
            }
            if (j >= ctx->nb_decompose_unit_types)
                continue;
        }

183 184 185 186
        av_buffer_unref(&unit->content_ref);
        unit->content = NULL;

        av_assert0(unit->data && unit->data_ref);
187

188
        err = ctx->codec->read_unit(ctx, unit);
189
        if (err == AVERROR(ENOSYS)) {
190
            av_log(ctx->log_ctx, AV_LOG_VERBOSE,
191
                   "Decomposition unimplemented for unit %d "
192
                   "(type %"PRIu32").\n", i, unit->type);
193 194
        } else if (err < 0) {
            av_log(ctx->log_ctx, AV_LOG_ERROR, "Failed to read unit %d "
195
                   "(type %"PRIu32").\n", i, unit->type);
196 197 198 199 200 201 202
            return err;
        }
    }

    return 0;
}

203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223
static int cbs_fill_fragment_data(CodedBitstreamContext *ctx,
                                  CodedBitstreamFragment *frag,
                                  const uint8_t *data, size_t size)
{
    av_assert0(!frag->data && !frag->data_ref);

    frag->data_ref =
        av_buffer_alloc(size + AV_INPUT_BUFFER_PADDING_SIZE);
    if (!frag->data_ref)
        return AVERROR(ENOMEM);

    frag->data      = frag->data_ref->data;
    frag->data_size = size;

    memcpy(frag->data, data, size);
    memset(frag->data + size, 0,
           AV_INPUT_BUFFER_PADDING_SIZE);

    return 0;
}

224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241
int ff_cbs_read_extradata(CodedBitstreamContext *ctx,
                          CodedBitstreamFragment *frag,
                          const AVCodecParameters *par)
{
    int err;

    err = cbs_fill_fragment_data(ctx, frag, par->extradata,
                                 par->extradata_size);
    if (err < 0)
        return err;

    err = ctx->codec->split_fragment(ctx, frag, 1);
    if (err < 0)
        return err;

    return cbs_read_fragment_content(ctx, frag);
}

242 243 244 245 246 247
int ff_cbs_read_packet(CodedBitstreamContext *ctx,
                       CodedBitstreamFragment *frag,
                       const AVPacket *pkt)
{
    int err;

248 249 250 251 252 253 254 255 256 257 258 259 260
    if (pkt->buf) {
        frag->data_ref = av_buffer_ref(pkt->buf);
        if (!frag->data_ref)
            return AVERROR(ENOMEM);

        frag->data      = pkt->data;
        frag->data_size = pkt->size;

    } else {
        err = cbs_fill_fragment_data(ctx, frag, pkt->data, pkt->size);
        if (err < 0)
            return err;
    }
261 262 263 264 265 266 267 268 269 270 271 272 273 274

    err = ctx->codec->split_fragment(ctx, frag, 0);
    if (err < 0)
        return err;

    return cbs_read_fragment_content(ctx, frag);
}

int ff_cbs_read(CodedBitstreamContext *ctx,
                CodedBitstreamFragment *frag,
                const uint8_t *data, size_t size)
{
    int err;

275 276 277
    err = cbs_fill_fragment_data(ctx, frag, data, size);
    if (err < 0)
        return err;
278 279 280 281 282 283 284 285

    err = ctx->codec->split_fragment(ctx, frag, 0);
    if (err < 0)
        return err;

    return cbs_read_fragment_content(ctx, frag);
}

286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311
static int cbs_write_unit_data(CodedBitstreamContext *ctx,
                               CodedBitstreamUnit *unit)
{
    PutBitContext pbc;
    int ret;

    if (!ctx->write_buffer) {
        // Initial write buffer size is 1MB.
        ctx->write_buffer_size = 1024 * 1024;

    reallocate_and_try_again:
        ret = av_reallocp(&ctx->write_buffer, ctx->write_buffer_size);
        if (ret < 0) {
            av_log(ctx->log_ctx, AV_LOG_ERROR, "Unable to allocate a "
                   "sufficiently large write buffer (last attempt "
                   "%"SIZE_SPECIFIER" bytes).\n", ctx->write_buffer_size);
            return ret;
        }
    }

    init_put_bits(&pbc, ctx->write_buffer, ctx->write_buffer_size);

    ret = ctx->codec->write_unit(ctx, unit, &pbc);
    if (ret < 0) {
        if (ret == AVERROR(ENOSPC)) {
            // Overflow.
312 313 314
            if (ctx->write_buffer_size == INT_MAX / 8)
                return AVERROR(ENOMEM);
            ctx->write_buffer_size = FFMIN(2 * ctx->write_buffer_size, INT_MAX / 8);
315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338
            goto reallocate_and_try_again;
        }
        // Write failed for some other reason.
        return ret;
    }

    // Overflow but we didn't notice.
    av_assert0(put_bits_count(&pbc) <= 8 * ctx->write_buffer_size);

    if (put_bits_count(&pbc) % 8)
        unit->data_bit_padding = 8 - put_bits_count(&pbc) % 8;
    else
        unit->data_bit_padding = 0;

    flush_put_bits(&pbc);

    ret = ff_cbs_alloc_unit_data(ctx, unit, put_bits_count(&pbc) / 8);
    if (ret < 0)
        return ret;

    memcpy(unit->data, ctx->write_buffer, unit->data_size);

    return 0;
}
339 340 341 342 343 344 345

int ff_cbs_write_fragment_data(CodedBitstreamContext *ctx,
                               CodedBitstreamFragment *frag)
{
    int err, i;

    for (i = 0; i < frag->nb_units; i++) {
346 347 348
        CodedBitstreamUnit *unit = &frag->units[i];

        if (!unit->content)
349 350
            continue;

351 352 353
        av_buffer_unref(&unit->data_ref);
        unit->data = NULL;

354
        err = cbs_write_unit_data(ctx, unit);
355 356
        if (err < 0) {
            av_log(ctx->log_ctx, AV_LOG_ERROR, "Failed to write unit %d "
357
                   "(type %"PRIu32").\n", i, unit->type);
358 359
            return err;
        }
360
        av_assert0(unit->data && unit->data_ref);
361 362
    }

363 364 365
    av_buffer_unref(&frag->data_ref);
    frag->data = NULL;

366 367 368 369 370
    err = ctx->codec->assemble_fragment(ctx, frag);
    if (err < 0) {
        av_log(ctx->log_ctx, AV_LOG_ERROR, "Failed to assemble fragment.\n");
        return err;
    }
371
    av_assert0(frag->data && frag->data_ref);
372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404

    return 0;
}

int ff_cbs_write_extradata(CodedBitstreamContext *ctx,
                           AVCodecParameters *par,
                           CodedBitstreamFragment *frag)
{
    int err;

    err = ff_cbs_write_fragment_data(ctx, frag);
    if (err < 0)
        return err;

    av_freep(&par->extradata);

    par->extradata = av_malloc(frag->data_size +
                               AV_INPUT_BUFFER_PADDING_SIZE);
    if (!par->extradata)
        return AVERROR(ENOMEM);

    memcpy(par->extradata, frag->data, frag->data_size);
    memset(par->extradata + frag->data_size, 0,
           AV_INPUT_BUFFER_PADDING_SIZE);
    par->extradata_size = frag->data_size;

    return 0;
}

int ff_cbs_write_packet(CodedBitstreamContext *ctx,
                        AVPacket *pkt,
                        CodedBitstreamFragment *frag)
{
405
    AVBufferRef *buf;
406 407 408 409 410 411
    int err;

    err = ff_cbs_write_fragment_data(ctx, frag);
    if (err < 0)
        return err;

412 413 414
    buf = av_buffer_ref(frag->data_ref);
    if (!buf)
        return AVERROR(ENOMEM);
415

416 417
    av_buffer_unref(&pkt->buf);

418 419
    pkt->buf  = buf;
    pkt->data = frag->data;
420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435
    pkt->size = frag->data_size;

    return 0;
}


void ff_cbs_trace_header(CodedBitstreamContext *ctx,
                         const char *name)
{
    if (!ctx->trace_enable)
        return;

    av_log(ctx->log_ctx, ctx->trace_level, "%s\n", name);
}

void ff_cbs_trace_syntax_element(CodedBitstreamContext *ctx, int position,
436 437
                                 const char *str, const int *subscripts,
                                 const char *bits, int64_t value)
438
{
439
    char name[256];
440
    size_t name_len, bits_len;
441
    int pad, subs, i, j, k, n;
442 443 444 445 446 447

    if (!ctx->trace_enable)
        return;

    av_assert0(value >= INT_MIN && value <= UINT32_MAX);

448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472
    subs = subscripts ? subscripts[0] : 0;
    n = 0;
    for (i = j = 0; str[i];) {
        if (str[i] == '[') {
            if (n < subs) {
                ++n;
                k = snprintf(name + j, sizeof(name) - j, "[%d", subscripts[n]);
                av_assert0(k > 0 && j + k < sizeof(name));
                j += k;
                for (++i; str[i] && str[i] != ']'; i++);
                av_assert0(str[i] == ']');
            } else {
                while (str[i] && str[i] != ']')
                    name[j++] = str[i++];
                av_assert0(str[i] == ']');
            }
        } else {
            av_assert0(j + 1 < sizeof(name));
            name[j++] = str[i++];
        }
    }
    av_assert0(j + 1 < sizeof(name));
    name[j] = 0;
    av_assert0(n == subs);

473 474 475 476 477 478 479 480 481 482 483 484 485
    name_len = strlen(name);
    bits_len = strlen(bits);

    if (name_len + bits_len > 60)
        pad = bits_len + 2;
    else
        pad = 61 - name_len;

    av_log(ctx->log_ctx, ctx->trace_level, "%-10d  %s%*s = %"PRId64"\n",
           position, name, pad, bits, value);
}

int ff_cbs_read_unsigned(CodedBitstreamContext *ctx, GetBitContext *gbc,
486 487
                         int width, const char *name,
                         const int *subscripts, uint32_t *write_to,
488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512
                         uint32_t range_min, uint32_t range_max)
{
    uint32_t value;
    int position;

    av_assert0(width > 0 && width <= 32);

    if (get_bits_left(gbc) < width) {
        av_log(ctx->log_ctx, AV_LOG_ERROR, "Invalid value at "
               "%s: bitstream ended.\n", name);
        return AVERROR_INVALIDDATA;
    }

    if (ctx->trace_enable)
        position = get_bits_count(gbc);

    value = get_bits_long(gbc, width);

    if (ctx->trace_enable) {
        char bits[33];
        int i;
        for (i = 0; i < width; i++)
            bits[i] = value >> (width - i - 1) & 1 ? '1' : '0';
        bits[i] = 0;

513 514
        ff_cbs_trace_syntax_element(ctx, position, name, subscripts,
                                    bits, value);
515 516 517 518 519 520 521 522 523 524 525 526 527 528
    }

    if (value < range_min || value > range_max) {
        av_log(ctx->log_ctx, AV_LOG_ERROR, "%s out of range: "
               "%"PRIu32", but must be in [%"PRIu32",%"PRIu32"].\n",
               name, value, range_min, range_max);
        return AVERROR_INVALIDDATA;
    }

    *write_to = value;
    return 0;
}

int ff_cbs_write_unsigned(CodedBitstreamContext *ctx, PutBitContext *pbc,
529 530
                          int width, const char *name,
                          const int *subscripts, uint32_t value,
531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551
                          uint32_t range_min, uint32_t range_max)
{
    av_assert0(width > 0 && width <= 32);

    if (value < range_min || value > range_max) {
        av_log(ctx->log_ctx, AV_LOG_ERROR, "%s out of range: "
               "%"PRIu32", but must be in [%"PRIu32",%"PRIu32"].\n",
               name, value, range_min, range_max);
        return AVERROR_INVALIDDATA;
    }

    if (put_bits_left(pbc) < width)
        return AVERROR(ENOSPC);

    if (ctx->trace_enable) {
        char bits[33];
        int i;
        for (i = 0; i < width; i++)
            bits[i] = value >> (width - i - 1) & 1 ? '1' : '0';
        bits[i] = 0;

552 553
        ff_cbs_trace_syntax_element(ctx, put_bits_count(pbc),
                                    name, subscripts, bits, value);
554 555 556 557 558 559 560 561 562 563
    }

    if (width < 32)
        put_bits(pbc, width, value);
    else
        put_bits32(pbc, value);

    return 0;
}

564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642
int ff_cbs_read_signed(CodedBitstreamContext *ctx, GetBitContext *gbc,
                       int width, const char *name,
                       const int *subscripts, int32_t *write_to,
                       int32_t range_min, int32_t range_max)
{
    int32_t value;
    int position;

    av_assert0(width > 0 && width <= 32);

    if (get_bits_left(gbc) < width) {
        av_log(ctx->log_ctx, AV_LOG_ERROR, "Invalid value at "
               "%s: bitstream ended.\n", name);
        return AVERROR_INVALIDDATA;
    }

    if (ctx->trace_enable)
        position = get_bits_count(gbc);

    value = get_sbits_long(gbc, width);

    if (ctx->trace_enable) {
        char bits[33];
        int i;
        for (i = 0; i < width; i++)
            bits[i] = value & (1U << (width - i - 1)) ? '1' : '0';
        bits[i] = 0;

        ff_cbs_trace_syntax_element(ctx, position, name, subscripts,
                                    bits, value);
    }

    if (value < range_min || value > range_max) {
        av_log(ctx->log_ctx, AV_LOG_ERROR, "%s out of range: "
               "%"PRId32", but must be in [%"PRId32",%"PRId32"].\n",
               name, value, range_min, range_max);
        return AVERROR_INVALIDDATA;
    }

    *write_to = value;
    return 0;
}

int ff_cbs_write_signed(CodedBitstreamContext *ctx, PutBitContext *pbc,
                        int width, const char *name,
                        const int *subscripts, int32_t value,
                        int32_t range_min, int32_t range_max)
{
    av_assert0(width > 0 && width <= 32);

    if (value < range_min || value > range_max) {
        av_log(ctx->log_ctx, AV_LOG_ERROR, "%s out of range: "
               "%"PRId32", but must be in [%"PRId32",%"PRId32"].\n",
               name, value, range_min, range_max);
        return AVERROR_INVALIDDATA;
    }

    if (put_bits_left(pbc) < width)
        return AVERROR(ENOSPC);

    if (ctx->trace_enable) {
        char bits[33];
        int i;
        for (i = 0; i < width; i++)
            bits[i] = value & (1U << (width - i - 1)) ? '1' : '0';
        bits[i] = 0;

        ff_cbs_trace_syntax_element(ctx, put_bits_count(pbc),
                                    name, subscripts, bits, value);
    }

    if (width < 32)
        put_sbits(pbc, width, value);
    else
        put_bits32(pbc, value);

    return 0;
}

643

644 645 646 647 648 649 650 651 652 653 654 655
int ff_cbs_alloc_unit_content(CodedBitstreamContext *ctx,
                              CodedBitstreamUnit *unit,
                              size_t size,
                              void (*free)(void *opaque, uint8_t *data))
{
    av_assert0(!unit->content && !unit->content_ref);

    unit->content = av_mallocz(size);
    if (!unit->content)
        return AVERROR(ENOMEM);

    unit->content_ref = av_buffer_create(unit->content, size,
656
                                         free, NULL, 0);
657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682
    if (!unit->content_ref) {
        av_freep(&unit->content);
        return AVERROR(ENOMEM);
    }

    return 0;
}

int ff_cbs_alloc_unit_data(CodedBitstreamContext *ctx,
                           CodedBitstreamUnit *unit,
                           size_t size)
{
    av_assert0(!unit->data && !unit->data_ref);

    unit->data_ref = av_buffer_alloc(size + AV_INPUT_BUFFER_PADDING_SIZE);
    if (!unit->data_ref)
        return AVERROR(ENOMEM);

    unit->data      = unit->data_ref->data;
    unit->data_size = size;

    memset(unit->data + size, 0, AV_INPUT_BUFFER_PADDING_SIZE);

    return 0;
}

683 684 685 686 687 688
static int cbs_insert_unit(CodedBitstreamContext *ctx,
                           CodedBitstreamFragment *frag,
                           int position)
{
    CodedBitstreamUnit *units;

689 690 691 692 693 694 695 696 697 698 699 700
    if (frag->nb_units < frag->nb_units_allocated) {
        units = frag->units;

        if (position < frag->nb_units)
            memmove(units + position + 1, units + position,
                    (frag->nb_units - position) * sizeof(*units));
    } else {
        units = av_malloc_array(frag->nb_units + 1, sizeof(*units));
        if (!units)
            return AVERROR(ENOMEM);

        ++frag->nb_units_allocated;
701

702 703 704 705 706 707 708
        if (position > 0)
            memcpy(units, frag->units, position * sizeof(*units));

        if (position < frag->nb_units)
            memcpy(units + position + 1, frag->units + position,
                   (frag->nb_units - position) * sizeof(*units));
    }
709 710 711

    memset(units + position, 0, sizeof(*units));

712 713 714 715 716
    if (units != frag->units) {
        av_free(frag->units);
        frag->units = units;
    }

717 718 719 720 721 722 723 724 725
    ++frag->nb_units;

    return 0;
}

int ff_cbs_insert_unit_content(CodedBitstreamContext *ctx,
                               CodedBitstreamFragment *frag,
                               int position,
                               CodedBitstreamUnitType type,
726 727
                               void *content,
                               AVBufferRef *content_buf)
728
{
729 730
    CodedBitstreamUnit *unit;
    AVBufferRef *content_ref;
731 732 733 734 735 736
    int err;

    if (position == -1)
        position = frag->nb_units;
    av_assert0(position >= 0 && position <= frag->nb_units);

737 738 739 740 741 742 743 744
    if (content_buf) {
        content_ref = av_buffer_ref(content_buf);
        if (!content_ref)
            return AVERROR(ENOMEM);
    } else {
        content_ref = NULL;
    }

745
    err = cbs_insert_unit(ctx, frag, position);
746 747
    if (err < 0) {
        av_buffer_unref(&content_ref);
748
        return err;
749
    }
750

751 752 753 754
    unit = &frag->units[position];
    unit->type        = type;
    unit->content     = content;
    unit->content_ref = content_ref;
755 756 757 758 759 760 761 762

    return 0;
}

int ff_cbs_insert_unit_data(CodedBitstreamContext *ctx,
                            CodedBitstreamFragment *frag,
                            int position,
                            CodedBitstreamUnitType type,
763 764
                            uint8_t *data, size_t data_size,
                            AVBufferRef *data_buf)
765
{
766 767
    CodedBitstreamUnit *unit;
    AVBufferRef *data_ref;
768 769 770 771 772 773
    int err;

    if (position == -1)
        position = frag->nb_units;
    av_assert0(position >= 0 && position <= frag->nb_units);

774 775 776 777
    if (data_buf)
        data_ref = av_buffer_ref(data_buf);
    else
        data_ref = av_buffer_create(data, data_size, NULL, NULL, 0);
778 779 780
    if (!data_ref) {
        if (!data_buf)
            av_free(data);
781
        return AVERROR(ENOMEM);
782
    }
783

784
    err = cbs_insert_unit(ctx, frag, position);
785 786
    if (err < 0) {
        av_buffer_unref(&data_ref);
787
        return err;
788
    }
789

790 791 792 793 794
    unit = &frag->units[position];
    unit->type      = type;
    unit->data      = data;
    unit->data_size = data_size;
    unit->data_ref  = data_ref;
795 796 797 798

    return 0;
}

799 800 801
void ff_cbs_delete_unit(CodedBitstreamContext *ctx,
                        CodedBitstreamFragment *frag,
                        int position)
802
{
803 804
    av_assert0(0 <= position && position < frag->nb_units
                             && "Unit to be deleted not in fragment.");
805 806 807 808 809

    cbs_unit_uninit(ctx, &frag->units[position]);

    --frag->nb_units;

810
    if (frag->nb_units > 0)
811 812 813 814
        memmove(frag->units + position,
                frag->units + position + 1,
                (frag->nb_units - position) * sizeof(*frag->units));
}