Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions muxer.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ type Muxer struct {
packetSize int
tablesRetransmitPeriod int // period in PES packets

transportStreamID uint16
pmtPID uint16

pm *programMap // pid -> programNumber
pmUpdated bool
pmt PMTData
Expand Down Expand Up @@ -68,6 +71,22 @@ func MuxerOptTablesRetransmitPeriod(newPeriod int) func(*Muxer) {
}
}

// WithTransportStreamID sets the transport stream ID written into PAT.
// Default is 0.
func WithTransportStreamID(id uint16) func(*Muxer) {
return func(m *Muxer) {
m.transportStreamID = id
}
}

// WithPMTPID sets the PID used for PMT packets and advertised in PAT.
// Default is 0x1000.
func WithPMTPID(pid uint16) func(*Muxer) {
return func(m *Muxer) {
m.pmtPID = pid
}
}

// TODO MuxerOptAutodetectPCRPID selecting first video PID for each PMT, falling back to first audio, falling back to any other

func NewMuxer(ctx context.Context, w io.Writer, opts ...func(*Muxer)) *Muxer {
Expand All @@ -78,6 +97,8 @@ func NewMuxer(ctx context.Context, w io.Writer, opts ...func(*Muxer)) *Muxer {
packetSize: MpegTsPacketSize, // no 192-byte packet support yet
tablesRetransmitPeriod: 40,

pmtPID: pmtStartPID,

pm: newProgramMap(),
pmt: PMTData{
ElementaryStreams: []*PMTElementaryStream{},
Expand All @@ -97,14 +118,14 @@ func NewMuxer(ctx context.Context, w io.Writer, opts ...func(*Muxer)) *Muxer {
m.bufWriter = astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: &m.buf})
m.bitsWriter = astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: m.w})

// TODO multiple programs support
m.pm.setUnlocked(pmtStartPID, programNumberStart)
m.pmUpdated = true

for _, opt := range opts {
opt(m)
}

// TODO multiple programs support
m.pm.setUnlocked(m.pmtPID, programNumberStart)
m.pmUpdated = true

// to output tables at the very start
m.tablesRetransmitCounter = m.tablesRetransmitPeriod

