Skip to content

Commit 0dc082a

Browse files
committed
handle duplicate roots
1 parent aa022bc commit 0dc082a

2 files changed

Lines changed: 49 additions & 21 deletions

File tree

src/circuit.cpp

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -380,22 +380,6 @@ void Circuit::add_root(Node* new_root) {
380380
}
381381
roots.push_back(new_root);
382382

383-
384-
if (new_root->layer != 0) {
385-
// root layer's order might have changed:
386-
// set ix back to those according to `roots` list.
387-
// Since `roots` may contain duplicate node refs:
388-
// we first set all ix to -1 to detect the duplicate refs.
389-
for (size_t i = 0; i < roots.size(); ++i)
390-
roots[i]->ix = -1;
391-
392-
int root_idx = 0;
393-
for (size_t i = 0; i < roots.size(); ++i) {
394-
if (roots[i]->ix == -1)
395-
roots[i]->ix = root_idx++;
396-
}
397-
}
398-
399383
/*
400384
if (nb_layers() > 1) {
401385
if (roots.size() != layers[new_root->layer].size()) {
@@ -458,11 +442,12 @@ std::pair<Arrays, Arrays> Circuit::tensorize() {
458442
// per layer, a vector representing the layer
459443
Arrays csr_ndarrays;
460444

461-
if (layers.size() == 1)
462-
// add node for roots
463-
for (Node* root: roots)
464-
add_node(root->dummy_parent());
465-
445+
// add root layer on top
446+
for (std::size_t i=0; i<roots.size(); i++) {
447+
Node* root = roots[i]->dummy_parent();
448+
root->hash = i;
449+
add_node(root);
450+
}
466451

467452
for (std::size_t i = 1; i < nb_layers(); ++i) {
468453
std::vector<long int> child_counts(layers[i].size(), 0);
@@ -498,6 +483,9 @@ std::pair<Arrays, Arrays> Circuit::tensorize() {
498483
indices_ndarrays.push_back(indices_ndarray);
499484
csr_ndarrays.push_back(csr_ndarray);
500485
}
486+
// remove root layer again
487+
layers.pop_back();
488+
501489
return std::make_pair(indices_ndarrays, csr_ndarrays);
502490
}
503491

tests/test_manual.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import klay
22
import torch
3+
from pysdd.sdd import SddManager
34

45

56
def test_node_equality():
@@ -57,3 +58,42 @@ def test_multi_rooted_ordering():
5758
expected = torch.tensor([0.8 * 0.6, 0.4 * 0.8])
5859
print(m(w), expected)
5960
assert torch.allclose(m(w), expected)
61+
62+
63+
def test_single_layer_multi_root():
64+
c = klay.Circuit()
65+
l1, l2 = c.literal_node(1), c.literal_node(-2)
66+
c.set_root(l1)
67+
c.set_root(l2)
68+
c.set_root(l1)
69+
70+
m = c.to_torch_module(semiring='real')
71+
weights = torch.tensor([0.4, 0.8])
72+
expected = torch.tensor([0.4, 0.2, 0.4])
73+
assert torch.allclose(m(weights), expected)
74+
75+
76+
def test_sdd_literal():
77+
sdd_mgr = SddManager(var_count=2)
78+
a, b = sdd_mgr.vars
79+
80+
c = klay.Circuit()
81+
c.add_sdd(a)
82+
m = c.to_torch_module(semiring='real')
83+
weights = torch.tensor([0.4])
84+
expected = torch.tensor([0.4])
85+
assert torch.allclose(m(weights), expected)
86+
87+
def test_sdd_multiroot():
88+
sdd_mgr = SddManager(var_count=2)
89+
a, b = sdd_mgr.vars
90+
91+
c = klay.Circuit()
92+
c.add_sdd(a)
93+
c.add_sdd(a & b)
94+
c.add_sdd(a & b & b)
95+
c.add_sdd(a & a)
96+
m = c.to_torch_module(semiring='real')
97+
weights = torch.tensor([0.4, 0.5])
98+
expected = torch.tensor([0.4, 0.2, 0.2, 0.4])
99+
assert torch.allclose(m(weights), expected)

0 commit comments

Comments
 (0)