Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A failure of raw FLOP count optimization #103

Open
dgasmith opened this issue Aug 30, 2019 · 4 comments
Open

A failure of raw FLOP count optimization #103

dgasmith opened this issue Aug 30, 2019 · 4 comments

Comments

@dgasmith
Copy link
Owner

A nice example of where optimal performs worse than greedy in practice. optimal (9s), greedy (1.1s), optimal, (10**7) (0.8s).

import numpy as np

x = np.random.randn(100, 300, 75, 10)
w = np.random.randn(100, 300, 3)
At = np.random.randn(8, 10)
G = np.random.randn(10, 3)
Bt = np.random.randn(10, 10)

np.einsum("mn,nr,fcr,nh,octh->oftm", At, G, w, Bt, x, optimize="optimal")
np.einsum("mn,nr,fcr,nh,octh->oftm", At, G, w, Bt, x, optimize="greedy")
np.einsum("mn,nr,fcr,nh,octh->oftm", At, G, w, Bt, x, optimize=("optimal", 10**7))

Copied my response here:

import opt_einsum as oe
oe.contract_path("mn,nr,fcr,nh,octh->oftm", At, G, w, Bt, x, optimize="optimal")[1]

 Complete contraction:  mn,nr,fcr,nh,octh->oftm
         Naive scaling:  8
     Optimized scaling:  5
      Naive FLOP count:  2.700e+12
  Optimized FLOP count:  5.072e+9
   Theoretical speedup:  532.355
  Largest intermediate:  2.250e+7 elements
--------------------------------------------------------------------------------
scaling        BLAS                current                             remaining
--------------------------------------------------------------------------------
   4           GEMM            fcr,nr->fcn                  mn,nh,octh,fcn->oftm
   5           GEMM          octh,nh->octn                     mn,fcn,octn->oftm
   5              0         octn,fcn->otnf                         mn,otnf->oftm
   5           TDOT          otnf,mn->oftm                            oftm->oftm
oe.contract_path("mn,nr,fcr,nh,octh->oftm", At, G, w, Bt, x, optimize="greedy")[1]

  Complete contraction:  mn,nr,fcr,nh,octh->oftm
         Naive scaling:  8
     Optimized scaling:  6
      Naive FLOP count:  2.700e+12
  Optimized FLOP count:  1.412e+10
   Theoretical speedup:  191.286
  Largest intermediate:  2.250e+7 elements
--------------------------------------------------------------------------------
scaling        BLAS                current                             remaining
--------------------------------------------------------------------------------
   6           TDOT        octh,fcr->othfr                  mn,nr,nh,othfr->oftm
   6           TDOT        othfr,nh->otfrn                     mn,nr,otfrn->oftm
   5              0         otfrn,nr->otfn                         mn,otfn->oftm
   5           GEMM          otfn,mn->oftm                            oftm->oftm
oe.contract_path("mn,nr,fcr,nh,octh->oftm", At, G, w, Bt, x, optimize="optimal", memory_limit=int(10**7))[1]

  Complete contraction:  mn,nr,fcr,nh,octh->oftm
         Naive scaling:  8
     Optimized scaling:  6
      Naive FLOP count:  2.700e+12
  Optimized FLOP count:  3.601e+10
   Theoretical speedup:  74.970
  Largest intermediate:  6.000e+6 elements
--------------------------------------------------------------------------------
scaling        BLAS                current                             remaining
--------------------------------------------------------------------------------
   3              0             nr,mn->nrm                 fcr,nh,octh,nrm->oftm
   4           GEMM            nrm,nh->rmh                    fcr,octh,rmh->oftm
   5           GEMM          rmh,fcr->mhfc                       octh,mhfc->oftm
   6           TDOT        mhfc,octh->oftm                            oftm->oftm

Any BLAS=0 call will call np.einsum (for an einsum backend).

Using optimal without a memory limit your most expensive (by far) contraction involving the oc indices defaults to an einsum operation. For every other contraction path the most expensive operations are handled by GEMM (best no tensor copies) or TDOT (requires tensor copies).

