diff --git a/.gitignore b/.gitignore index e436cff..9896977 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,6 @@ *~ #*# *__pycache__* -*.pyc \ No newline at end of file +*.pyc + +venv/** \ No newline at end of file diff --git a/pl/dpll.py b/pl/dpll.py new file mode 100644 index 0000000..012743c --- /dev/null +++ b/pl/dpll.py @@ -0,0 +1,166 @@ +import numpy as np +import networkx as nx +from utils import elaborate_clauses, get_ordered_symbols + +def find_unit_clause(symbols, clauses, model): + save_index = None + for clause in clauses: + i = 0 + c_none = 0 + sat = False + while i < len(clause) and not sat: + literal_index = symbols.index(clause[i].replace('!', '')) + if model[literal_index] == None: + save_index = i + c_none = c_none + 1 + else: + if('!' not in clause[i]): + sat = sat or model[literal_index] + else: + sat = sat or not model[literal_index] + i = i + 1 + if c_none == 1 and not sat: + return clause[save_index] + return None + +def find_first(symbols, clauses, model): + i = 0 + list_of_symbols = get_ordered_symbols(clauses) + while i < len(model): + if model[i] == None and symbols[i] in list_of_symbols: + return i + i = i + 1 + return None + +def dpll_satisfailable(clauses): + symbols = get_ordered_symbols(clauses) + model = [None] * len(symbols) + + search_tree = nx.Graph() + search_tree.add_node(0) + + labels = {0:'Start'} + return (dpll(clauses, symbols, model, search_tree, 0, labels), search_tree, labels) + +def dpll(clauses, symbols, model, search_tree, parent_node, labels): + false_clauses, unsat_clauses = elaborate_clauses(symbols, clauses, model) + print(model) + if len(false_clauses) > 0: + # I have at least one clause unsatisfailable with the current model + id = search_tree.number_of_nodes() + search_tree.add_node(id) + search_tree.add_edge(parent_node, id) + labels[id] = 'FAIL' + return False + elif len(unsat_clauses) == 0: + id = search_tree.number_of_nodes() + search_tree.add_node(id) + search_tree.add_edge(parent_node, id) + labels[id] = 'SAT' + return True + else: + lonely_literal = find_unit_clause(symbols, unsat_clauses, model) + if lonely_literal != None: + literal = lonely_literal + new_model = [x for x in model] + if ('!' in literal): + new_model[symbols.index(literal.replace('!', ''))] = False + else: + new_model[symbols.index(literal.replace('!', ''))] = True + + id = search_tree.number_of_nodes() + search_tree.add_node(id) + search_tree.add_edge(parent_node, id) + labels[id] = literal + + return dpll(clauses, symbols, new_model, search_tree, id, labels) + else: + # Search for the first unassigned literal in a unsatisfied clause + literal_index = find_first(symbols, unsat_clauses, model) + if literal_index == None: + # I found no literal + id = search_tree.number_of_nodes() + search_tree.add_node(id) + search_tree.add_edge(parent_node, id) + labels[id] = 'FAIL' + return False + # Model setting the literal to true + new_model_true = [x for x in model] + new_model_true[literal_index] = True + + id = search_tree.number_of_nodes() + search_tree.add_node(id) + search_tree.add_edge(parent_node, id) + labels[id] = symbols[literal_index] + + ans = dpll(clauses, symbols, new_model_true, search_tree, id, labels) + + # Model setting the literal to false + new_model_false = [x for x in model] + new_model_false[literal_index] = False + + id = search_tree.number_of_nodes() + search_tree.add_node(id) + search_tree.add_edge(parent_node, id) + labels[id] = '!' + symbols[literal_index] + + ans = ans or dpll(clauses, symbols, new_model_false, search_tree, id, labels) + + return ans + +if __name__ == '__main__': + import argparse + import matplotlib.pyplot as plt + from utils import parse_formula, get_ordered_symbols, generate_cnf, hierarchy_pos + + parser = argparse.ArgumentParser() + + parser.add_argument('--solve', type=str, nargs='+', + help=f"List of literals in atrix form, where length of each clause is expressed in clause_lengths", + default=None) + parser.add_argument('--clause_lengths', type=int, nargs='+', + help=f"A list that expresses the lenght of each clauses in solve", + default=None) + + parser.add_argument('-c', '--conjunctions', type=int, + help=f"Number of conjunctions in the formula", + default=4) + parser.add_argument('-d', '--disjunctions', type=int, nargs=2, + help=f"Limit of minimum and maximum disjunctions in the formula", + default=(1, 4)) + + parser.add_argument('-l', '--literals', type=str, nargs="+", + help=f"Literal list from which the formula is created", + default=['A', 'B', 'C', 'D', 'E', 'F']) + + parser.add_argument('-s', '--seed', type=int, + help=f"Random seed number", + default=int(np.random.random() * 1000)) + + args = parser.parse_args() + + np.random.seed(args.seed) + print("SEED: ", args.seed) + symbols = args.literals + + if len(symbols) < args.disjunctions[1]: + raise ValueError("Number of literals " + + "is not sufficient to generate a formula " + + "with maximum disjunctions") + + if args.solve == None: + cnf = generate_cnf(symbols, args.conjunctions, args.disjunctions[1], min_disj=args.disjunctions[0]) + elif args.solve != None and args.clause_lengths != None: + cnf = parse_formula(args.solve, args.clause_lengths) + else: + raise ValueError('Formula or disjunctions lengths not provided') + + print(cnf) + + input() + (ans, search_tree, labels) = dpll_satisfailable(cnf) + + print(f"Solution of symbols {get_ordered_symbols(cnf)} is {ans}") + + nx.draw(search_tree, pos=hierarchy_pos(search_tree, 0), labels=labels) + plt.show() \ No newline at end of file diff --git a/pl/horn.py b/pl/horn.py new file mode 100644 index 0000000..811d1e8 --- /dev/null +++ b/pl/horn.py @@ -0,0 +1,69 @@ +import numpy as np +from utils import generate_cnf, elaborate_clauses +from dpll import find_unit_clause + +def generate_horn(symbols, n_conj, max_disj, min_disj): + clauses = generate_cnf(symbols, n_conj, max_disj, min_disj, p_negative=1.0) + + for clause in clauses: + for i in range(len(clause)): + if np.random.random() <= 1 / len(clause): + clause[i] = clause[i].replace('!', '') + break + + return clauses + +def horn_satisfiable(symbols, clauses, model): + literal = find_unit_clause(symbols, clauses, model) + while(literal != None): + if ('!' in literal): + model[symbols.index(literal.replace('!', ''))] = False + else: + model[symbols.index(literal.replace('!', ''))] = True + literal = find_unit_clause(symbols, clauses, model) + unsat, _ = elaborate_clauses(symbols, clauses, model) + if len(unsat) > 0: + return None + else: + return [False if x is None else x for x in model] + +if __name__ == '__main__': + import argparse + from utils import get_ordered_symbols + + parser = argparse.ArgumentParser() + + parser.add_argument('-c', '--conjunctions', type=int, + help=f"Number of conjunctions in the formula", + default=4) + parser.add_argument('-d', '--disjunctions', type=int, nargs=2, + help=f"Limit of minimum and maximum disjunctions in the formula", + default=(1, 4)) + + parser.add_argument('-l', '--literals', type=str, nargs="+", + help=f"Literal list from which the formula is created", + default=['A', 'B', 'C', 'D', 'E', 'F']) + + parser.add_argument('-s', '--seed', type=int, + help=f"Random seed number", + default=int(np.random.random() * 1000)) + + args = parser.parse_args() + + np.random.seed(args.seed) + print("SEED: ", args.seed) + symbols = args.literals + + if len(symbols) < args.disjunctions[1]: + raise ValueError("Number of literals " + + "is not sufficient to generate a formula " + + "with maximum disjunctions") + + horn = generate_horn(symbols, args.conjunctions, args.disjunctions[1], min_disj=args.disjunctions[0]) + print(horn) + + input() + model = [None] * len(symbols) + ans = horn_satisfiable(symbols, horn, model) + + print(f"Solution of symbols {get_ordered_symbols(horn)} is {ans}") \ No newline at end of file diff --git a/pl/utils.py b/pl/utils.py new file mode 100644 index 0000000..1b73ee0 --- /dev/null +++ b/pl/utils.py @@ -0,0 +1,174 @@ +import numpy as np +import re +import networkx as nx +import itertools + +def select_choice(options, choices): + '''Simulated stochastic process that deterministically pick an + option given a pre-determined list of choices. Choices are + cycled through. Options are weighted. + + ''' + w_sum = sum(w for _, w in options) + + print("options:", [(o, w/w_sum) for o,w in options]) + print("choices:", choices) + + choice = choices.pop(0) + choices.append(choice) # cycling + + p = 0 + for opt, w in options: + p += w/w_sum + if choice <= p: + return opt, p + +def hierarchy_pos(G, root=None, width=1., vert_gap = 0.2, vert_loc = 0, xcenter = 0.5): + ''' + From Joel's answer at https://stackoverflow.com/a/29597209/2966723. + Licensed under Creative Commons Attribution-Share Alike + + If the graph is a tree this will return the positions to plot this in a + hierarchical layout. + + G: the graph (must be a tree) + + root: the root node of current branch + - if the tree is directed and this is not given, + the root will be found and used + - if the tree is directed and this is given, then + the positions will be just for the descendants of this node. + - if the tree is undirected and not given, + then a random choice will be used. + + width: horizontal space allocated for this branch - avoids overlap with other branches + + vert_gap: gap between levels of hierarchy + + vert_loc: vertical location of root + + xcenter: horizontal location of root + ''' + if not nx.is_tree(G): + raise TypeError('cannot use hierarchy_pos on a graph that is not a tree') + + if root is None: + if isinstance(G, nx.DiGraph): + root = next(iter(nx.topological_sort(G))) #allows back compatibility with nx version 1.11 + else: + root = np.random.choice(list(G.nodes)) + + def _hierarchy_pos(G, root, width=1., vert_gap = 0.2, vert_loc = 0, xcenter = 0.5, pos = None, parent = None): + ''' + see hierarchy_pos docstring for most arguments + + pos: a dict saying where all nodes go if they have been assigned + parent: parent of this branch. - only affects it if non-directed + + ''' + + if pos is None: + pos = {root:(xcenter,vert_loc)} + else: + pos[root] = (xcenter, vert_loc) + children = list(G.neighbors(root)) + if not isinstance(G, nx.DiGraph) and parent is not None: + children.remove(parent) + if len(children)!=0: + dx = width/len(children) + nextx = xcenter - width/2 - dx/2 + for child in children: + nextx += dx + pos = _hierarchy_pos(G,child, width = dx, vert_gap = vert_gap, + vert_loc = vert_loc-vert_gap, xcenter=nextx, + pos=pos, parent = root) + return pos + + + return _hierarchy_pos(G, root, width, vert_gap, vert_loc, xcenter) + +def parse_formula(array, disjunctions_lengths): + ''' + Parses an array of strings into a cnf formula with + len(disjunction_lengths) conjunctions and the specified + disjunctions lenghts for each clause + ''' + cnf = [] + + if len(array) != sum(disjunctions_lengths): + raise ValueError("The number of given literals is not sufficient to construct the formula") + + symbol_c = 0 + i = 0 + while i < len(disjunctions_lengths): + j = 0 + clause = [] + while j < disjunctions_lengths[i]: + clause.append(array[symbol_c]) + symbol_c = symbol_c + 1 + j = j + 1 + cnf.append(sorted(clause, key=lambda x: re.sub('[^A-Za-z]+', '', x).lower())) + i = i + 1 + + return cnf + +def get_ordered_symbols(clauses): + ''' + Return a list with all the symbols present in a set of clauses + ''' + symbols = set([x.replace('!', '') for x in list(itertools.chain.from_iterable(clauses))]) + return sorted(list(symbols), key=lambda x: re.sub('[^A-Za-z]+', '', x).lower()) + +def generate_cnf(symbols, n_conj, max_disj, min_disj = 1, p_negative = 0.5): + ''' + Creates a random formula in cnd form. + The formula is expressed in a matrix where: + each cell in each row is in the same disjuction and + each row is in conjuction with the others + ''' + + cnf = [] + for conj in range(0, n_conj): + choices = np.random.choice( + len(symbols), + size=np.random.randint(min_disj, max_disj + 1), + replace=False) + literals = [] + for x in range(len(choices)): + if np.random.random() <= p_negative: + literals.append('!' + symbols[choices[x]]) + else: + literals.append('' + symbols[choices[x]]) + cnf.append(sorted(literals, key=lambda x: re.sub('[^A-Za-z]+', '', x).lower())) + return cnf + +def elaborate_clauses(symbols, clauses, model): + ''' + Search for unsatisfied clauses given a model. + If a clause contains empty (None) values, it is not returned + only if the known values make it satisfied + ''' + + unsolvable = [] + unsatisfied = [] + for clause in clauses: + i = 0 + sat = False + still_solvable = False + while i < len(clause) and not sat: + literal = symbols.index(clause[i].replace('!', '')) + if model[literal] != None: + if('!' not in clause[i]): + sat = sat or model[literal] + else: + sat = sat or not model[literal] + else: + still_solvable = True + i = i + 1 + if not sat: + if still_solvable: + unsatisfied.append(clause) + else: + unsolvable.append(clause) + + return (unsolvable, unsatisfied) \ No newline at end of file diff --git a/pl/walksat.py b/pl/walksat.py new file mode 100644 index 0000000..28acc67 --- /dev/null +++ b/pl/walksat.py @@ -0,0 +1,112 @@ +import numpy as np +from utils import select_choice, elaborate_clauses + +def maximize_satisfied(symbols, clauses, model, target_clause): + best_literal = None + min_unsat = len(clauses) # Maximize sat should be equal to minimizing unsat + + for literal in target_clause: + new_model = [x for x in model] + new_model[symbols.index(literal.replace('!', ''))] = not model[symbols.index(literal.replace('!', ''))] + + unsat, _ = elaborate_clauses(symbols, clauses, new_model) + curr_sat = len(unsat) + print(f"Literal {literal.replace('!', '')} leaves the model with {curr_sat} unsatisfied clauses") + if(curr_sat < min_unsat): + min_unsat = curr_sat + best_literal = literal.replace('!', '') + + return best_literal + +def walksat(symbols, clauses, choices, max_flips=1000, p = 0.5): + model = np.zeros(len(symbols), dtype=bool) + for i in range(len(model)): + if np.random.random() < 0.5: + model[i] = True + else: + model[i] = False + + print("Starting model: ", model) + for _ in range(max_flips): + input() + unsat_clauses, _ = elaborate_clauses(symbols, clauses, model) + if len(unsat_clauses) == 0: + return model + (clause, _) = select_choice([(x, 1) for x in unsat_clauses], choices) + print(f"Selected clause: {clause}") + if select_choice([('Informed', 1 - p), ('Uninformed', p)], choices)[0] == 'Informed': + print('Selected informed flip') + to_flip = maximize_satisfied(symbols, clauses, model, clause) + model[symbols.index(to_flip)] = not model[symbols.index(to_flip)] + pass + else: + print('Selected random flip') + options = [(x.replace('!', ''), 1) for x in clause] + (to_flip, _) = select_choice(options, choices) + model[symbols.index(to_flip)] = not model[symbols.index(to_flip)] + print(f"Flipped {to_flip} that becomes {model[symbols.index(to_flip)]}") + return None + +if __name__ == '__main__': + import argparse + from utils import parse_formula, generate_cnf, get_ordered_symbols + + parser = argparse.ArgumentParser() + + parser.add_argument('--solve', type=str, nargs='+', + help=f"List of literals in atrix form, where length of each clause is expressed in clause_lengths", + default=None) + parser.add_argument('--clause_lengths', type=int, nargs='+', + help=f"A list that expresses the lenght of each clauses in solve", + default=None) + + parser.add_argument('-c', '--conjunctions', type=int, + help=f"Number of conjunctions in the formula", + default=4) + parser.add_argument('-d', '--disjunctions', type=int, nargs=2, + help=f"Limit of minimum and maximum disjunctions in the formula", + default=(3, 4)) + + parser.add_argument('-p', '--uninformed_probability', type=float, + help=f"Sets the probability of uninformed_search", + default=0.5) + + parser.add_argument('-l', '--literals', type=str, nargs="+", + help=f"Literal list from which the formula is created", + default=['A', 'B', 'C', 'D', 'E', 'F']) + + parser.add_argument('-s', '--seed', type=int, + help=f"Random seed number", + default=int(np.random.random() * 1000)) + parser.add_argument('--choices', type=float, nargs='+', + help=f"Predetermined non-deterministic choices", + default=None) + + args = parser.parse_args() + + np.random.seed(args.seed) + temp_choices = [round(np.random.uniform(0, 1), 2) for x in range(0, 3)] + if args.choices == None: + args.choices = temp_choices + + print("SEED: ", args.seed) + symbols = args.literals + + if len(symbols) < args.disjunctions[1]: + raise ValueError("Number of literals " + + "is not sufficient to generate a formula " + + "with maximum disjunctions") + + if args.solve == None: + cnf = generate_cnf(symbols, args.conjunctions, args.disjunctions[1], min_disj=args.disjunctions[0]) + elif args.solve != None and args.clause_lengths != None: + cnf = parse_formula(args.solve, args.clause_lengths) + else: + raise ValueError('Formula or disjunctions lengths not provided') + print(cnf) + + symbols = get_ordered_symbols(cnf) + ans = walksat(symbols, cnf, args.choices, p = args.uninformed_probability) + + print(f"Formula: {cnf}") + print(f"Solution of symbols {symbols} is {ans}") \ No newline at end of file