diff --git a/inference/src/main/java/cz/aprar/bonitoo/inference/CacheInference.java b/inference/src/main/java/cz/aprar/bonitoo/inference/CacheInference.java index 1871224..06ab8b0 100644 --- a/inference/src/main/java/cz/aprar/bonitoo/inference/CacheInference.java +++ b/inference/src/main/java/cz/aprar/bonitoo/inference/CacheInference.java @@ -18,6 +18,7 @@ import java.util.Map; public class CacheInference { private static final TTL[] TTL_VALUES = TTL.values(); + private static final int NUM_FEATURES = 20; private final ObjectMapper mapper = new ObjectMapper(); private final Booster booster; @@ -37,59 +38,82 @@ public class CacheInference { labels = loadLabels(labelIS); } - public TTL cacheTTL(final FlightData data) throws XGBoostError { + public TTL[] cacheTTL(final FlightData[] data) throws XGBoostError { return cacheTTL(data, ZonedDateTime.now()); } /** * Method for backtesting */ - public TTL cacheTTL(final FlightData data, final ZonedDateTime now) throws XGBoostError { - final float[] predicts = cacheTTLProbability(data, now); - return TTL_VALUES[argmax(predicts)]; + public TTL[] cacheTTL(final FlightData[] data, final ZonedDateTime now) throws XGBoostError { + final float[][] predicts = cacheTTLProbability(data, now); + + if (predicts.length == 0) { + return new TTL[] {}; + } + + final TTL[] ttl = new TTL[predicts.length]; + for (int i = 0; i < predicts.length; i++) { + int idx = 0; + float max = Float.MIN_VALUE; + + for (int j = 0; j < predicts[i].length; j++) { + if (predicts[i][j] > max) { + max = predicts[i][j]; + idx = j; + } + } + + ttl[i] = TTL_VALUES[idx]; + } + + return ttl; } - public float[] cacheTTLProbability(final FlightData data) throws XGBoostError { + public float[][] cacheTTLProbability(final FlightData[] data) throws XGBoostError { return cacheTTLProbability(data, ZonedDateTime.now()); } /** * Method for backtesting */ - public float[] cacheTTLProbability(final FlightData data, final ZonedDateTime now) throws XGBoostError { + public float[][] cacheTTLProbability(final FlightData[] data, final ZonedDateTime now) throws XGBoostError { final DMatrix matrix = createMatrix(data, now); - float[][] predicts = booster.predict(matrix); - return predicts[0]; + return booster.predict(matrix); } private Booster loadModel(InputStream model) throws XGBoostError, IOException { return XGBoost.loadModel(model); } - private DMatrix createMatrix(final FlightData data, final ZonedDateTime now) throws XGBoostError { - final float[] arr = new float[20]; - arr[0] = (float) data.getInputPrice(); - arr[1] = labels.get("type").getOrDefault(data.getType(), 0); - arr[2] = labels.get("flight.inboundSegments.departure").getOrDefault(data.getInboundDeparture(), 0); - arr[3] = labels.get("flight.inboundSegments.arrival").getOrDefault(data.getInboundArrival(), 0); - arr[4] = labels.get("flight.inboundSegments.origin.airportCode").getOrDefault(data.getInboundOrigin(), 0); - arr[5] = labels.get("flight.inboundSegments.destination.airportCode").getOrDefault(data.getInboundDestination(), 0); - arr[6] = labels.get("flight.inboundSegments.airline.code").getOrDefault(data.getInboundAirlines(), 0); - arr[7] = labels.get("flight.inboundMCX.code").getOrDefault(data.getInboundMCXAirlines(), 0); - arr[8] = labels.get("flight.outboundSegments.departure").getOrDefault(data.getOutboundDeparture(), 0); - arr[9] = labels.get("flight.outboundSegments.arrival").getOrDefault(data.getOutboundArrival(), 0); - arr[10] = labels.get("flight.outboundSegments.origin.airportCode").getOrDefault(data.getOutboundOrigin(), 0); - arr[11] = labels.get("flight.outboundSegments.destination.airportCode").getOrDefault(data.getOutboundDestination(), 0); - arr[12] = labels.get("flight.outboundSegments.airline.code").getOrDefault(data.getOutboundAirlines(), 0); - arr[13] = labels.get("flight.outboundMCX.code").getOrDefault(data.getOutboundMCXAirlines(), 0); - arr[14] = data.getNights(); - arr[15] = computePrebooking(data.getOutDepartureDate(), now); - arr[16] = now.getDayOfMonth(); - arr[17] = now.getDayOfWeek().getValue() - 1; - arr[18] = data.getOutDepartureDate().getDayOfMonth(); - arr[19] = data.getOutDepartureDate().getDayOfWeek().getValue() - 1; - - return new DMatrix(arr, 1, arr.length); + private DMatrix createMatrix(final FlightData[] data, final ZonedDateTime now) throws XGBoostError { + final float[] arr = new float[data.length * NUM_FEATURES]; + for (int i = 0; i < data.length; i++) { + final FlightData current = data[i]; + final int start = NUM_FEATURES * i; + arr[start] = (float) current.getInputPrice(); + arr[start + 1] = labels.get("type").getOrDefault(current.getType(), 0); + arr[start + 2] = labels.get("flight.inboundSegments.departure").getOrDefault(current.getInboundDeparture(), 0); + arr[start + 3] = labels.get("flight.inboundSegments.arrival").getOrDefault(current.getInboundArrival(), 0); + arr[start + 4] = labels.get("flight.inboundSegments.origin.airportCode").getOrDefault(current.getInboundOrigin(), 0); + arr[start + 5] = labels.get("flight.inboundSegments.destination.airportCode").getOrDefault(current.getInboundDestination(), 0); + arr[start + 6] = labels.get("flight.inboundSegments.airline.code").getOrDefault(current.getInboundAirlines(), 0); + arr[start + 7] = labels.get("flight.inboundMCX.code").getOrDefault(current.getInboundMCXAirlines(), 0); + arr[start + 8] = labels.get("flight.outboundSegments.departure").getOrDefault(current.getOutboundDeparture(), 0); + arr[start + 9] = labels.get("flight.outboundSegments.arrival").getOrDefault(current.getOutboundArrival(), 0); + arr[start + 10] = labels.get("flight.outboundSegments.origin.airportCode").getOrDefault(current.getOutboundOrigin(), 0); + arr[start + 11] = labels.get("flight.outboundSegments.destination.airportCode").getOrDefault(current.getOutboundDestination(), 0); + arr[start + 12] = labels.get("flight.outboundSegments.airline.code").getOrDefault(current.getOutboundAirlines(), 0); + arr[start + 13] = labels.get("flight.outboundMCX.code").getOrDefault(current.getOutboundMCXAirlines(), 0); + arr[start + 14] = current.getNights(); + arr[start + 15] = computePrebooking(current.getOutDepartureDate(), now); + arr[start + 16] = now.getDayOfMonth(); + arr[start + 17] = now.getDayOfWeek().getValue() - 1; + arr[start + 18] = current.getOutDepartureDate().getDayOfMonth(); + arr[start + 19] = current.getOutDepartureDate().getDayOfWeek().getValue() - 1; + } + + return new DMatrix(arr, data.length, NUM_FEATURES); } private float computePrebooking(final ZonedDateTime outdeparture, final ZonedDateTime now) { @@ -101,18 +125,4 @@ public class CacheInference { }; return mapper.readValue(labels, typeRef); } - - private int argmax(final float[] data) { - int idx = 0; - float max = Float.MIN_VALUE; - - for (int i = 0; i < data.length; i++) { - if (data[i] > max) { - max = data[i]; - idx = i; - } - } - - return idx; - } } diff --git a/inference/src/test/java/cz/aprar/bonitoo/inference/CacheInferenceTest.java b/inference/src/test/java/cz/aprar/bonitoo/inference/CacheInferenceTest.java index 88193e7..81e7303 100644 --- a/inference/src/test/java/cz/aprar/bonitoo/inference/CacheInferenceTest.java +++ b/inference/src/test/java/cz/aprar/bonitoo/inference/CacheInferenceTest.java @@ -29,34 +29,27 @@ public class CacheInferenceTest { classLoader.getResourceAsStream("model/encoder.json")); } - @Test(dataProvider = "inferenceData") - void testInference(final FlightData input, final TTL expected) throws XGBoostError { - final TTL result = ci.cacheTTL(input, TEST_DATE); - assertEquals(result, expected); - } - - @DataProvider(name = "inferenceData") - public Object[][] inferenceData() { - return new Object[][]{ - {new FlightData( - "WS", - toList("2020-05-09T23:59:00", "2020-05-10T10:30:00"), - toList("2020-05-10T08:55:00", "2020-05-10T11:30:00"), - toList("MCO", "FRA"), - toList("FRA", "PRG"), - toList("LH", "LH"), - "LH", - toList("2020-05-01T09:50:00", "2020-05-01T11:55:00"), - toList("2020-05-01T11:00:00", "2020-05-01T21:55:00"), - toList("PRG", "FRA"), - toList("FRA", "MCO"), - toList("LH", "LH"), - "LH", - 39766.0, - 7, - ZonedDateTime.of(2020, 5, 1, 9, 50, 0, 0, UTC) - ), TTL.D3}, - {new FlightData( + @Test + void testInference() throws XGBoostError { + final FlightData[] input = {new FlightData( + "WS", + toList("2020-05-09T23:59:00", "2020-05-10T10:30:00"), + toList("2020-05-10T08:55:00", "2020-05-10T11:30:00"), + toList("MCO", "FRA"), + toList("FRA", "PRG"), + toList("LH", "LH"), + "LH", + toList("2020-05-01T09:50:00", "2020-05-01T11:55:00"), + toList("2020-05-01T11:00:00", "2020-05-01T21:55:00"), + toList("PRG", "FRA"), + toList("FRA", "MCO"), + toList("LH", "LH"), + "LH", + 39766.0, + 7, + ZonedDateTime.of(2020, 5, 1, 9, 50, 0, 0, UTC) + ), + new FlightData( "PYTON", "", "", @@ -73,8 +66,8 @@ public class CacheInferenceTest { 336.258, 0, ZonedDateTime.of(2019, 12, 18, 5, 45, 0, 0, UTC) - ), TTL.D14}, - {new FlightData( + ), + new FlightData( "AVIA", toList("2020-02-07T02:25:00", "2020-02-07T14:50:00"), toList("2020-02-07T13:10:00", "2020-02-07T16:55:00"), @@ -91,8 +84,8 @@ public class CacheInferenceTest { 5971.77978, 4, ZonedDateTime.of(2020, 1, 28, 10, 35, 0, 0, UTC) - ), TTL.D2}, - {new FlightData( + ), + new FlightData( "HH", toList("2019-11-01T16:30:00", "2019-11-01T23:35:00"), toList("2019-11-01T21:12:00", "2019-11-02T07:45:00"), @@ -109,8 +102,8 @@ public class CacheInferenceTest { 17723.0, 12, ZonedDateTime.of(2019, 10, 18, 8, 10, 0, 0, UTC) - ), TTL.D2}, - {new FlightData( + ), + new FlightData( "unknown", toList(ZonedDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)), toList(ZonedDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)), @@ -127,8 +120,11 @@ public class CacheInferenceTest { 0.0, 0, ZonedDateTime.now() - ), TTL.D1} - }; + )}; + final TTL[] expected = {TTL.D3, TTL.D14, TTL.D2, TTL.D2, TTL.D1}; + + final TTL[] result = ci.cacheTTL(input, TEST_DATE); + assertEquals(result, expected); } private String toList(final String... data) {