We have looked at adding additional logic like so:

  • BLAS = cost = cost / ncores
  • TDOT = cost = cost / ncores + copy cost

This gets a bit tricky as the heuristics become very blurry for several reasons:

  • Is there FMA?
  • Do we care about Strassen for GEMM w/MKL?
  • What is your RAM throughput which is the transpose bottleneck.
  • Which index is copied in the tensor? Left indices will be much slower than right indices for C-contiguous arrays.
  • For other backends like JAX/Tensorflow/PyTorch these heuristics will fail (e.g., V100 tensor cores).

Our original pass at smarter heuristics for einsum was not met with too much enthusiasm as the use cases for einsum are incredibly diverse. If we could define a 99% case we could likely optimize to it, but so far we haven't been successful in describing the bounds of those cases.

Original issue here:
numpy/numpy#14332

@shoyer
Copy link

shoyer commented Nov 27, 2019

I wonder if the right way to handle this for now is to add a low level API for indicating costs, either in terms of a multiple of the original cost or with custom logic for particular shapes.

Google TPUs also have dedicated matrix multiplication units (like NVidia's tensor cores), and I suspect we will see even more hardware like this in the future. The logic gets pretty specialized to particular platforms, so I think it would be difficult to handle all the options ahead of time.

@dgasmith
Copy link
Owner Author

Would it be better to allow a scaling factor for different operational types? I think only the calling function would know if the operation could fit specific TPU requirements.

costs = {"EINSUM": 1.25,
         "GEMM": 0.5 / num_threads,
         "TDOT: 0.5 / num_threads,
         "TDOT_TRANSPOSE": 0.75}
oe.contract_path(einsum_string, *views, costs=costs)

This would allow the cost function to be simple: total_cost += costs.get(ctype, 1.0) * contraction_cost

@jcmgray
Copy link
Collaborator

jcmgray commented Nov 27, 2019

My preference would be to make the default 'cost' calculated as simple as possible and just the number of operations, i.e.:

sum(
    compute_size_by_dict(indices_involved, size_dict)
    for indices_involved in contractions
)

and thus, at least as a baseline, ignore the current modifiers in the current flop_count for whether it is an inner product or how many terms are involved. This is the cost that the 'dp' optimizer minimizes and is used elsewhere in tensor network literature. Also just in practice I find it is the best estimator!

Another motivation is that for e.g. complex data types, addition is 2 FLOPs whilst multiplication is 6, and there is likely other instruction set optimized stuff going on, so e.g. just doubling for a inner product isn't necessarily natural!

Maybe one could then have a separate and more advanced FLOP/cost estimator that take into account the nature of each contraction, and other customizable factors like you mention. This would only really help to understand the cost of a contraction once it is found, but otherwise it might be a low of work to support a custom cost in all the current path finders.

@dgasmith
Copy link
Owner Author

Its a good point on the inner product these days. When this was first started and the FLOP code was written FMA was pretty bleeding edge and not generally available. The first Skylake Intel CPUs came out that year (AMD had it in 2013, but wasn't very common) and FMA/AVX has propagated to most hardware these days so that choice is now fairly wrong.

Long way of saying that that this does need an overhaul. There other question to answer first is "does this matter"? We can think about several regimes:

  • Small contractions where it may be better to pick a worse scaling algorithm that uses optimized routines.
  • Medium contractions where it may be worth organizing contractions for optimized routines over lowest FLOP count at the same scaling.
  • Large contractions where advanced algorithms play an enhanced role (Strassen).

All of these use cases require scaling to be injected into the path finding itself, the logic overhead would slow the algorithms down quite a bit as well. It may be worth hacking in something exploratory to check the above matter, this could be as simple as being able to supply your own FLOP cost function.

If we rename FLOPs to OPs I think we become less clear as an OP could refer to a SIMD (or similar) call. Is there a better way to phrase the current "FLOP" categories?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants