1515from rapidsmpf .streaming .core .actor import define_actor , run_actor_network
1616from rapidsmpf .streaming .core .leaf_actor import pull_from_channel , push_to_channel
1717from rapidsmpf .streaming .core .message import Message
18+ from rapidsmpf .streaming .cudf import ChannelMetadata
1819from rapidsmpf .streaming .cudf .bloom_filter import BloomFilter
1920from rapidsmpf .streaming .cudf .table_chunk import TableChunk
2021from rapidsmpf .testing import assert_eq
@@ -34,6 +35,30 @@ def make_table(values: np.ndarray, stream: Stream) -> TableChunk:
3435 return TableChunk .from_pylibcudf_table (table , stream , exclusive_view = True )
3536
3637
38+ @define_actor ()
39+ async def add_metadata (
40+ ctx : Context , ch_in : Channel [TableChunk ], ch_out : Channel [TableChunk ]
41+ ) -> None :
42+ await ch_out .send_metadata (ctx , Message (0 , ChannelMetadata (1 )))
43+ await ch_out .drain_metadata (ctx )
44+ while (msg := await ch_in .recv (ctx )) is not None :
45+ await ch_out .send (ctx , msg )
46+ await ch_out .drain (ctx )
47+
48+
49+ @define_actor ()
50+ async def receive_metadata (
51+ ctx : Context , ch_in : Channel [TableChunk ], ch_out : Channel [TableChunk ]
52+ ) -> None :
53+ m = await ch_in .recv_metadata (ctx )
54+ assert m is not None
55+ meta = ChannelMetadata .from_message (m )
56+ assert meta .local_count == 1
57+ while (msg := await ch_in .recv (ctx )) is not None :
58+ await ch_out .send (ctx , msg )
59+ await ch_out .drain (ctx )
60+
61+
3762@define_actor ()
3863async def bloom_pipeline (
3964 ctx : Context ,
@@ -76,12 +101,16 @@ def run_bloom_filter_pipeline(
76101
77102 ch_build : Channel [TableChunk ] = context .create_channel ()
78103 ch_probe : Channel [TableChunk ] = context .create_channel ()
104+ ch_probe_meta : Channel [TableChunk ] = context .create_channel ()
105+ ch_out_meta : Channel [TableChunk ] = context .create_channel ()
79106 ch_out : Channel [TableChunk ] = context .create_channel ()
80107
81108 actors : list [CppActor | PyActor ] = [
82109 push_to_channel (context , ch_build , [build_msg ]),
83110 push_to_channel (context , ch_probe , [probe_msg ]),
84- bloom_pipeline (context , bloom , ch_build , ch_probe , ch_out ),
111+ add_metadata (context , ch_probe , ch_probe_meta ),
112+ bloom_pipeline (context , bloom , ch_build , ch_probe_meta , ch_out_meta ),
113+ receive_metadata (context , ch_out_meta , ch_out ),
85114 ]
86115 pull_actor , deferred = pull_from_channel (context , ch_out )
87116 actors .append (pull_actor )
0 commit comments