diff --git a/muxer.go b/muxer.go index 315da9b..4163017 100644 --- a/muxer.go +++ b/muxer.go @@ -29,6 +29,9 @@ type Muxer struct { packetSize int tablesRetransmitPeriod int // period in PES packets + transportStreamID uint16 + pmtPID uint16 + pm *programMap // pid -> programNumber pmUpdated bool pmt PMTData @@ -68,6 +71,22 @@ func MuxerOptTablesRetransmitPeriod(newPeriod int) func(*Muxer) { } } +// WithTransportStreamID sets the transport stream ID written into PAT. +// Default is 0. +func WithTransportStreamID(id uint16) func(*Muxer) { + return func(m *Muxer) { + m.transportStreamID = id + } +} + +// WithPMTPID sets the PID used for PMT packets and advertised in PAT. +// Default is 0x1000. +func WithPMTPID(pid uint16) func(*Muxer) { + return func(m *Muxer) { + m.pmtPID = pid + } +} + // TODO MuxerOptAutodetectPCRPID selecting first video PID for each PMT, falling back to first audio, falling back to any other func NewMuxer(ctx context.Context, w io.Writer, opts ...func(*Muxer)) *Muxer { @@ -78,6 +97,8 @@ func NewMuxer(ctx context.Context, w io.Writer, opts ...func(*Muxer)) *Muxer { packetSize: MpegTsPacketSize, // no 192-byte packet support yet tablesRetransmitPeriod: 40, + pmtPID: pmtStartPID, + pm: newProgramMap(), pmt: PMTData{ ElementaryStreams: []*PMTElementaryStream{}, @@ -97,14 +118,14 @@ func NewMuxer(ctx context.Context, w io.Writer, opts ...func(*Muxer)) *Muxer { m.bufWriter = astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: &m.buf}) m.bitsWriter = astikit.NewBitsWriter(astikit.BitsWriterOptions{Writer: m.w}) - // TODO multiple programs support - m.pm.setUnlocked(pmtStartPID, programNumberStart) - m.pmUpdated = true - for _, opt := range opts { opt(m) } + // TODO multiple programs support + m.pm.setUnlocked(m.pmtPID, programNumberStart) + m.pmUpdated = true + // to output tables at the very start m.tablesRetransmitCounter = m.tablesRetransmitPeriod @@ -322,6 +343,7 @@ func (m *Muxer) WriteTables() (int, error) { func (m *Muxer) generatePAT() error { d := m.pm.toPATDataUnlocked() + d.TransportStreamID = m.transportStreamID versionNumber := m.patVersion.get() if m.pmUpdated { @@ -432,7 +454,7 @@ func (m *Muxer) generatePMT() error { Header: PacketHeader{ HasPayload: true, PayloadUnitStartIndicator: true, - PID: pmtStartPID, // FIXME multiple programs support + PID: m.pmtPID, // FIXME multiple programs support ContinuityCounter: uint8(m.pmtCC.inc()), }, Payload: m.buf.Bytes(), diff --git a/roundtrip_test.go b/roundtrip_test.go new file mode 100644 index 0000000..cf56e83 --- /dev/null +++ b/roundtrip_test.go @@ -0,0 +1,138 @@ +package astits + +import ( + "bytes" + "context" + "errors" + "os" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type pesRecord struct { + pid uint16 + pes *PESData + af *PacketAdaptationField +} + +func TestRoundTrip(t *testing.T) { + originalBytes, err := os.ReadFile("testdata/ts/silent_audio.ts") + require.NoError(t, err) + + // Phase 1: Demux the original TS file + dmx := NewDemuxer(context.Background(), bytes.NewReader(originalBytes), DemuxerOptPacketSize(MpegTsPacketSize)) + + var originalPAT *PATData + var originalPMT *PMTData + var originalPMTPID uint16 = 0xFFFF + + for { + d, err := dmx.NextData() + if errors.Is(err, ErrNoMorePackets) { + break + } + require.NoError(t, err) + + if d.PAT != nil { + originalPAT = d.PAT + originalPMTPID = d.PAT.Programs[0].ProgramMapID + } + if d.PMT != nil { + originalPMT = d.PMT + } + + if originalPMT != nil && originalPAT != nil { + break + } + } + require.NotNil(t, originalPAT) + require.NotNil(t, originalPMT) + require.NotEqual(t, 0xFFFF, originalPMTPID) + + // Phase 2: Mux everything back into a new TS stream, preserving PAT/PMT identifiers + var buf bytes.Buffer + muxer := NewMuxer(context.Background(), &buf, + WithTransportStreamID(originalPAT.TransportStreamID), + WithPMTPID(originalPMTPID), + ) + + for _, es := range originalPMT.ElementaryStreams { + err := muxer.AddElementaryStream(PMTElementaryStream{ + ElementaryPID: es.ElementaryPID, + StreamType: es.StreamType, + ElementaryStreamDescriptors: es.ElementaryStreamDescriptors, + }) + require.NoError(t, err) + } + muxer.SetPCRPID(originalPMT.PCRPID) + muxer.pmt.ProgramDescriptors = originalPMT.ProgramDescriptors + _, err = muxer.WriteTables() + require.NoError(t, err) + + // Phase 3: Demux the round-tripped output + dmx2 := NewDemuxer(context.Background(), bytes.NewReader(buf.Bytes()), DemuxerOptPacketSize(MpegTsPacketSize)) + + var rtPAT *PATData + var rtPMT *PMTData + + for { + d, err := dmx2.NextData() + if errors.Is(err, ErrNoMorePackets) { + break + } + require.NoError(t, err) + + if d.PAT != nil { + rtPAT = d.PAT + } + if d.PMT != nil { + rtPMT = d.PMT + } + + if rtPAT != nil && rtPMT != nil { + break + } + } + require.NotNil(t, rtPAT) + require.NotNil(t, rtPMT) + + // Phase 4: Validate round-trip preserved all meaningful information + // --- PAT --- + assert.Equal(t, originalPAT.TransportStreamID, rtPAT.TransportStreamID, "PAT TransportStreamID mismatch") + require.Equal(t, len(originalPAT.Programs), len(rtPAT.Programs), "PAT program count mismatch") + for i, origProg := range originalPAT.Programs { + assert.Equalf(t, origProg.ProgramNumber, rtPAT.Programs[i].ProgramNumber, + "PAT Programs[%d].ProgramNumber mismatch", i) + assert.Equalf(t, origProg.ProgramMapID, rtPAT.Programs[i].ProgramMapID, + "PAT Programs[%d].ProgramMapID mismatch", i) + } + + // --- PMT --- + assert.Equal(t, originalPMT.PCRPID, rtPMT.PCRPID) + assert.Equal(t, originalPMT.ProgramNumber, rtPMT.ProgramNumber) + require.Equal(t, len(originalPMT.ProgramDescriptors), len(rtPMT.ProgramDescriptors)) + for i, desc := range originalPMT.ProgramDescriptors { + assert.Equalf(t, desc.Tag, rtPMT.ProgramDescriptors[i].Tag, + "PMT ProgramDescriptors[%d].Tag mismatch", i) + assert.Equalf(t, desc.Length, rtPMT.ProgramDescriptors[i].Length, + "PMT ProgramDescriptors[%d].Length mismatch", i) + } + require.Equal(t, len(originalPMT.ElementaryStreams), len(rtPMT.ElementaryStreams)) + for i, es := range originalPMT.ElementaryStreams { + rtES := rtPMT.ElementaryStreams[i] + assert.Equalf(t, es.ElementaryPID, rtES.ElementaryPID, + "PMT ElementaryStreams[%d].ElementaryPID mismatch", i) + assert.Equalf(t, es.StreamType, rtES.StreamType, + "PMT ElementaryStreams[%d].StreamType mismatch", i) + require.Equalf(t, len(es.ElementaryStreamDescriptors), len(rtES.ElementaryStreamDescriptors), + "PMT ElementaryStreams[%d].ElementaryStreamDescriptors count mismatch", i) + for j, desc := range es.ElementaryStreamDescriptors { + assert.Equalf(t, desc.Tag, rtES.ElementaryStreamDescriptors[j].Tag, + "PMT ElementaryStreams[%d].Descriptors[%d].Tag mismatch", i, j) + assert.Equalf(t, desc.Length, rtES.ElementaryStreamDescriptors[j].Length, + "PMT ElementaryStreams[%d].Descriptors[%d].Length mismatch", i, j) + } + } +} diff --git a/testdata/ts/silent_audio.ts b/testdata/ts/silent_audio.ts new file mode 100644 index 0000000..ac398e0 Binary files /dev/null and b/testdata/ts/silent_audio.ts differ