Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

contract_expression with batch dimenstion #151

Closed
PgLoLo opened this issue Jul 21, 2020 · 2 comments
Closed

contract_expression with batch dimenstion #151

PgLoLo opened this issue Jul 21, 2020 · 2 comments

Comments

@PgLoLo
Copy link

PgLoLo commented Jul 21, 2020

Is it possible to construct contract_expression with optional batch dimensions? Consider example:
contract('...i,i->...', a, b)
Expression above could be used with different shapes of variable a. But If I construct contract_expression in the following way:
contract_expression('...i,i->...', (1024, 16), (16,))
It would accept only 2-dimensional tensor as its first argument.

@dgasmith
Copy link
Owner

I believe this is related to batch discussions here: #95

@jcmgray
Copy link
Collaborator

jcmgray commented Jul 21, 2020

You can supply arrays with different shapes but not ndim to a ContractExpression and they will be evaluated with the same contraction path - i.e. you could vary 1024, 16 etc. Or you could generate the path for one set of shapes and use it to generate many different contract expressions (though it might no longer be the best path for different sets of shapes!).

opt_einsum also works nicely with jax (which has an efficient version of numpy.vectorize):

eq = 'a,a->'
expr = oe.contract_expression(eq, (16,), (16,))
vexpr = jax.numpy.vectorize(expr, signature='(a),(a)->()')

x = np.random.randn(7, 32, 1024, 16)
y = np.random.randn(16)
vexpr(x, y).shape
# (7, 32, 1024)

@dgasmith dgasmith closed this as completed Nov 7, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants