API performance optimization

master
Petr Masopust 6 years ago
parent 73a20885c7
commit e8477fad03
  1. 7
      inference/pom.xml
  2. 53
      inference/src/main/java/cz/aprar/bonitoo/inference/CacheInference.java
  3. 103
      inference/src/main/java/cz/aprar/bonitoo/inference/FlightData.java
  4. 79
      inference/src/test/java/cz/aprar/bonitoo/inference/CacheInferenceTest.java

@ -30,6 +30,13 @@
<version>7.0.0</version>
<scope>test</scope>
</dependency>
<!-- external dependencies -->
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<version>1.7.28</version>
</dependency>
</dependencies>
<build>

@ -13,11 +13,8 @@ import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.time.ZonedDateTime;
import java.time.format.DateTimeFormatter;
import java.time.temporal.ChronoUnit;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
public class CacheInference {
private static final TTL[] TTL_VALUES = TTL.values();
@ -71,40 +68,32 @@ public class CacheInference {
private DMatrix createMatrix(final FlightData data, final ZonedDateTime now) throws XGBoostError {
final float[] arr = new float[20];
arr[0] = data.getInputPrice().floatValue();
arr[0] = (float) data.getInputPrice();
arr[1] = labels.get("type").getOrDefault(data.getType(), 0);
arr[2] = labels.get("flight.inboundSegments.departure").getOrDefault(joinTimestampList(data.getInboundDeparture()), 0);
arr[3] = labels.get("flight.inboundSegments.arrival").getOrDefault(joinTimestampList(data.getInboundArrival()), 0);
arr[4] = labels.get("flight.inboundSegments.origin.airportCode").getOrDefault(joinList(data.getInboundOrigin()), 0);
arr[5] = labels.get("flight.inboundSegments.destination.airportCode").getOrDefault(joinList(data.getInboundDestination()), 0);
arr[6] = labels.get("flight.inboundSegments.airline.code").getOrDefault(joinList(data.getInboundAirlines()), 0);
arr[2] = labels.get("flight.inboundSegments.departure").getOrDefault(data.getInboundDeparture(), 0);
arr[3] = labels.get("flight.inboundSegments.arrival").getOrDefault(data.getInboundArrival(), 0);
arr[4] = labels.get("flight.inboundSegments.origin.airportCode").getOrDefault(data.getInboundOrigin(), 0);
arr[5] = labels.get("flight.inboundSegments.destination.airportCode").getOrDefault(data.getInboundDestination(), 0);
arr[6] = labels.get("flight.inboundSegments.airline.code").getOrDefault(data.getInboundAirlines(), 0);
arr[7] = labels.get("flight.inboundMCX.code").getOrDefault(data.getInboundMCXAirlines(), 0);
arr[8] = labels.get("flight.outboundSegments.departure").getOrDefault(joinTimestampList(data.getOutboundDeparture()), 0);
arr[9] = labels.get("flight.outboundSegments.arrival").getOrDefault(joinTimestampList(data.getOutboundArrival()), 0);
arr[10] = labels.get("flight.outboundSegments.origin.airportCode").getOrDefault(joinList(data.getOutboundOrigin()), 0);
arr[11] = labels.get("flight.outboundSegments.destination.airportCode").getOrDefault(joinList(data.getOutboundDestination()), 0);
arr[12] = labels.get("flight.outboundSegments.airline.code").getOrDefault(joinList(data.getOutboundAirlines()), 0);
arr[8] = labels.get("flight.outboundSegments.departure").getOrDefault(data.getOutboundDeparture(), 0);
arr[9] = labels.get("flight.outboundSegments.arrival").getOrDefault(data.getOutboundArrival(), 0);
arr[10] = labels.get("flight.outboundSegments.origin.airportCode").getOrDefault(data.getOutboundOrigin(), 0);
arr[11] = labels.get("flight.outboundSegments.destination.airportCode").getOrDefault(data.getOutboundDestination(), 0);
arr[12] = labels.get("flight.outboundSegments.airline.code").getOrDefault(data.getOutboundAirlines(), 0);
arr[13] = labels.get("flight.outboundMCX.code").getOrDefault(data.getOutboundMCXAirlines(), 0);
arr[14] = computeDuration(data.getInboundDeparture(), data.getOutboundDeparture());
arr[15] = computePrebooking(data.getOutboundDeparture(), now);
arr[14] = data.getNights();
arr[15] = computePrebooking(data.getOutDepartureDate(), now);
arr[16] = now.getDayOfMonth();
arr[17] = now.getDayOfWeek().getValue() - 1;
arr[18] = data.getOutboundDeparture().get(0).getDayOfMonth();
arr[19] = data.getOutboundDeparture().get(0).getDayOfWeek().getValue() - 1;
arr[18] = data.getOutDepartureDate().getDayOfMonth();
arr[19] = data.getOutDepartureDate().getDayOfWeek().getValue() - 1;
return new DMatrix(arr, 1, arr.length);
}
private float computeDuration(final List<ZonedDateTime> indeparture, final List<ZonedDateTime> outdeparture) {
if (indeparture.isEmpty()) {
return 0;
}
return ChronoUnit.DAYS.between(outdeparture.get(0), indeparture.get(0));
}
private float computePrebooking(final List<ZonedDateTime> outdeparture, final ZonedDateTime now) {
return ChronoUnit.DAYS.between(now, outdeparture.get(0));
private float computePrebooking(final ZonedDateTime outdeparture, final ZonedDateTime now) {
return ChronoUnit.DAYS.between(now, outdeparture);
}
private Map<String, Map<String, Integer>> loadLabels(InputStream labels) throws IOException {
@ -113,14 +102,6 @@ public class CacheInference {
return mapper.readValue(labels, typeRef);
}
private String joinTimestampList(final List<ZonedDateTime> data) {
return data.stream().map(DateTimeFormatter.ISO_LOCAL_DATE_TIME::format).collect(Collectors.joining("|"));
}
private String joinList(final List<String> data) {
return String.join("|", data);
}
private int argmax(final float[] data) {
int idx = 0;
float max = Float.MIN_VALUE;

@ -1,32 +1,31 @@
package cz.aprar.bonitoo.inference;
import java.time.ZonedDateTime;
import java.util.Collections;
import java.util.List;
public class FlightData {
private final String type;
private final List<ZonedDateTime> inboundDeparture;
private final List<ZonedDateTime> inboundArrival;
private final List<String> inboundOrigin;
private final List<String> inboundDestination;
private final List<String> inboundAirlines;
private final String inboundDeparture;
private final String inboundArrival;
private final String inboundOrigin;
private final String inboundDestination;
private final String inboundAirlines;
private final String inboundMCXAirlines;
private final List<ZonedDateTime> outboundDeparture;
private final List<ZonedDateTime> outboundArrival;
private final List<String> outboundOrigin;
private final List<String> outboundDestination;
private final List<String> outboundAirlines;
private final String outboundDeparture;
private final String outboundArrival;
private final String outboundOrigin;
private final String outboundDestination;
private final String outboundAirlines;
private final String outboundMCXAirlines;
private final Double inputPrice;
public FlightData(final String type, final List<ZonedDateTime> inboundDeparture,
final List<ZonedDateTime> inboundArrival, final List<String> inboundOrigin,
final List<String> inboundDestination, final List<String> inboundAirlines,
final String inboundMCXAirlines, final List<ZonedDateTime> outboundDeparture,
final List<ZonedDateTime> outboundArrival, final List<String> outboundOrigin,
final List<String> outboundDestination, final List<String> outboundAirlines,
final String outboundMCXAirlines, final Double inputPrice) {
private final double inputPrice;
private final int nights;
private final ZonedDateTime outDepartureDate;
public FlightData(final String type, final String inboundDeparture, final String inboundArrival,
final String inboundOrigin, final String inboundDestination, final String inboundAirlines,
final String inboundMCXAirlines, final String outboundDeparture, final String outboundArrival,
final String outboundOrigin, final String outboundDestination, final String outboundAirlines,
final String outboundMCXAirlines, final double inputPrice, final int nights,
final ZonedDateTime outDepartureDate) {
this.type = type;
this.inboundDeparture = inboundDeparture;
this.inboundArrival = inboundArrival;
@ -41,61 +40,71 @@ public class FlightData {
this.outboundAirlines = outboundAirlines;
this.outboundMCXAirlines = outboundMCXAirlines;
this.inputPrice = inputPrice;
this.nights = nights;
this.outDepartureDate = outDepartureDate;
}
public String getType() {
return type;
}
public List<ZonedDateTime> getInboundDeparture() {
return Collections.unmodifiableList(inboundDeparture);
public String getInboundDeparture() {
return inboundDeparture;
}
public List<ZonedDateTime> getInboundArrival() {
return Collections.unmodifiableList(inboundArrival);
public String getInboundArrival() {
return inboundArrival;
}
public List<String> getInboundOrigin() {
return Collections.unmodifiableList(inboundOrigin);
public String getInboundOrigin() {
return inboundOrigin;
}
public List<String> getInboundDestination() {
return Collections.unmodifiableList(inboundDestination);
public String getInboundDestination() {
return inboundDestination;
}
public List<ZonedDateTime> getOutboundDeparture() {
return Collections.unmodifiableList(outboundDeparture);
public String getInboundAirlines() {
return inboundAirlines;
}
public List<ZonedDateTime> getOutboundArrival() {
return Collections.unmodifiableList(outboundArrival);
public String getInboundMCXAirlines() {
return inboundMCXAirlines;
}
public List<String> getOutboundOrigin() {
return Collections.unmodifiableList(outboundOrigin);
public String getOutboundDeparture() {
return outboundDeparture;
}
public List<String> getOutboundDestination() {
return Collections.unmodifiableList(outboundDestination);
public String getOutboundArrival() {
return outboundArrival;
}
public Double getInputPrice() {
return inputPrice;
public String getOutboundOrigin() {
return outboundOrigin;
}
public List<String> getInboundAirlines() {
return Collections.unmodifiableList(inboundAirlines);
public String getOutboundDestination() {
return outboundDestination;
}
public List<String> getOutboundAirlines() {
return Collections.unmodifiableList(outboundAirlines);
}
public String getInboundMCXAirlines() {
return inboundMCXAirlines;
public String getOutboundAirlines() {
return outboundAirlines;
}
public String getOutboundMCXAirlines() {
return outboundMCXAirlines;
}
public double getInputPrice() {
return inputPrice;
}
public int getNights() {
return nights;
}
public ZonedDateTime getOutDepartureDate() {
return outDepartureDate;
}
}

@ -17,7 +17,8 @@ import static java.util.Collections.emptyList;
import static org.testng.Assert.assertEquals;
public class CacheInferenceTest {
private static final ZonedDateTime TEST_DATE = ZonedDateTime.of(2019, 10, 10, 0, 0, 0, 0, ZoneId.of("UTC"));
private static final ZoneId UTC = ZoneId.of("UTC");
private static final ZonedDateTime TEST_DATE = ZonedDateTime.of(2019, 10, 10, 0, 0, 0, 0, UTC);
private CacheInference ci;
@ -39,94 +40,98 @@ public class CacheInferenceTest {
return new Object[][]{
{new FlightData(
"WS",
toTimestampList("2020-05-09T23:59:00", "2020-05-10T10:30:00"),
toTimestampList("2020-05-10T08:55:00", "2020-05-10T11:30:00"),
toList("2020-05-09T23:59:00", "2020-05-10T10:30:00"),
toList("2020-05-10T08:55:00", "2020-05-10T11:30:00"),
toList("MCO", "FRA"),
toList("FRA", "PRG"),
toList("LH", "LH"),
"LH",
toTimestampList("2020-05-01T09:50:00", "2020-05-01T11:55:00"),
toTimestampList("2020-05-01T11:00:00", "2020-05-01T21:55:00"),
toList("2020-05-01T09:50:00", "2020-05-01T11:55:00"),
toList("2020-05-01T11:00:00", "2020-05-01T21:55:00"),
toList("PRG", "FRA"),
toList("FRA", "MCO"),
toList("LH", "LH"),
"LH",
39766.0
39766.0,
7,
ZonedDateTime.of(2020, 5, 1, 9, 50, 0, 0, UTC)
), TTL.D3},
{new FlightData(
"PYTON",
emptyList(),
emptyList(),
emptyList(),
emptyList(),
emptyList(),
"",
toTimestampList("2019-12-18T05:45:00"),
toTimestampList("2019-12-18T08:05:00"),
"",
"",
"",
"",
"",
toList("2019-12-18T05:45:00"),
toList("2019-12-18T08:05:00"),
toList("KRK"),
toList("BVA"),
toList("FR"),
"FR",
336.258
336.258,
0,
ZonedDateTime.of(2019, 12, 18, 5, 45, 0, 0, UTC)
), TTL.D14},
{new FlightData(
"AVIA",
toTimestampList("2020-02-07T02:25:00", "2020-02-07T14:50:00"),
toTimestampList("2020-02-07T13:10:00", "2020-02-07T16:55:00"),
toList("2020-02-07T02:25:00", "2020-02-07T14:50:00"),
toList("2020-02-07T13:10:00", "2020-02-07T16:55:00"),
toList("LAX", "LHR"),
toList("LHR", "PRG"),
toList("AA", "BA"),
"AA",
toTimestampList("2020-01-28T10:35:00", "2020-01-28T14:40:00"),
toTimestampList("2020-01-28T12:45:00", "2020-01-29T01:45:00"),
toList("2020-01-28T10:35:00", "2020-01-28T14:40:00"),
toList("2020-01-28T12:45:00", "2020-01-29T01:45:00"),
toList("PRG", "HEL"),
toList("HEL", "LAX"),
toList("AY", "AY"),
"AY",
5971.77978
5971.77978,
4,
ZonedDateTime.of(2020, 1, 28, 10, 35, 0, 0, UTC)
), TTL.D2},
{new FlightData(
"HH",
toTimestampList("2019-11-01T16:30:00", "2019-11-01T23:35:00"),
toTimestampList("2019-11-01T21:12:00", "2019-11-02T07:45:00"),
toList("2019-11-01T16:30:00", "2019-11-01T23:35:00"),
toList("2019-11-01T21:12:00", "2019-11-02T07:45:00"),
toList("YVR", "YUL"),
toList("YUL", "VIE"),
toList("LH", "LH"),
"LH",
toTimestampList("2019-10-18T08:10:00", "2019-10-18T11:30:00"),
toTimestampList("2019-10-18T09:40:00", "2019-10-18T21:25:00"),
toList("2019-10-18T08:10:00", "2019-10-18T11:30:00"),
toList("2019-10-18T09:40:00", "2019-10-18T21:25:00"),
toList("VIE", "FRA"),
toList("FRA", "YVR"),
toList("LH", "LH"),
"LH",
17723.0
17723.0,
12,
ZonedDateTime.of(2019, 10, 18, 8, 10, 0, 0, UTC)
), TTL.D2},
{new FlightData(
"unknown",
toTimestampList(ZonedDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)),
toTimestampList(ZonedDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)),
toList(ZonedDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)),
toList(ZonedDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)),
toList("unknown"),
toList("unknown"),
toList("unknown"),
"unknown",
toTimestampList(ZonedDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)),
toTimestampList(ZonedDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)),
toList(ZonedDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)),
toList(ZonedDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)),
toList("unknown"),
toList("unknown"),
toList("unknown"),
"unknown",
0.0
0.0,
0,
ZonedDateTime.now()
), TTL.D1}
};
}
private List<String> toList(final String... data) {
return Arrays.asList(data);
}
private List<ZonedDateTime> toTimestampList(final String... data) {
return Arrays.stream(data)
.map(x -> x + "Z")
.map(x -> ZonedDateTime.parse(x, DateTimeFormatter.ISO_ZONED_DATE_TIME)).collect(Collectors.toList());
private String toList(final String... data) {
return String.join("|", data);
}
}

Loading…
Cancel
Save