diff --git a/api_matop.go b/api_matop.go index 75c2452..bf412ea 100644 --- a/api_matop.go +++ b/api_matop.go @@ -127,13 +127,20 @@ func Diag(t Tensor) (retVal Tensor, err error) { // ByIndices allows for selection of value of `a` byt the indices listed in the `indices` tensor. // The `indices` tensor has to be a vector-like tensor of ints. func ByIndices(a, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { + if axis >= a.Shape().Dims() { + return nil, errors.Errorf("Cannot select by indices on axis %d. Input only has %d dims", axis, a.Shape().Dims()) + } if sbi, ok := a.Engine().(ByIndiceser); ok { return sbi.SelectByIndices(a, indices, axis, opts...) } return nil, errors.Errorf("Unable to select by indices. Egnine %T does not support that.", a.Engine()) } +// ByIndicesB is the backpropagation of ByIndices. func ByIndicesB(a, b, indices Tensor, axis int, opts ...FuncOpt) (retVal Tensor, err error) { + if axis >= a.Shape().Dims() { + return nil, errors.Errorf("Cannot select by indices on axis %d. Input only has %d dims", axis, a.Shape().Dims()) + } if sbi, ok := a.Engine().(ByIndiceser); ok { return sbi.SelectByIndicesB(a, b, indices, axis, opts...) } diff --git a/defaultengine_selbyidx.go b/defaultengine_selbyidx.go index ab7f4f1..cdcc318 100644 --- a/defaultengine_selbyidx.go +++ b/defaultengine_selbyidx.go @@ -86,8 +86,13 @@ func (e StdEng) selectByIdx(axis int, indices []int, typ reflect.Type, dataA, da dstCoord := make([]int, apRet.shape.Dims()) if isInnermost { - prevStride := apA.strides[axis-1] - retPrevStride := apRet.strides[axis-1] + prevAxis := axis - 1 + if prevAxis < 0 { + // this may be the case if input is a vector + prevAxis = 0 + } + prevStride := apA.strides[prevAxis] + retPrevStride := apRet.strides[prevAxis] for i, idx := range indices { srcCoord[axis] = idx dstCoord[axis] = i @@ -194,8 +199,13 @@ func (e StdEng) selectByIndicesB(axis int, indices []int, typ reflect.Type, data srcCoord := make([]int, apRet.shape.Dims()) if isInnermost { - retPrevStride := apB.strides[axis-1] - prevStride := apRet.strides[axis-1] + prevAxis := axis - 1 + if prevAxis < 0 { + // this may be the case if input is a vector + prevAxis = 0 + } + retPrevStride := apB.strides[prevAxis] + prevStride := apRet.strides[prevAxis] for i, idx := range indices { dstCoord[axis] = idx srcCoord[axis] = i diff --git a/dense_selbyidx_test.go b/dense_selbyidx_test.go index ca6b34f..86369be 100644 --- a/dense_selbyidx_test.go +++ b/dense_selbyidx_test.go @@ -6,121 +6,108 @@ import ( "github.com/stretchr/testify/assert" ) -func TestDense_SelectByIndices(t *testing.T) { - assert := assert.New(t) - - a := New(WithBacking([]float64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}), WithShape(3, 2, 4)) - indices := New(WithBacking([]int{1, 1})) - - e := StdEng{} - - a1, err := e.SelectByIndices(a, indices, 1) - if err != nil { - t.Errorf("%v", err) - } - correct1 := []float64{4, 5, 6, 7, 4, 5, 6, 7, 12, 13, 14, 15, 12, 13, 14, 15, 20, 21, 22, 23, 20, 21, 22, 23} - assert.Equal(correct1, a1.Data()) - - a0, err := e.SelectByIndices(a, indices, 0) - if err != nil { - t.Errorf("%v", err) - } - correct0 := []float64{8, 9, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 14, 15} - assert.Equal(correct0, a0.Data()) +type selByIndicesTest struct { + Name string + Data interface{} + Shape Shape + Indices []int + Axis int + WillErr bool + + Correct interface{} + CorrectShape Shape +} - a2, err := e.SelectByIndices(a, indices, 2) - if err != nil { - t.Errorf("%v", err) - } - correct2 := []float64{1, 1, 5, 5, 9, 9, 13, 13, 17, 17, 21, 21} - assert.Equal(correct2, a2.Data()) +var selByIndicesTests = []selByIndicesTest{ + {Name: "3-tensor, axis 0", Data: Range(Float64, 0, 24), Shape: Shape{3, 2, 4}, Indices: []int{1, 1}, Axis: 0, WillErr: false, + Correct: []float64{8, 9, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 14, 15}, CorrectShape: Shape{2, 2, 4}}, - // !safe - aUnsafe := a.Clone().(*Dense) - indices = New(WithBacking([]int{1, 1, 1})) - aUnsafeSelect, err := e.SelectByIndices(aUnsafe, indices, 0, UseUnsafe()) - if err != nil { - t.Errorf("%v", err) - } - correctUnsafe := []float64{8, 9, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 14, 15} - assert.Equal(correctUnsafe, aUnsafeSelect.Data()) + {Name: "3-tensor, axis 1", Data: Range(Float64, 0, 24), Shape: Shape{3, 2, 4}, Indices: []int{1, 1}, Axis: 1, WillErr: false, + Correct: []float64{4, 5, 6, 7, 4, 5, 6, 7, 12, 13, 14, 15, 12, 13, 14, 15, 20, 21, 22, 23, 20, 21, 22, 23}, CorrectShape: Shape{3, 2, 4}}, - // 3 indices, just to make sure the sanity of the algorithm - indices = New(WithBacking([]int{1, 1, 1})) - a1, err = e.SelectByIndices(a, indices, 1) - if err != nil { - t.Errorf("%v", err) - } - correct1 = []float64{ - 4, 5, 6, 7, - 4, 5, 6, 7, - 4, 5, 6, 7, + {Name: "3-tensor, axis 2", Data: Range(Float64, 0, 24), Shape: Shape{3, 2, 4}, Indices: []int{1, 1}, Axis: 2, WillErr: false, + Correct: []float64{1, 1, 5, 5, 9, 9, 13, 13, 17, 17, 21, 21}, CorrectShape: Shape{3, 2, 2}}, - 12, 13, 14, 15, - 12, 13, 14, 15, - 12, 13, 14, 15, + {Name: "Vector, axis 0", Data: Range(Int, 0, 5), Shape: Shape{5}, Indices: []int{1, 1}, Axis: 0, WillErr: false, + Correct: []int{1, 1}, CorrectShape: Shape{2}}, - 20, 21, 22, 23, - 20, 21, 22, 23, - 20, 21, 22, 23, - } - assert.Equal(correct1, a1.Data()) + {Name: "Vector, axis 1", Data: Range(Int, 0, 5), Shape: Shape{5}, Indices: []int{1, 1}, Axis: 1, WillErr: true, + Correct: []int{1, 1}, CorrectShape: Shape{2}}, + {Name: "(4,2) Matrix, with (10) indices", Data: Range(Float32, 0, 8), Shape: Shape{4, 2}, Indices: []int{1, 1, 1, 1, 0, 2, 2, 2, 2, 0}, Axis: 0, WillErr: false, + Correct: []float32{2, 3, 2, 3, 2, 3, 2, 3, 0, 1, 4, 5, 4, 5, 4, 5, 4, 5, 0, 1}, CorrectShape: Shape{10, 2}}, + {Name: "(2,1) Matrx (colvec)m with (10) indies", Data: Range(Float64, 0, 2), Shape: Shape{2, 1}, Indices: []int{1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, Axis: 0, WillErr: false, + Correct: []float64{1, 1, 1, 1, 1, 1, 1, 1, 1, 1}, CorrectShape: Shape{10}, + }, +} - a0, err = e.SelectByIndices(a, indices, 0) - if err != nil { - t.Errorf("%v", err) +func TestDense_SelectByIndices(t *testing.T) { + assert := assert.New(t) + for i, tc := range selByIndicesTests { + T := New(WithShape(tc.Shape...), WithBacking(tc.Data)) + indices := New(WithBacking(tc.Indices)) + ret, err := ByIndices(T, indices, tc.Axis) + if checkErr(t, tc.WillErr, err, tc.Name, i) { + continue + } + assert.Equal(tc.Correct, ret.Data()) + assert.True(tc.CorrectShape.Eq(ret.Shape())) } - correct0 = []float64{8, 9, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 14, 15, 8, 9, 10, 11, 12, 13, 14, 15} - assert.Equal(correct0, a0.Data()) +} - a2, err = e.SelectByIndices(a, indices, 2) - if err != nil { - t.Errorf("%v", err) - } - correct2 = []float64{1, 1, 1, 5, 5, 5, 9, 9, 9, 13, 13, 13, 17, 17, 17, 21, 21, 21} - assert.Equal(correct2, a2.Data()) +var selByIndicesBTests = []struct { + selByIndicesTest + + CorrectGrad interface{} + CorrectGradShape Shape +}{ + { + selByIndicesTest: selByIndicesTests[0], + CorrectGrad: []float64{0, 0, 0, 0, 0, 0, 0, 0, 16, 18, 20, 22, 24, 26, 28, 30, 0, 0, 0, 0, 0, 0, 0, 0}, + CorrectGradShape: Shape{3, 2, 4}, + }, + { + selByIndicesTest: selByIndicesTests[1], + CorrectGrad: []float64{0, 0, 0, 0, 8, 10, 12, 14, 0, 0, 0, 0, 24, 26, 28, 30, 0, 0, 0, 0, 40, 42, 44, 46}, + CorrectGradShape: Shape{3, 2, 4}, + }, + { + selByIndicesTest: selByIndicesTests[2], + CorrectGrad: []float64{0, 2, 0, 0, 0, 10, 0, 0, 0, 18, 0, 0, 0, 26, 0, 0, 0, 34, 0, 0, 0, 42, 0, 0}, + CorrectGradShape: Shape{3, 2, 4}, + }, + { + selByIndicesTest: selByIndicesTests[3], + CorrectGrad: []int{0, 2, 0, 0, 0}, + CorrectGradShape: Shape{5}, + }, + { + selByIndicesTest: selByIndicesTests[5], + CorrectGrad: []float32{4, 6, 8, 12, 8, 12, 0, 0}, + CorrectGradShape: Shape{4, 2}, + }, + { + selByIndicesTest: selByIndicesTests[6], + CorrectGrad: []float64{0, 10}, + CorrectGradShape: Shape{2, 1}, + }, } func TestDense_SelectByIndicesB(t *testing.T) { - a := New(WithBacking([]float64{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}), WithShape(3, 2, 4)) - indices := New(WithBacking([]int{1, 1})) - - t.Logf("a\n%v", a) - - e := StdEng{} - - a1, err := e.SelectByIndices(a, indices, 1) - if err != nil { - t.Errorf("%v", err) - } - t.Logf("a1\n%v", a1) - - a1Grad, err := e.SelectByIndicesB(a, a1, indices, 1) - if err != nil { - t.Errorf("%v", err) - } - t.Logf("a1Grad \n%v", a1Grad) - - a0, err := e.SelectByIndices(a, indices, 0) - if err != nil { - t.Errorf("%v", err) - } - t.Logf("a0\n%v", a0) - a0Grad, err := e.SelectByIndicesB(a, a0, indices, 0) - if err != nil { - t.Errorf("%v", err) + assert := assert.New(t) + for i, tc := range selByIndicesBTests { + T := New(WithShape(tc.Shape...), WithBacking(tc.Data)) + indices := New(WithBacking(tc.Indices)) + ret, err := ByIndices(T, indices, tc.Axis) + if checkErr(t, tc.WillErr, err, tc.Name, i) { + continue + } + grad, err := ByIndicesB(T, ret, indices, tc.Axis) + if checkErr(t, tc.WillErr, err, tc.Name, i) { + continue + } + assert.Equal(tc.CorrectGrad, grad.Data(), "%v", tc.Name) + assert.True(tc.CorrectGradShape.Eq(grad.Shape()), "%v - Grad shape should be %v. Got %v instead", tc.Name, tc.CorrectGradShape, grad.Shape()) } - t.Logf("a0Grad\n%v", a0Grad) - a2, err := e.SelectByIndices(a, indices, 2) - if err != nil { - t.Errorf("%v", err) - } - t.Logf("\n%v", a2) - a2Grad, err := e.SelectByIndicesB(a, a2, indices, 2) - if err != nil { - t.Errorf("%v", err) - } - t.Logf("a2Grad\n%v", a2Grad) }