Skip to content

Commit 71cbeeb

Browse files
committed
Add test of bloom filter forwarding metadata
1 parent 4289532 commit 71cbeeb

File tree

1 file changed

+30
-1
lines changed

1 file changed

+30
-1
lines changed

python/rapidsmpf/rapidsmpf/tests/streaming/test_bloom_filter.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from rapidsmpf.streaming.core.actor import define_actor, run_actor_network
1616
from rapidsmpf.streaming.core.leaf_actor import pull_from_channel, push_to_channel
1717
from rapidsmpf.streaming.core.message import Message
18+
from rapidsmpf.streaming.cudf import ChannelMetadata
1819
from rapidsmpf.streaming.cudf.bloom_filter import BloomFilter
1920
from rapidsmpf.streaming.cudf.table_chunk import TableChunk
2021
from rapidsmpf.testing import assert_eq
@@ -34,6 +35,30 @@ def make_table(values: np.ndarray, stream: Stream) -> TableChunk:
3435
return TableChunk.from_pylibcudf_table(table, stream, exclusive_view=True)
3536

3637

38+
@define_actor()
39+
async def add_metadata(
40+
ctx: Context, ch_in: Channel[TableChunk], ch_out: Channel[TableChunk]
41+
) -> None:
42+
await ch_out.send_metadata(ctx, Message(0, ChannelMetadata(1)))
43+
await ch_out.drain_metadata(ctx)
44+
while (msg := await ch_in.recv(ctx)) is not None:
45+
await ch_out.send(ctx, msg)
46+
await ch_out.drain(ctx)
47+
48+
49+
@define_actor()
50+
async def receive_metadata(
51+
ctx: Context, ch_in: Channel[TableChunk], ch_out: Channel[TableChunk]
52+
) -> None:
53+
m = await ch_in.recv_metadata(ctx)
54+
assert m is not None
55+
meta = ChannelMetadata.from_message(m)
56+
assert meta.local_count == 1
57+
while (msg := await ch_in.recv(ctx)) is not None:
58+
await ch_out.send(ctx, msg)
59+
await ch_out.drain(ctx)
60+
61+
3762
@define_actor()
3863
async def bloom_pipeline(
3964
ctx: Context,
@@ -76,12 +101,16 @@ def run_bloom_filter_pipeline(
76101

77102
ch_build: Channel[TableChunk] = context.create_channel()
78103
ch_probe: Channel[TableChunk] = context.create_channel()
104+
ch_probe_meta: Channel[TableChunk] = context.create_channel()
105+
ch_out_meta: Channel[TableChunk] = context.create_channel()
79106
ch_out: Channel[TableChunk] = context.create_channel()
80107

81108
actors: list[CppActor | PyActor] = [
82109
push_to_channel(context, ch_build, [build_msg]),
83110
push_to_channel(context, ch_probe, [probe_msg]),
84-
bloom_pipeline(context, bloom, ch_build, ch_probe, ch_out),
111+
add_metadata(context, ch_probe, ch_probe_meta),
112+
bloom_pipeline(context, bloom, ch_build, ch_probe_meta, ch_out_meta),
113+
receive_metadata(context, ch_out_meta, ch_out),
85114
]
86115
pull_actor, deferred = pull_from_channel(context, ch_out)
87116
actors.append(pull_actor)

0 commit comments

Comments
 (0)