Expand Down Expand Up @@ -322,6 +343,7 @@ func (m *Muxer) WriteTables() (int, error) {

func (m *Muxer) generatePAT() error {
d := m.pm.toPATDataUnlocked()
d.TransportStreamID = m.transportStreamID

versionNumber := m.patVersion.get()
if m.pmUpdated {
Expand Down Expand Up @@ -432,7 +454,7 @@ func (m *Muxer) generatePMT() error {
Header: PacketHeader{
HasPayload: true,
PayloadUnitStartIndicator: true,
PID: pmtStartPID, // FIXME multiple programs support
PID: m.pmtPID, // FIXME multiple programs support
ContinuityCounter: uint8(m.pmtCC.inc()),
},
Payload: m.buf.Bytes(),
Expand Down
138 changes: 138 additions & 0 deletions roundtrip_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
package astits

import (
"bytes"
"context"
"errors"
"os"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

type pesRecord struct {
pid uint16
pes *PESData
af *PacketAdaptationField
}

func TestRoundTrip(t *testing.T) {
originalBytes, err := os.ReadFile("testdata/ts/silent_audio.ts")
Copy link
Copy Markdown
Owner

@asticode asticode Mar 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your new test is failing since you forgot to add the silent_audio.ts file to this PR. Could you add it?

require.NoError(t, err)

// Phase 1: Demux the original TS file
dmx := NewDemuxer(context.Background(), bytes.NewReader(originalBytes), DemuxerOptPacketSize(MpegTsPacketSize))

var originalPAT *PATData
var originalPMT *PMTData
var originalPMTPID uint16 = 0xFFFF

for {
d, err := dmx.NextData()
if errors.Is(err, ErrNoMorePackets) {
break
}
require.NoError(t, err)

if d.PAT != nil {
originalPAT = d.PAT
originalPMTPID = d.PAT.Programs[0].ProgramMapID
}
if d.PMT != nil {
originalPMT = d.PMT
}

if originalPMT != nil && originalPAT != nil {
break
}
}
require.NotNil(t, originalPAT)
require.NotNil(t, originalPMT)
require.NotEqual(t, 0xFFFF, originalPMTPID)

// Phase 2: Mux everything back into a new TS stream, preserving PAT/PMT identifiers
var buf bytes.Buffer
muxer := NewMuxer(context.Background(), &buf,
WithTransportStreamID(originalPAT.TransportStreamID),
WithPMTPID(originalPMTPID),
)

for _, es := range originalPMT.ElementaryStreams {
err := muxer.AddElementaryStream(PMTElementaryStream{
ElementaryPID: es.ElementaryPID,
StreamType: es.StreamType,
ElementaryStreamDescriptors: es.ElementaryStreamDescriptors,
})
require.NoError(t, err)
}
muxer.SetPCRPID(originalPMT.PCRPID)
muxer.pmt.ProgramDescriptors = originalPMT.ProgramDescriptors
_, err = muxer.WriteTables()
require.NoError(t, err)

// Phase 3: Demux the round-tripped output
dmx2 := NewDemuxer(context.Background(), bytes.NewReader(buf.Bytes()), DemuxerOptPacketSize(MpegTsPacketSize))

var rtPAT *PATData
var rtPMT *PMTData

for {
d, err := dmx2.NextData()
if errors.Is(err, ErrNoMorePackets) {
break
}
require.NoError(t, err)

if d.PAT != nil {
rtPAT = d.PAT
}
if d.PMT != nil {
rtPMT = d.PMT
}

if rtPAT != nil && rtPMT != nil {
break
}
}
require.NotNil(t, rtPAT)
require.NotNil(t, rtPMT)

// Phase 4: Validate round-trip preserved all meaningful information
// --- PAT ---
assert.Equal(t, originalPAT.TransportStreamID, rtPAT.TransportStreamID, "PAT TransportStreamID mismatch")
require.Equal(t, len(originalPAT.Programs), len(rtPAT.Programs), "PAT program count mismatch")
for i, origProg := range originalPAT.Programs {
assert.Equalf(t, origProg.ProgramNumber, rtPAT.Programs[i].ProgramNumber,
"PAT Programs[%d].ProgramNumber mismatch", i)
assert.Equalf(t, origProg.ProgramMapID, rtPAT.Programs[i].ProgramMapID,
"PAT Programs[%d].ProgramMapID mismatch", i)
}

// --- PMT ---
assert.Equal(t, originalPMT.PCRPID, rtPMT.PCRPID)
assert.Equal(t, originalPMT.ProgramNumber, rtPMT.ProgramNumber)
require.Equal(t, len(originalPMT.ProgramDescriptors), len(rtPMT.ProgramDescriptors))
for i, desc := range originalPMT.ProgramDescriptors {
assert.Equalf(t, desc.Tag, rtPMT.ProgramDescriptors[i].Tag,
"PMT ProgramDescriptors[%d].Tag mismatch", i)
assert.Equalf(t, desc.Length, rtPMT.ProgramDescriptors[i].Length,
"PMT ProgramDescriptors[%d].Length mismatch", i)
}
require.Equal(t, len(originalPMT.ElementaryStreams), len(rtPMT.ElementaryStreams))
for i, es := range originalPMT.ElementaryStreams {
rtES := rtPMT.ElementaryStreams[i]
assert.Equalf(t, es.ElementaryPID, rtES.ElementaryPID,
"PMT ElementaryStreams[%d].ElementaryPID mismatch", i)
assert.Equalf(t, es.StreamType, rtES.StreamType,
"PMT ElementaryStreams[%d].StreamType mismatch", i)
require.Equalf(t, len(es.ElementaryStreamDescriptors), len(rtES.ElementaryStreamDescriptors),
"PMT ElementaryStreams[%d].ElementaryStreamDescriptors count mismatch", i)
for j, desc := range es.ElementaryStreamDescriptors {
assert.Equalf(t, desc.Tag, rtES.ElementaryStreamDescriptors[j].Tag,
"PMT ElementaryStreams[%d].Descriptors[%d].Tag mismatch", i, j)
assert.Equalf(t, desc.Length, rtES.ElementaryStreamDescriptors[j].Length,
"PMT ElementaryStreams[%d].Descriptors[%d].Length mismatch", i, j)
}
}
}
Binary file added testdata/ts/silent_audio.ts
Binary file not shown.
Loading