From e8477fad03a6b89058f670c18ea3e9aa15919f9e Mon Sep 17 00:00:00 2001 From: Petr Masopust Date: Wed, 30 Oct 2019 14:34:27 +0100 Subject: [PATCH] API performance optimization --- inference/pom.xml | 7 ++ .../bonitoo/inference/CacheInference.java | 53 +++------ .../aprar/bonitoo/inference/FlightData.java | 103 ++++++++++-------- .../bonitoo/inference/CacheInferenceTest.java | 79 +++++++------- 4 files changed, 122 insertions(+), 120 deletions(-) diff --git a/inference/pom.xml b/inference/pom.xml index 56faa9d..6dceeef 100644 --- a/inference/pom.xml +++ b/inference/pom.xml @@ -30,6 +30,13 @@ 7.0.0 test + + + + org.slf4j + slf4j-api + 1.7.28 + diff --git a/inference/src/main/java/cz/aprar/bonitoo/inference/CacheInference.java b/inference/src/main/java/cz/aprar/bonitoo/inference/CacheInference.java index 2d19484..1871224 100644 --- a/inference/src/main/java/cz/aprar/bonitoo/inference/CacheInference.java +++ b/inference/src/main/java/cz/aprar/bonitoo/inference/CacheInference.java @@ -13,11 +13,8 @@ import java.io.InputStream; import java.nio.file.Files; import java.nio.file.Path; import java.time.ZonedDateTime; -import java.time.format.DateTimeFormatter; import java.time.temporal.ChronoUnit; -import java.util.List; import java.util.Map; -import java.util.stream.Collectors; public class CacheInference { private static final TTL[] TTL_VALUES = TTL.values(); @@ -71,40 +68,32 @@ public class CacheInference { private DMatrix createMatrix(final FlightData data, final ZonedDateTime now) throws XGBoostError { final float[] arr = new float[20]; - arr[0] = data.getInputPrice().floatValue(); + arr[0] = (float) data.getInputPrice(); arr[1] = labels.get("type").getOrDefault(data.getType(), 0); - arr[2] = labels.get("flight.inboundSegments.departure").getOrDefault(joinTimestampList(data.getInboundDeparture()), 0); - arr[3] = labels.get("flight.inboundSegments.arrival").getOrDefault(joinTimestampList(data.getInboundArrival()), 0); - arr[4] = labels.get("flight.inboundSegments.origin.airportCode").getOrDefault(joinList(data.getInboundOrigin()), 0); - arr[5] = labels.get("flight.inboundSegments.destination.airportCode").getOrDefault(joinList(data.getInboundDestination()), 0); - arr[6] = labels.get("flight.inboundSegments.airline.code").getOrDefault(joinList(data.getInboundAirlines()), 0); + arr[2] = labels.get("flight.inboundSegments.departure").getOrDefault(data.getInboundDeparture(), 0); + arr[3] = labels.get("flight.inboundSegments.arrival").getOrDefault(data.getInboundArrival(), 0); + arr[4] = labels.get("flight.inboundSegments.origin.airportCode").getOrDefault(data.getInboundOrigin(), 0); + arr[5] = labels.get("flight.inboundSegments.destination.airportCode").getOrDefault(data.getInboundDestination(), 0); + arr[6] = labels.get("flight.inboundSegments.airline.code").getOrDefault(data.getInboundAirlines(), 0); arr[7] = labels.get("flight.inboundMCX.code").getOrDefault(data.getInboundMCXAirlines(), 0); - arr[8] = labels.get("flight.outboundSegments.departure").getOrDefault(joinTimestampList(data.getOutboundDeparture()), 0); - arr[9] = labels.get("flight.outboundSegments.arrival").getOrDefault(joinTimestampList(data.getOutboundArrival()), 0); - arr[10] = labels.get("flight.outboundSegments.origin.airportCode").getOrDefault(joinList(data.getOutboundOrigin()), 0); - arr[11] = labels.get("flight.outboundSegments.destination.airportCode").getOrDefault(joinList(data.getOutboundDestination()), 0); - arr[12] = labels.get("flight.outboundSegments.airline.code").getOrDefault(joinList(data.getOutboundAirlines()), 0); + arr[8] = labels.get("flight.outboundSegments.departure").getOrDefault(data.getOutboundDeparture(), 0); + arr[9] = labels.get("flight.outboundSegments.arrival").getOrDefault(data.getOutboundArrival(), 0); + arr[10] = labels.get("flight.outboundSegments.origin.airportCode").getOrDefault(data.getOutboundOrigin(), 0); + arr[11] = labels.get("flight.outboundSegments.destination.airportCode").getOrDefault(data.getOutboundDestination(), 0); + arr[12] = labels.get("flight.outboundSegments.airline.code").getOrDefault(data.getOutboundAirlines(), 0); arr[13] = labels.get("flight.outboundMCX.code").getOrDefault(data.getOutboundMCXAirlines(), 0); - arr[14] = computeDuration(data.getInboundDeparture(), data.getOutboundDeparture()); - arr[15] = computePrebooking(data.getOutboundDeparture(), now); + arr[14] = data.getNights(); + arr[15] = computePrebooking(data.getOutDepartureDate(), now); arr[16] = now.getDayOfMonth(); arr[17] = now.getDayOfWeek().getValue() - 1; - arr[18] = data.getOutboundDeparture().get(0).getDayOfMonth(); - arr[19] = data.getOutboundDeparture().get(0).getDayOfWeek().getValue() - 1; + arr[18] = data.getOutDepartureDate().getDayOfMonth(); + arr[19] = data.getOutDepartureDate().getDayOfWeek().getValue() - 1; return new DMatrix(arr, 1, arr.length); } - private float computeDuration(final List indeparture, final List outdeparture) { - if (indeparture.isEmpty()) { - return 0; - } - - return ChronoUnit.DAYS.between(outdeparture.get(0), indeparture.get(0)); - } - - private float computePrebooking(final List outdeparture, final ZonedDateTime now) { - return ChronoUnit.DAYS.between(now, outdeparture.get(0)); + private float computePrebooking(final ZonedDateTime outdeparture, final ZonedDateTime now) { + return ChronoUnit.DAYS.between(now, outdeparture); } private Map> loadLabels(InputStream labels) throws IOException { @@ -113,14 +102,6 @@ public class CacheInference { return mapper.readValue(labels, typeRef); } - private String joinTimestampList(final List data) { - return data.stream().map(DateTimeFormatter.ISO_LOCAL_DATE_TIME::format).collect(Collectors.joining("|")); - } - - private String joinList(final List data) { - return String.join("|", data); - } - private int argmax(final float[] data) { int idx = 0; float max = Float.MIN_VALUE; diff --git a/inference/src/main/java/cz/aprar/bonitoo/inference/FlightData.java b/inference/src/main/java/cz/aprar/bonitoo/inference/FlightData.java index 969bdd1..904c9a5 100644 --- a/inference/src/main/java/cz/aprar/bonitoo/inference/FlightData.java +++ b/inference/src/main/java/cz/aprar/bonitoo/inference/FlightData.java @@ -1,32 +1,31 @@ package cz.aprar.bonitoo.inference; import java.time.ZonedDateTime; -import java.util.Collections; -import java.util.List; public class FlightData { private final String type; - private final List inboundDeparture; - private final List inboundArrival; - private final List inboundOrigin; - private final List inboundDestination; - private final List inboundAirlines; + private final String inboundDeparture; + private final String inboundArrival; + private final String inboundOrigin; + private final String inboundDestination; + private final String inboundAirlines; private final String inboundMCXAirlines; - private final List outboundDeparture; - private final List outboundArrival; - private final List outboundOrigin; - private final List outboundDestination; - private final List outboundAirlines; + private final String outboundDeparture; + private final String outboundArrival; + private final String outboundOrigin; + private final String outboundDestination; + private final String outboundAirlines; private final String outboundMCXAirlines; - private final Double inputPrice; - - public FlightData(final String type, final List inboundDeparture, - final List inboundArrival, final List inboundOrigin, - final List inboundDestination, final List inboundAirlines, - final String inboundMCXAirlines, final List outboundDeparture, - final List outboundArrival, final List outboundOrigin, - final List outboundDestination, final List outboundAirlines, - final String outboundMCXAirlines, final Double inputPrice) { + private final double inputPrice; + private final int nights; + private final ZonedDateTime outDepartureDate; + + public FlightData(final String type, final String inboundDeparture, final String inboundArrival, + final String inboundOrigin, final String inboundDestination, final String inboundAirlines, + final String inboundMCXAirlines, final String outboundDeparture, final String outboundArrival, + final String outboundOrigin, final String outboundDestination, final String outboundAirlines, + final String outboundMCXAirlines, final double inputPrice, final int nights, + final ZonedDateTime outDepartureDate) { this.type = type; this.inboundDeparture = inboundDeparture; this.inboundArrival = inboundArrival; @@ -41,61 +40,71 @@ public class FlightData { this.outboundAirlines = outboundAirlines; this.outboundMCXAirlines = outboundMCXAirlines; this.inputPrice = inputPrice; + this.nights = nights; + this.outDepartureDate = outDepartureDate; } public String getType() { return type; } - public List getInboundDeparture() { - return Collections.unmodifiableList(inboundDeparture); + public String getInboundDeparture() { + return inboundDeparture; } - public List getInboundArrival() { - return Collections.unmodifiableList(inboundArrival); + public String getInboundArrival() { + return inboundArrival; } - public List getInboundOrigin() { - return Collections.unmodifiableList(inboundOrigin); + public String getInboundOrigin() { + return inboundOrigin; } - public List getInboundDestination() { - return Collections.unmodifiableList(inboundDestination); + public String getInboundDestination() { + return inboundDestination; } - public List getOutboundDeparture() { - return Collections.unmodifiableList(outboundDeparture); + public String getInboundAirlines() { + return inboundAirlines; } - public List getOutboundArrival() { - return Collections.unmodifiableList(outboundArrival); + public String getInboundMCXAirlines() { + return inboundMCXAirlines; } - public List getOutboundOrigin() { - return Collections.unmodifiableList(outboundOrigin); + public String getOutboundDeparture() { + return outboundDeparture; } - public List getOutboundDestination() { - return Collections.unmodifiableList(outboundDestination); + public String getOutboundArrival() { + return outboundArrival; } - public Double getInputPrice() { - return inputPrice; + public String getOutboundOrigin() { + return outboundOrigin; } - public List getInboundAirlines() { - return Collections.unmodifiableList(inboundAirlines); + public String getOutboundDestination() { + return outboundDestination; } - public List getOutboundAirlines() { - return Collections.unmodifiableList(outboundAirlines); - } - - public String getInboundMCXAirlines() { - return inboundMCXAirlines; + public String getOutboundAirlines() { + return outboundAirlines; } public String getOutboundMCXAirlines() { return outboundMCXAirlines; } + + public double getInputPrice() { + return inputPrice; + } + + public int getNights() { + return nights; + } + + public ZonedDateTime getOutDepartureDate() { + return outDepartureDate; + } } diff --git a/inference/src/test/java/cz/aprar/bonitoo/inference/CacheInferenceTest.java b/inference/src/test/java/cz/aprar/bonitoo/inference/CacheInferenceTest.java index 6103219..88193e7 100644 --- a/inference/src/test/java/cz/aprar/bonitoo/inference/CacheInferenceTest.java +++ b/inference/src/test/java/cz/aprar/bonitoo/inference/CacheInferenceTest.java @@ -17,7 +17,8 @@ import static java.util.Collections.emptyList; import static org.testng.Assert.assertEquals; public class CacheInferenceTest { - private static final ZonedDateTime TEST_DATE = ZonedDateTime.of(2019, 10, 10, 0, 0, 0, 0, ZoneId.of("UTC")); + private static final ZoneId UTC = ZoneId.of("UTC"); + private static final ZonedDateTime TEST_DATE = ZonedDateTime.of(2019, 10, 10, 0, 0, 0, 0, UTC); private CacheInference ci; @@ -39,94 +40,98 @@ public class CacheInferenceTest { return new Object[][]{ {new FlightData( "WS", - toTimestampList("2020-05-09T23:59:00", "2020-05-10T10:30:00"), - toTimestampList("2020-05-10T08:55:00", "2020-05-10T11:30:00"), + toList("2020-05-09T23:59:00", "2020-05-10T10:30:00"), + toList("2020-05-10T08:55:00", "2020-05-10T11:30:00"), toList("MCO", "FRA"), toList("FRA", "PRG"), toList("LH", "LH"), "LH", - toTimestampList("2020-05-01T09:50:00", "2020-05-01T11:55:00"), - toTimestampList("2020-05-01T11:00:00", "2020-05-01T21:55:00"), + toList("2020-05-01T09:50:00", "2020-05-01T11:55:00"), + toList("2020-05-01T11:00:00", "2020-05-01T21:55:00"), toList("PRG", "FRA"), toList("FRA", "MCO"), toList("LH", "LH"), "LH", - 39766.0 + 39766.0, + 7, + ZonedDateTime.of(2020, 5, 1, 9, 50, 0, 0, UTC) ), TTL.D3}, {new FlightData( "PYTON", - emptyList(), - emptyList(), - emptyList(), - emptyList(), - emptyList(), "", - toTimestampList("2019-12-18T05:45:00"), - toTimestampList("2019-12-18T08:05:00"), + "", + "", + "", + "", + "", + toList("2019-12-18T05:45:00"), + toList("2019-12-18T08:05:00"), toList("KRK"), toList("BVA"), toList("FR"), "FR", - 336.258 + 336.258, + 0, + ZonedDateTime.of(2019, 12, 18, 5, 45, 0, 0, UTC) ), TTL.D14}, {new FlightData( "AVIA", - toTimestampList("2020-02-07T02:25:00", "2020-02-07T14:50:00"), - toTimestampList("2020-02-07T13:10:00", "2020-02-07T16:55:00"), + toList("2020-02-07T02:25:00", "2020-02-07T14:50:00"), + toList("2020-02-07T13:10:00", "2020-02-07T16:55:00"), toList("LAX", "LHR"), toList("LHR", "PRG"), toList("AA", "BA"), "AA", - toTimestampList("2020-01-28T10:35:00", "2020-01-28T14:40:00"), - toTimestampList("2020-01-28T12:45:00", "2020-01-29T01:45:00"), + toList("2020-01-28T10:35:00", "2020-01-28T14:40:00"), + toList("2020-01-28T12:45:00", "2020-01-29T01:45:00"), toList("PRG", "HEL"), toList("HEL", "LAX"), toList("AY", "AY"), "AY", - 5971.77978 + 5971.77978, + 4, + ZonedDateTime.of(2020, 1, 28, 10, 35, 0, 0, UTC) ), TTL.D2}, {new FlightData( "HH", - toTimestampList("2019-11-01T16:30:00", "2019-11-01T23:35:00"), - toTimestampList("2019-11-01T21:12:00", "2019-11-02T07:45:00"), + toList("2019-11-01T16:30:00", "2019-11-01T23:35:00"), + toList("2019-11-01T21:12:00", "2019-11-02T07:45:00"), toList("YVR", "YUL"), toList("YUL", "VIE"), toList("LH", "LH"), "LH", - toTimestampList("2019-10-18T08:10:00", "2019-10-18T11:30:00"), - toTimestampList("2019-10-18T09:40:00", "2019-10-18T21:25:00"), + toList("2019-10-18T08:10:00", "2019-10-18T11:30:00"), + toList("2019-10-18T09:40:00", "2019-10-18T21:25:00"), toList("VIE", "FRA"), toList("FRA", "YVR"), toList("LH", "LH"), "LH", - 17723.0 + 17723.0, + 12, + ZonedDateTime.of(2019, 10, 18, 8, 10, 0, 0, UTC) ), TTL.D2}, {new FlightData( "unknown", - toTimestampList(ZonedDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)), - toTimestampList(ZonedDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)), + toList(ZonedDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)), + toList(ZonedDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)), toList("unknown"), toList("unknown"), toList("unknown"), "unknown", - toTimestampList(ZonedDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)), - toTimestampList(ZonedDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)), + toList(ZonedDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)), + toList(ZonedDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)), toList("unknown"), toList("unknown"), toList("unknown"), "unknown", - 0.0 + 0.0, + 0, + ZonedDateTime.now() ), TTL.D1} }; } - private List toList(final String... data) { - return Arrays.asList(data); - } - - private List toTimestampList(final String... data) { - return Arrays.stream(data) - .map(x -> x + "Z") - .map(x -> ZonedDateTime.parse(x, DateTimeFormatter.ISO_ZONED_DATE_TIME)).collect(Collectors.toList()); + private String toList(final String... data) { + return String.join("|", data); } }