Skip to content

Commit

Permalink
#4421 - Improve performance of the overlap iterator
Browse files Browse the repository at this point in the history
- Fix SlidingWindow and added test
  • Loading branch information
reckart committed Jan 2, 2024
1 parent 360a882 commit bcf2e4c
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import opennlp.tools.namefind.NameSample;
import opennlp.tools.util.ObjectStream;

public class NameSampleStream
class NameSampleStream
implements ObjectStream<NameSample>, AutoCloseable
{
private List<NameSample> samples;
Expand All @@ -41,6 +41,7 @@ public NameSample read()
if (iterator != null && iterator.hasNext()) {
return iterator.next();
}

return null;
}

Expand All @@ -56,5 +57,4 @@ public void close()
samples = null;
iterator = null;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,12 @@ public Range predict(PredictionContext aContext, CAS aCas, int aBegin, int aEnd)
break;
}

var firstToken = tokens.get(0);
var lastToken = tokens.get(tokens.size() - 1);

predictionCount++;
predictedRangeBegin = Math.min(predictedRangeBegin, tokens.get(0).getBegin());
predictedRangeEnd = Math.max(predictedRangeEnd, tokens.get(tokens.size() - 1).getEnd());
predictedRangeBegin = Math.min(predictedRangeBegin, firstToken.getBegin());
predictedRangeEnd = Math.max(predictedRangeEnd, lastToken.getEnd());

var tokenTexts = tokens.stream() //
.map(AnnotationFS::getCoveredText) //
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,12 @@ private class SlidingWindowIterator
{
private final Iterator<T> tokenIterator;

private List<T> nextWindow;
private LinkedList<T> nextWindow;

public SlidingWindowIterator(Iterator<T> aTokenIterator)
{
tokenIterator = aTokenIterator;
nextWindow = makeSample(tokenIterator, new LinkedList<T>(), windowSize, windowOverlap);
nextWindow = makeSample(tokenIterator, new LinkedList<T>());
}

@Override
Expand All @@ -81,24 +81,29 @@ public boolean hasNext()
public List<T> next()
{
var currentWindow = nextWindow;
nextWindow = makeSample(tokenIterator, nextWindow, windowSize, windowOverlap);
nextWindow = makeSample(tokenIterator, nextWindow);
return currentWindow;
}

private List<T> makeSample(Iterator<T> aFreshTokenIterator, List<T> aTokens, int aMaxLength,
int aOverlap)
private LinkedList<T> makeSample(Iterator<T> aFreshTokenIterator, LinkedList<T> aPrevWindow)
{
var result = new LinkedList<T>();

if (!aFreshTokenIterator.hasNext()) {
return Collections.emptyList();
return result;
}

var result = new LinkedList<T>();

// Add tokens overlapping with previous sample
var size = 0;
if (aOverlap > 0) {
var overlapIterator = result.descendingIterator();
if (windowOverlap > 0) {
var overlapIterator = aPrevWindow.descendingIterator();

while (overlapIterator.hasNext()) {
if (size >= windowOverlap && !result.isEmpty()) {
// Overlap size reached
break;
}

var token = overlapIterator.next();
var tokenText = token.getCoveredText();

Expand All @@ -107,17 +112,21 @@ private List<T> makeSample(Iterator<T> aFreshTokenIterator, List<T> aTokens, int
}

size += tokenText.length();
if (size >= aOverlap && !result.isEmpty()) {
// Overlap size reached
break;
}
result.add(0, token);

result.add(token);
}

Collections.reverse(result);
}

// Add fresh tokens
var freshTokenAdded = false;
while (aFreshTokenIterator.hasNext()) {
if (size >= windowSize && freshTokenAdded) {
// Maximum sample size reached
break;
}

var token = aFreshTokenIterator.next();
var tokenText = token.getCoveredText();

Expand All @@ -126,17 +135,13 @@ private List<T> makeSample(Iterator<T> aFreshTokenIterator, List<T> aTokens, int
}

size += tokenText.length();
if (size >= aMaxLength && freshTokenAdded) {
// Maximum sample size reached
break;
}

result.add(token);
freshTokenAdded = true;
}

if (!freshTokenAdded) {
return Collections.emptyList();
return new LinkedList<T>();
}

return result;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Licensed to the Technische Universität Darmstadt under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The Technische Universität Darmstadt
* licenses this file to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License.
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package de.tudarmstadt.ukp.inception.recommendation.imls.opennlp.ner;

import static org.assertj.core.api.Assertions.assertThat;

import java.util.ArrayList;
import java.util.HashSet;

import org.apache.uima.fit.factory.JCasBuilder;
import org.apache.uima.fit.factory.JCasFactory;
import org.junit.jupiter.api.Test;

import de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Sentence;
import de.tudarmstadt.ukp.dkpro.core.api.segmentation.type.Token;

class SlidingWindowTest
{

@Test
void test() throws Exception
{
var sentences = 10;
var sentenceLength = 10;

var cas = JCasFactory.createJCas();
var casBuilder = new JCasBuilder(cas);

var expectedTokens = new HashSet<String>();

for (int s = 0; s < sentences; s++) {
int sBegin = casBuilder.getPosition();

for (int t = 0; t < sentenceLength; t++) {
var token = String.format("%02d", s * sentenceLength + t);
expectedTokens.add(token);
casBuilder.add(token, Token.class);
casBuilder.add(" ");
}

casBuilder.add(sBegin, Sentence.class);
casBuilder.add(".\n");
}

casBuilder.close();

var sut = new SlidingWindow<>(cas.getCas(), Token.class, 20, 10);

var actualTokens = new HashSet<String>();
int base = 0;
for (var unit : sut) {
unit.stream().map(Token::getCoveredText).forEach(actualTokens::add);

var expected = new ArrayList<String>();
for (int n = 0; n < sentenceLength; n++) {
expected.add(String.format("%02d", base + n));
}

assertThat(unit.stream().map(Token::getCoveredText).toList())
.containsExactlyElementsOf(expected);
base += sentenceLength / 2;
}

assertThat(actualTokens).containsExactlyInAnyOrderElementsOf(expectedTokens);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
<Loggers>
<Logger name="org.deeplearning4j.optimize.listeners" level="INFO"/>
<Logger name="de.tudarmstadt.ukp.dkpro.core.api.datasets.DatasetFactory" level="INFO"/>
<Logger name="de.tudarmstadt.ukp.inception.recommendation" level="TRACE"/>
<Logger name="de.tudarmstadt.ukp.inception.recommendation" level="INFO"/>
<Logger name="de.tudarmstadt.ukp.inception.recommendation.api.util.OverlapIterator" level="INFO"/>

<Root level="ERROR">
Expand Down

0 comments on commit bcf2e4c

Please sign in to comment.