API performance optimization

master
Petr Masopust 6 years ago
parent 73a20885c7
commit e8477fad03
  1. 7
      inference/pom.xml
  2. 53
      inference/src/main/java/cz/aprar/bonitoo/inference/CacheInference.java
  3. 103
      inference/src/main/java/cz/aprar/bonitoo/inference/FlightData.java
  4. 79
      inference/src/test/java/cz/aprar/bonitoo/inference/CacheInferenceTest.java

@ -30,6 +30,13 @@
<version>7.0.0</version> <version>7.0.0</version>
<scope>test</scope> <scope>test</scope>
</dependency> </dependency>
<!-- external dependencies -->
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<version>1.7.28</version>
</dependency>
</dependencies> </dependencies>
<build> <build>

@ -13,11 +13,8 @@ import java.io.InputStream;
import java.nio.file.Files; import java.nio.file.Files;
import java.nio.file.Path; import java.nio.file.Path;
import java.time.ZonedDateTime; import java.time.ZonedDateTime;
import java.time.format.DateTimeFormatter;
import java.time.temporal.ChronoUnit; import java.time.temporal.ChronoUnit;
import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.stream.Collectors;
public class CacheInference { public class CacheInference {
private static final TTL[] TTL_VALUES = TTL.values(); private static final TTL[] TTL_VALUES = TTL.values();
@ -71,40 +68,32 @@ public class CacheInference {
private DMatrix createMatrix(final FlightData data, final ZonedDateTime now) throws XGBoostError { private DMatrix createMatrix(final FlightData data, final ZonedDateTime now) throws XGBoostError {
final float[] arr = new float[20]; final float[] arr = new float[20];
arr[0] = data.getInputPrice().floatValue(); arr[0] = (float) data.getInputPrice();
arr[1] = labels.get("type").getOrDefault(data.getType(), 0); arr[1] = labels.get("type").getOrDefault(data.getType(), 0);
arr[2] = labels.get("flight.inboundSegments.departure").getOrDefault(joinTimestampList(data.getInboundDeparture()), 0); arr[2] = labels.get("flight.inboundSegments.departure").getOrDefault(data.getInboundDeparture(), 0);
arr[3] = labels.get("flight.inboundSegments.arrival").getOrDefault(joinTimestampList(data.getInboundArrival()), 0); arr[3] = labels.get("flight.inboundSegments.arrival").getOrDefault(data.getInboundArrival(), 0);
arr[4] = labels.get("flight.inboundSegments.origin.airportCode").getOrDefault(joinList(data.getInboundOrigin()), 0); arr[4] = labels.get("flight.inboundSegments.origin.airportCode").getOrDefault(data.getInboundOrigin(), 0);
arr[5] = labels.get("flight.inboundSegments.destination.airportCode").getOrDefault(joinList(data.getInboundDestination()), 0); arr[5] = labels.get("flight.inboundSegments.destination.airportCode").getOrDefault(data.getInboundDestination(), 0);
arr[6] = labels.get("flight.inboundSegments.airline.code").getOrDefault(joinList(data.getInboundAirlines()), 0); arr[6] = labels.get("flight.inboundSegments.airline.code").getOrDefault(data.getInboundAirlines(), 0);
arr[7] = labels.get("flight.inboundMCX.code").getOrDefault(data.getInboundMCXAirlines(), 0); arr[7] = labels.get("flight.inboundMCX.code").getOrDefault(data.getInboundMCXAirlines(), 0);
arr[8] = labels.get("flight.outboundSegments.departure").getOrDefault(joinTimestampList(data.getOutboundDeparture()), 0); arr[8] = labels.get("flight.outboundSegments.departure").getOrDefault(data.getOutboundDeparture(), 0);
arr[9] = labels.get("flight.outboundSegments.arrival").getOrDefault(joinTimestampList(data.getOutboundArrival()), 0); arr[9] = labels.get("flight.outboundSegments.arrival").getOrDefault(data.getOutboundArrival(), 0);
arr[10] = labels.get("flight.outboundSegments.origin.airportCode").getOrDefault(joinList(data.getOutboundOrigin()), 0); arr[10] = labels.get("flight.outboundSegments.origin.airportCode").getOrDefault(data.getOutboundOrigin(), 0);
arr[11] = labels.get("flight.outboundSegments.destination.airportCode").getOrDefault(joinList(data.getOutboundDestination()), 0); arr[11] = labels.get("flight.outboundSegments.destination.airportCode").getOrDefault(data.getOutboundDestination(), 0);
arr[12] = labels.get("flight.outboundSegments.airline.code").getOrDefault(joinList(data.getOutboundAirlines()), 0); arr[12] = labels.get("flight.outboundSegments.airline.code").getOrDefault(data.getOutboundAirlines(), 0);
arr[13] = labels.get("flight.outboundMCX.code").getOrDefault(data.getOutboundMCXAirlines(), 0); arr[13] = labels.get("flight.outboundMCX.code").getOrDefault(data.getOutboundMCXAirlines(), 0);
arr[14] = computeDuration(data.getInboundDeparture(), data.getOutboundDeparture()); arr[14] = data.getNights();
arr[15] = computePrebooking(data.getOutboundDeparture(), now); arr[15] = computePrebooking(data.getOutDepartureDate(), now);
arr[16] = now.getDayOfMonth(); arr[16] = now.getDayOfMonth();
arr[17] = now.getDayOfWeek().getValue() - 1; arr[17] = now.getDayOfWeek().getValue() - 1;
arr[18] = data.getOutboundDeparture().get(0).getDayOfMonth(); arr[18] = data.getOutDepartureDate().getDayOfMonth();
arr[19] = data.getOutboundDeparture().get(0).getDayOfWeek().getValue() - 1; arr[19] = data.getOutDepartureDate().getDayOfWeek().getValue() - 1;
return new DMatrix(arr, 1, arr.length); return new DMatrix(arr, 1, arr.length);
} }
private float computeDuration(final List<ZonedDateTime> indeparture, final List<ZonedDateTime> outdeparture) { private float computePrebooking(final ZonedDateTime outdeparture, final ZonedDateTime now) {
if (indeparture.isEmpty()) { return ChronoUnit.DAYS.between(now, outdeparture);
return 0;
}
return ChronoUnit.DAYS.between(outdeparture.get(0), indeparture.get(0));
}
private float computePrebooking(final List<ZonedDateTime> outdeparture, final ZonedDateTime now) {
return ChronoUnit.DAYS.between(now, outdeparture.get(0));
} }
private Map<String, Map<String, Integer>> loadLabels(InputStream labels) throws IOException { private Map<String, Map<String, Integer>> loadLabels(InputStream labels) throws IOException {
@ -113,14 +102,6 @@ public class CacheInference {
return mapper.readValue(labels, typeRef); return mapper.readValue(labels, typeRef);
} }
private String joinTimestampList(final List<ZonedDateTime> data) {
return data.stream().map(DateTimeFormatter.ISO_LOCAL_DATE_TIME::format).collect(Collectors.joining("|"));
}
private String joinList(final List<String> data) {
return String.join("|", data);
}
private int argmax(final float[] data) { private int argmax(final float[] data) {
int idx = 0; int idx = 0;
float max = Float.MIN_VALUE; float max = Float.MIN_VALUE;

@ -1,32 +1,31 @@
package cz.aprar.bonitoo.inference; package cz.aprar.bonitoo.inference;
import java.time.ZonedDateTime; import java.time.ZonedDateTime;
import java.util.Collections;
import java.util.List;
public class FlightData { public class FlightData {
private final String type; private final String type;
private final List<ZonedDateTime> inboundDeparture; private final String inboundDeparture;
private final List<ZonedDateTime> inboundArrival; private final String inboundArrival;
private final List<String> inboundOrigin; private final String inboundOrigin;
private final List<String> inboundDestination; private final String inboundDestination;
private final List<String> inboundAirlines; private final String inboundAirlines;
private final String inboundMCXAirlines; private final String inboundMCXAirlines;
private final List<ZonedDateTime> outboundDeparture; private final String outboundDeparture;
private final List<ZonedDateTime> outboundArrival; private final String outboundArrival;
private final List<String> outboundOrigin; private final String outboundOrigin;
private final List<String> outboundDestination; private final String outboundDestination;
private final List<String> outboundAirlines; private final String outboundAirlines;
private final String outboundMCXAirlines; private final String outboundMCXAirlines;
private final Double inputPrice; private final double inputPrice;
private final int nights;
public FlightData(final String type, final List<ZonedDateTime> inboundDeparture, private final ZonedDateTime outDepartureDate;
final List<ZonedDateTime> inboundArrival, final List<String> inboundOrigin,
final List<String> inboundDestination, final List<String> inboundAirlines, public FlightData(final String type, final String inboundDeparture, final String inboundArrival,
final String inboundMCXAirlines, final List<ZonedDateTime> outboundDeparture, final String inboundOrigin, final String inboundDestination, final String inboundAirlines,
final List<ZonedDateTime> outboundArrival, final List<String> outboundOrigin, final String inboundMCXAirlines, final String outboundDeparture, final String outboundArrival,
final List<String> outboundDestination, final List<String> outboundAirlines, final String outboundOrigin, final String outboundDestination, final String outboundAirlines,
final String outboundMCXAirlines, final Double inputPrice) { final String outboundMCXAirlines, final double inputPrice, final int nights,
final ZonedDateTime outDepartureDate) {
this.type = type; this.type = type;
this.inboundDeparture = inboundDeparture; this.inboundDeparture = inboundDeparture;
this.inboundArrival = inboundArrival; this.inboundArrival = inboundArrival;
@ -41,61 +40,71 @@ public class FlightData {
this.outboundAirlines = outboundAirlines; this.outboundAirlines = outboundAirlines;
this.outboundMCXAirlines = outboundMCXAirlines; this.outboundMCXAirlines = outboundMCXAirlines;
this.inputPrice = inputPrice; this.inputPrice = inputPrice;
this.nights = nights;
this.outDepartureDate = outDepartureDate;
} }
public String getType() { public String getType() {
return type; return type;
} }
public List<ZonedDateTime> getInboundDeparture() { public String getInboundDeparture() {
return Collections.unmodifiableList(inboundDeparture); return inboundDeparture;
} }
public List<ZonedDateTime> getInboundArrival() { public String getInboundArrival() {
return Collections.unmodifiableList(inboundArrival); return inboundArrival;
} }
public List<String> getInboundOrigin() { public String getInboundOrigin() {
return Collections.unmodifiableList(inboundOrigin); return inboundOrigin;
} }
public List<String> getInboundDestination() { public String getInboundDestination() {
return Collections.unmodifiableList(inboundDestination); return inboundDestination;
} }
public List<ZonedDateTime> getOutboundDeparture() { public String getInboundAirlines() {
return Collections.unmodifiableList(outboundDeparture); return inboundAirlines;
} }
public List<ZonedDateTime> getOutboundArrival() { public String getInboundMCXAirlines() {
return Collections.unmodifiableList(outboundArrival); return inboundMCXAirlines;
} }
public List<String> getOutboundOrigin() { public String getOutboundDeparture() {
return Collections.unmodifiableList(outboundOrigin); return outboundDeparture;
} }
public List<String> getOutboundDestination() { public String getOutboundArrival() {
return Collections.unmodifiableList(outboundDestination); return outboundArrival;
} }
public Double getInputPrice() { public String getOutboundOrigin() {
return inputPrice; return outboundOrigin;
} }
public List<String> getInboundAirlines() { public String getOutboundDestination() {
return Collections.unmodifiableList(inboundAirlines); return outboundDestination;
} }
public List<String> getOutboundAirlines() { public String getOutboundAirlines() {
return Collections.unmodifiableList(outboundAirlines); return outboundAirlines;
}
public String getInboundMCXAirlines() {
return inboundMCXAirlines;
} }
public String getOutboundMCXAirlines() { public String getOutboundMCXAirlines() {
return outboundMCXAirlines; return outboundMCXAirlines;
} }
public double getInputPrice() {
return inputPrice;
}
public int getNights() {
return nights;
}
public ZonedDateTime getOutDepartureDate() {
return outDepartureDate;
}
} }

@ -17,7 +17,8 @@ import static java.util.Collections.emptyList;
import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertEquals;
public class CacheInferenceTest { public class CacheInferenceTest {
private static final ZonedDateTime TEST_DATE = ZonedDateTime.of(2019, 10, 10, 0, 0, 0, 0, ZoneId.of("UTC")); private static final ZoneId UTC = ZoneId.of("UTC");
private static final ZonedDateTime TEST_DATE = ZonedDateTime.of(2019, 10, 10, 0, 0, 0, 0, UTC);
private CacheInference ci; private CacheInference ci;
@ -39,94 +40,98 @@ public class CacheInferenceTest {
return new Object[][]{ return new Object[][]{
{new FlightData( {new FlightData(
"WS", "WS",
toTimestampList("2020-05-09T23:59:00", "2020-05-10T10:30:00"), toList("2020-05-09T23:59:00", "2020-05-10T10:30:00"),
toTimestampList("2020-05-10T08:55:00", "2020-05-10T11:30:00"), toList("2020-05-10T08:55:00", "2020-05-10T11:30:00"),
toList("MCO", "FRA"), toList("MCO", "FRA"),
toList("FRA", "PRG"), toList("FRA", "PRG"),
toList("LH", "LH"), toList("LH", "LH"),
"LH", "LH",
toTimestampList("2020-05-01T09:50:00", "2020-05-01T11:55:00"), toList("2020-05-01T09:50:00", "2020-05-01T11:55:00"),
toTimestampList("2020-05-01T11:00:00", "2020-05-01T21:55:00"), toList("2020-05-01T11:00:00", "2020-05-01T21:55:00"),
toList("PRG", "FRA"), toList("PRG", "FRA"),
toList("FRA", "MCO"), toList("FRA", "MCO"),
toList("LH", "LH"), toList("LH", "LH"),
"LH", "LH",
39766.0 39766.0,
7,
ZonedDateTime.of(2020, 5, 1, 9, 50, 0, 0, UTC)
), TTL.D3}, ), TTL.D3},
{new FlightData( {new FlightData(
"PYTON", "PYTON",
emptyList(),
emptyList(),
emptyList(),
emptyList(),
emptyList(),
"", "",
toTimestampList("2019-12-18T05:45:00"), "",
toTimestampList("2019-12-18T08:05:00"), "",
"",
"",
"",
toList("2019-12-18T05:45:00"),
toList("2019-12-18T08:05:00"),
toList("KRK"), toList("KRK"),
toList("BVA"), toList("BVA"),
toList("FR"), toList("FR"),
"FR", "FR",
336.258 336.258,
0,
ZonedDateTime.of(2019, 12, 18, 5, 45, 0, 0, UTC)
), TTL.D14}, ), TTL.D14},
{new FlightData( {new FlightData(
"AVIA", "AVIA",
toTimestampList("2020-02-07T02:25:00", "2020-02-07T14:50:00"), toList("2020-02-07T02:25:00", "2020-02-07T14:50:00"),
toTimestampList("2020-02-07T13:10:00", "2020-02-07T16:55:00"), toList("2020-02-07T13:10:00", "2020-02-07T16:55:00"),
toList("LAX", "LHR"), toList("LAX", "LHR"),
toList("LHR", "PRG"), toList("LHR", "PRG"),
toList("AA", "BA"), toList("AA", "BA"),
"AA", "AA",
toTimestampList("2020-01-28T10:35:00", "2020-01-28T14:40:00"), toList("2020-01-28T10:35:00", "2020-01-28T14:40:00"),
toTimestampList("2020-01-28T12:45:00", "2020-01-29T01:45:00"), toList("2020-01-28T12:45:00", "2020-01-29T01:45:00"),
toList("PRG", "HEL"), toList("PRG", "HEL"),
toList("HEL", "LAX"), toList("HEL", "LAX"),
toList("AY", "AY"), toList("AY", "AY"),
"AY", "AY",
5971.77978 5971.77978,
4,
ZonedDateTime.of(2020, 1, 28, 10, 35, 0, 0, UTC)
), TTL.D2}, ), TTL.D2},
{new FlightData( {new FlightData(
"HH", "HH",
toTimestampList("2019-11-01T16:30:00", "2019-11-01T23:35:00"), toList("2019-11-01T16:30:00", "2019-11-01T23:35:00"),
toTimestampList("2019-11-01T21:12:00", "2019-11-02T07:45:00"), toList("2019-11-01T21:12:00", "2019-11-02T07:45:00"),
toList("YVR", "YUL"), toList("YVR", "YUL"),
toList("YUL", "VIE"), toList("YUL", "VIE"),
toList("LH", "LH"), toList("LH", "LH"),
"LH", "LH",
toTimestampList("2019-10-18T08:10:00", "2019-10-18T11:30:00"), toList("2019-10-18T08:10:00", "2019-10-18T11:30:00"),
toTimestampList("2019-10-18T09:40:00", "2019-10-18T21:25:00"), toList("2019-10-18T09:40:00", "2019-10-18T21:25:00"),
toList("VIE", "FRA"), toList("VIE", "FRA"),
toList("FRA", "YVR"), toList("FRA", "YVR"),
toList("LH", "LH"), toList("LH", "LH"),
"LH", "LH",
17723.0 17723.0,
12,
ZonedDateTime.of(2019, 10, 18, 8, 10, 0, 0, UTC)
), TTL.D2}, ), TTL.D2},
{new FlightData( {new FlightData(
"unknown", "unknown",
toTimestampList(ZonedDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)), toList(ZonedDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)),
toTimestampList(ZonedDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)), toList(ZonedDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)),
toList("unknown"), toList("unknown"),
toList("unknown"), toList("unknown"),
toList("unknown"), toList("unknown"),
"unknown", "unknown",
toTimestampList(ZonedDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)), toList(ZonedDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)),
toTimestampList(ZonedDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)), toList(ZonedDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)),
toList("unknown"), toList("unknown"),
toList("unknown"), toList("unknown"),
toList("unknown"), toList("unknown"),
"unknown", "unknown",
0.0 0.0,
0,
ZonedDateTime.now()
), TTL.D1} ), TTL.D1}
}; };
} }
private List<String> toList(final String... data) { private String toList(final String... data) {
return Arrays.asList(data); return String.join("|", data);
}
private List<ZonedDateTime> toTimestampList(final String... data) {
return Arrays.stream(data)
.map(x -> x + "Z")
.map(x -> ZonedDateTime.parse(x, DateTimeFormatter.ISO_ZONED_DATE_TIME)).collect(Collectors.toList());
} }
} }

Loading…
Cancel
Save