Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added proposed multi_index_select version #34

Open
wants to merge 6 commits into
base: master
Choose a base branch
from

Conversation

bcsaldias
Copy link
Contributor

This PR also fixes the nonzero method default dimension names.

@bcsaldias
Copy link
Contributor Author

Ouch! I accidentally sent two different functions on the same PR.

I have a proposal that I'd love you can give feedback on it. So it's called multi_index_select(tensor, dims, indices).

Basically, dims refers to the name of dimensions to be indexed by indices; dims must be names of tensor dimensions (no particular order required).

@srush
Copy link
Contributor

srush commented Feb 2, 2019

This looks great. the isalphanum stuff could go right in if you wanted to send a separate PR.

One minor thing: I want to move all the asserts to RuntimeError's. That was my fault, but let's do that for future PRs.

A couple requests on multi_index:

  • can you add a longer comment for multi-index-select. It's a bit complicated. In particular I feel like it should take an index_dim argument right? Also unlike other functions the order of dims matters here right?

  • Ideally, this would just work with regular index select. Is there a reason it needs to be a different function? Maybe it just gets called if you pass more than one dim?

@bcsaldias
Copy link
Contributor Author

  1. Yes. I need to add documentation there.

I see multi-index-select as a flexible function, where you can say "I have this index, say [[3, 1], [2, 2]], where each element's dimensions correspond to 'dimc' and 'dimb' respectively, and want to index tensor T whose dimensions are (dima, dimb, dimc, dimd)." In this scenario, we would get tensor of dimension (2, dima, dimd).

Ex:
tensor.shape -> OrderedDict([("dima", 5), ("dimb", 4), ("dimc", 3), ("dimd", 7)])
indices = [[3, 1], [2, 2]]
dims = ('dimc', 'dimb')
elements = multi_index_select(tensor, dims, indices)
elements.shape -> OrderedDict([("elementsdim", 2), ("dima", 5), ("dimd", 7)])

Is it ways too flexible? I feel like this approach takes great advantage of the abstraction of dimensions proposed by namedtensor.

I could definitively develop other approaches but wanted to propose this idea first.

@bcsaldias bcsaldias changed the title Added 'isalnum' restriction to dimension names Added proposed multi_index_select version Feb 2, 2019
@srush
Copy link
Contributor

srush commented Feb 3, 2019

Yes, I mostly agree with this, I think we are on the same page. However, I would 1) like indices to be a named tensor too not a list, 2) have an argument that says which is the dimension with the indices, 3) not be a different function than index_select.

tensor.shape -> OrderedDict([("dima", 5), ("dimb", 4), ("dimc", 3), ("dimd", 7)])
indices = ntorch.tensor([[3, 1], [2, 2]], names=("elements", "dims")
dims = ('dimc', 'dimb')
elements = tensor.index_select(dims, indices, index_dim="dims")
elements.shape -> OrderedDict([("elements", 2), ("dima", 5), ("dimd", 7)])

where the standard index_select is a special case

tensor.shape -> OrderedDict([("dima", 5), ("dimb", 4), ("dimc", 3), ("dimd", 7)])
indices = ntorch.tensor([[3], [2]], names=("elements")
dims = ('dimc')
elements = tensor.index_select(dims, indices)
elements.shape -> OrderedDict([("elements", 2), ("dima", 5), ("dimd", 7), ("dimb", 4)])

@bcsaldias
Copy link
Contributor Author

Yes, that makes total sense.

  1. Yes, indices is already a namedtensor (I failed to say it in the last comment)

  2. I see. I was missing that! We cannot assume it's the first dimension.
    For example in this case it would be "elements", which is dim 0:
    indices = ntorch.tensor([[3, 1], [2, 2]], names=("elements", "dims"))
    However, in this another case it would be "elements", which is dim 1:
    indices = ntorch.tensor([[3, 1], [2, 2]], names=("dims", "elements"))

  3. Right. It doesn't be a different function.

Thanks for the feedback.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants