Skip to content

Commit

Permalink
added dynamic programming path (#102)
Browse files Browse the repository at this point in the history
* added dynamic programming path

* removed duplicate import

* code cleanup and some micro optimisations

* added dynamic programming to list of optimisers to be tested

* bugfix (for dp) for summation indices occurring only in one input

* added decomposition into disconnected subgraphs for dp path

* bugfix for outer products in dp path; improved performance for outer products

* added tests for dp path

* dp performance improvement: converted indices to integers

* dp performance improvement: using now bitmaps (integers) as tensor sets

* added helper function `path_cost` to compute cost of contraction paths

* added cost limit for dp and random-greedy-128 based estimate for it

* moved dp parameter cost_limit to constructor

* dp: now using successively increasing cost cap

* removed no longer needed function path_cost

* edited docstrings
  • Loading branch information
mrader1248 authored and dgasmith committed Sep 5, 2019
1 parent f0ec70f commit 4c123e0
Show file tree
Hide file tree
Showing 3 changed files with 332 additions and 3 deletions.
325 changes: 324 additions & 1 deletion opt_einsum/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@

from . import helpers

__all__ = ["optimal", "BranchBound", "branch", "greedy", "auto", "get_path_fn"]
__all__ = [
"optimal", "BranchBound", "branch", "greedy", "auto", "get_path_fn",
"DynamicProgrammingOptimizer", "dynamic_programming"
]


_UNLIMITED_MEM = {-1, None, float('inf')}
Expand Down Expand Up @@ -652,6 +655,324 @@ def greedy(inputs, output, size_dict, memory_limit=None, choose_fn=None, cost_fn
return ssa_to_linear(ssa_path)


def _tree_to_sequence(c):
"""
Converts a contraction tree to a contraction path as it has to be
returned by path optimizers. A contraction tree can either be an int
(=no contraction) or a tuple containing the terms to be contracted. An
arbitrary number (>= 1) of terms can be contracted at once. Note that
contractions are commutative, e.g. (j, k, l) = (k, l, j). Note that in
general, solutions are not unique.
Parameters
----------
c : tuple or int
Contraction tree
Returns
-------
path : list[set[int]]
Contraction path
Examples
--------
>>> _tree_to_sequence(((1,2),(0,(4,5,3))))
[(1, 2), (1, 2, 3), (0, 2), (0, 1)]
"""

# ((1,2),(0,(4,5,3))) --> [(1, 2), (1, 2, 3), (0, 2), (0, 1)]
#
# 0 0 0 (1,2) --> ((1,2),(0,(3,4,5)))
# 1 3 (1,2) --> (0,(3,4,5))
# 2 --> 4 --> (3,4,5)
# 3 5
# 4 (1,2)
# 5
#
# this function iterates through the table shown above from right to left;

if type(c) == int:
return []

c = [c] # list of remaining contractions (lower part of columns shown above)
t = [] # list of elementary tensors (upper part of colums)
s = [] # resulting contraction sequence

while len(c) > 0:
j = c.pop(-1)
s.insert(0, tuple())

for i in sorted([i for i in j if type(i) == int]):
s[0] += (sum(1 for q in t if q < i),)
t.insert(s[0][-1], i)

for i in [i for i in j if type(i) != int]:
s[0] += (len(t) + len(c),)
c.append(i)

return s


def _find_disconnected_subgraphs(inputs, output):
"""
Finds disconnected subgraphs in the given list of inputs. Inputs are
connected if they share summation indices. Note: Disconnected subgraphs
can be contracted independently before forming outer products.
Parameters
----------
inputs : list[set]
List of sets that represent the lhs side of the einsum subscript
output : set
Set that represents the rhs side of the overall einsum subscript
Returns
-------
subgraphs : list[set[int]]
List containing sets of indices for each subgraph
Examples
--------
>>> _find_disconnected_subgraphs([set("ab"), set("c"), set("ad")], set("bd"))
[{0, 2}, {1}]
>>> _find_disconnected_subgraphs([set("ab"), set("c"), set("ad")], set("abd"))
[{0}, {1}, {2}]
"""

subgraphs = []
unused_inputs = set(range(len(inputs)))

i_sum = set.union(*inputs) - output # all summation indices

while len(unused_inputs) > 0:
g = set()
q = [unused_inputs.pop()]
while len(q) > 0:
j = q.pop()
g.add(j)
i_tmp = i_sum & inputs[j]
n = {k for k in unused_inputs if len(i_tmp & inputs[k]) > 0}
q.extend(n)
unused_inputs.difference_update(n)

subgraphs.append(g)

return subgraphs


def _bitmapset_indices(s):
"""
Returns a generator object allowing to iterate over the elements contained
in a bitmap set.
Parameters
----------
s : int
The bitmap set to iterate over
Returns
-------
path : generator
Generator object to iterate over the elements in s
Examples
--------
>>> type(_bitmapset_indices(0b1001011))
generator
>>> list(_bitmapset_indices(0b1001011))
[0, 1, 3, 6]
"""
j = 0
while s != 0:
if s & 1 != 0:
yield j
s >>= 1
j += 1


class DynamicProgrammingOptimizer(PathOptimizer):
"""
Finds the optimal path of pairwise contractions without intermediate outer
products based a dynamic programming approach presented in
Phys. Rev. E 90, 033315 (2014) (the corresponding preprint is publically
available at https://arxiv.org/abs/1304.6112). This method is especially
well-suited in the area of tensor network states, where it usually
outperforms all the other optimization strategies.
This algorithm shows exponential scaling with the number of inputs
in the worst case scenario (see example below). If the graph to be
contracted consists of disconnected subgraphs, the algorithm scales
linearly in the number of disconnected subgraphs and only exponentially
with the number of inputs per subgraph.
"""

def __call__(self, inputs, output, size_dict, memory_limit=None):
"""
Parameters
----------
inputs : list
List of sets that represent the lhs side of the einsum subscript
output : set
Set that represents the rhs side of the overall einsum subscript
size_dict : dictionary
Dictionary of index sizes
memory_limit : int
The maximum number of elements in a temporary array
Returns
-------
path : list
The contraction order (a list of tuples of ints).
Examples
--------
>>> n_in = 3 # exponential scaling
>>> n_out = 2 # linear scaling
>>> s = dict()
>>> i_all = []
>>> for _ in range(n_out):
>>> i = [set() for _ in range(n_in)]
>>> for j in range(n_in):
>>> for k in range(j+1, n_in):
>>> c = oe.get_symbol(len(s))
>>> i[j].add(c)
>>> i[k].add(c)
>>> s[c] = 2
>>> i_all.extend(i)
>>> o = DynamicProgrammingOptimizer()
>>> o(i_all, set(), s)
[(1, 2), (0, 4), (1, 2), (0, 2), (0, 1)]
"""

# convert all indices to integers (makes set operations ~10 % faster)
symbol2int = {c: j for j, c in enumerate(set.union(*inputs) | output)}
inputs = [set(symbol2int[c] for c in i) for i in inputs]
output = set(symbol2int[c] for c in output)
size_dict = {symbol2int[c]: v for c, v in size_dict.items() if c in symbol2int}
size_dict = [size_dict[j] for j in range(len(size_dict))]

# all summation indices occurring exactly in one input:
i_single = set(
c for c in set.union(*inputs) - output
if sum(1 for i in inputs if c in i) == 1
)

# contraction expressions for all inputs that have already been
# reduced to scalars:
inputs_done = [
(j,) for j, i in enumerate(inputs)
if len(i - i_single) == 0
]

# remaining input index sets and corresponding contraction expressions;
# indices from i_single are removed and if a single-tensor contraction
# is performed, the contraction expression is (j,) instead of j;
inputs, inputs_contractions = zip(*[
(i - i_single, j if i.isdisjoint(i_single) else (j,))
for j, i in enumerate(inputs)
if len(i - i_single) > 0
])

# a list of all neccessary contraction expressions for each of the
# disconnected subgraphs and their size
subgraph_contractions = inputs_done
subgraph_contractions_size = [1]*len(inputs_done)

for g in _find_disconnected_subgraphs(inputs, output):

# dynamic programming approach to compute x[n] for subgraph g;
# x[n][set of n tensors] = (indices, cost, contraction)
# the set of n tensors is represented by a bitmap: if bit j is 1,
# tensor j is in the set, e.g. 0b100101 = {0,2,5}; set unions
# (intersections) can then be computed by bitwise or (and);
x = [None]*2 + [dict() for j in range(len(g)-1)]
x[1] = {1 << j: (inputs[j], 0, inputs_contractions[j]) for j in g}

# convert set of tensors g to a bitmap set:
g = functools.reduce(lambda x, y: x | y, (1 << j for j in g))

# the bitmap set of all tensors is computed as it is needed to
# compute set differences: s1 - s2 transforms into
# s1 & (all_tensors ^ s2)
all_tensors = (1 << len(inputs)) - 1

# try to find contraction with cost <= cost_cap and increase
# cost_cap successively if no such contraction is found;
# this is a major performance improvement; start with product of
# output index dimensions as initial cost_cap
cost_cap = helpers.compute_size_by_dict(
set.union(*(inputs[j] for j in _bitmapset_indices(g))) & output,
size_dict
)

while len(x[-1]) == 0:
for n in range(2, len(x[1]) + 1):
xn = x[n]

# try to combine solutions from x[m] and x[n-m]
for m in range(1, n // 2 + 1):
for s1, (i1, cost1, cntrct1) in x[m].items():
for s2, (i2, cost2, cntrct2) in x[n-m].items():

# only if s1 and s2 are disjoint
if s1 & s2 == 0:

# avoid e.g. s1={0}, s2={1} and s1={1}, s2={0}
if m != n - m or s1 < s2:

i1_cut_i2_wo_output = (i1 & i2) - output

# ignore outer products:
if len(i1_cut_i2_wo_output) > 0:

i1_union_i2 = i1 | i2
cost = cost1 + cost2 + helpers.compute_size_by_dict(i1_union_i2, size_dict)
if cost <= cost_cap:
s = s1 | s2
if s not in xn or cost < xn[s][1]:
# set of remaining tensors (=g-s)
r = g & (all_tensors ^ s)

# indices of remaining indices:
i_r = (set.union(*(inputs[j] for j in _bitmapset_indices(r)))
if r != 0 else set())

# contraction indices:
i_cntrct = i1_cut_i2_wo_output - i_r

i = i1_union_i2 - i_cntrct
mem = helpers.compute_size_by_dict(i, size_dict)
if memory_limit is None or mem <= memory_limit:
xn[s] = (i, cost, (cntrct1, cntrct2))

# increase cost cap for next iteration:
cost_cap = min(size_dict) * cost_cap

i, cost, contraction = list(x[-1].values())[0]
subgraph_contractions.append(contraction)
subgraph_contractions_size.append(helpers.compute_size_by_dict(i, size_dict))

# sort the subgraph contractions by the size of the subgraphs in
# ascending order (will give the cheapest contractions); note that
# outer products should be performed pairwise (to use BLAS functions)
subgraph_contractions = [
subgraph_contractions[j]
for j in np.argsort(subgraph_contractions_size)
]

# build the final contraction tree
tree = functools.reduce(lambda x, y: (x, y), subgraph_contractions)

return _tree_to_sequence(tree)


def dynamic_programming(inputs, output, size_dict, memory_limit=None):
optimizer = DynamicProgrammingOptimizer()
return optimizer(inputs, output, size_dict, memory_limit)


_AUTO_CHOICES = {}
for i in range(1, 5):
_AUTO_CHOICES[i] = optimal
Expand Down Expand Up @@ -680,6 +1001,8 @@ def auto(inputs, output, size_dict, memory_limit=None):
'greedy': greedy,
'eager': greedy,
'opportunistic': greedy,
'dp': dynamic_programming,
'dynamic-programming': dynamic_programming
}


Expand Down
2 changes: 1 addition & 1 deletion opt_einsum/tests/test_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@
]


all_optimizers = ['optimal', 'branch-all', 'branch-2', 'branch-1', 'greedy', 'random-greedy']
all_optimizers = ['optimal', 'branch-all', 'branch-2', 'branch-1', 'greedy', 'random-greedy', 'dp']


@pytest.mark.parametrize("string", tests)
Expand Down
8 changes: 7 additions & 1 deletion opt_einsum/tests/test_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,29 @@
}),
}

# note that these tests have no unique solution due to the chosen dimensions
path_edge_tests = [
['greedy', 'eb,cb,fb->cef', ((0, 2), (0, 1))],
['branch-all', 'eb,cb,fb->cef', ((0, 2), (0, 1))],
['branch-2', 'eb,cb,fb->cef', ((0, 2), (0, 1))],
['optimal', 'eb,cb,fb->cef', ((0, 2), (0, 1))],
['dp', 'eb,cb,fb->cef', ((1, 2), (0, 1))],
['greedy', 'dd,fb,be,cdb->cef', ((0, 3), (0, 1), (0, 1))],
['branch-all', 'dd,fb,be,cdb->cef', ((0, 3), (0, 1), (0, 1))],
['branch-2', 'dd,fb,be,cdb->cef', ((0, 3), (0, 1), (0, 1))],
['optimal', 'dd,fb,be,cdb->cef', ((0, 3), (0, 1), (0, 1))],
['optimal', 'dd,fb,be,cdb->cef', ((0, 3), (0, 1), (0, 1))],
['dp', 'dd,fb,be,cdb->cef', ((0, 3), (0, 2), (0, 1))],
['greedy', 'bca,cdb,dbf,afc->', ((1, 2), (0, 2), (0, 1))],
['branch-all', 'bca,cdb,dbf,afc->', ((1, 2), (0, 2), (0, 1))],
['branch-2', 'bca,cdb,dbf,afc->', ((1, 2), (0, 2), (0, 1))],
['optimal', 'bca,cdb,dbf,afc->', ((1, 2), (0, 2), (0, 1))],
['dp', 'bca,cdb,dbf,afc->', ((1, 2), (1, 2), (0, 1))],
['greedy', 'dcc,fce,ea,dbf->ab', ((1, 2), (0, 1), (0, 1))],
['branch-all', 'dcc,fce,ea,dbf->ab', ((1, 2), (0, 2), (0, 1))],
['branch-2', 'dcc,fce,ea,dbf->ab', ((1, 2), (0, 2), (0, 1))],
['optimal', 'dcc,fce,ea,dbf->ab', ((1, 2), (0, 2), (0, 1))],
['dp', 'dcc,fce,ea,dbf->ab', ((1, 2), (0, 2), (0, 1))],
]


Expand Down Expand Up @@ -191,7 +197,7 @@ def test_greedy_edge_cases():
assert check_path(path, [(0, 1), (0, 2), (0, 1)])


@pytest.mark.parametrize("optimize", ['greedy', 'branch-2', 'branch-all', 'optimal'])
@pytest.mark.parametrize("optimize", ['greedy', 'branch-2', 'branch-all', 'optimal', 'dp'])
def test_can_optimize_outer_products(optimize):
a, b, c = [np.random.randn(10, 10) for _ in range(3)]
d = np.random.randn(10, 2)
Expand Down

0 comments on commit 4c123e0

Please sign in to comment.