Array like API

master
EHP 6 years ago
parent e8477fad03
commit 10ca1d2c51
  1. 102
      inference/src/main/java/cz/aprar/bonitoo/inference/CacheInference.java
  2. 68
      inference/src/test/java/cz/aprar/bonitoo/inference/CacheInferenceTest.java

@ -18,6 +18,7 @@ import java.util.Map;
public class CacheInference { public class CacheInference {
private static final TTL[] TTL_VALUES = TTL.values(); private static final TTL[] TTL_VALUES = TTL.values();
private static final int NUM_FEATURES = 20;
private final ObjectMapper mapper = new ObjectMapper(); private final ObjectMapper mapper = new ObjectMapper();
private final Booster booster; private final Booster booster;
@ -37,59 +38,82 @@ public class CacheInference {
labels = loadLabels(labelIS); labels = loadLabels(labelIS);
} }
public TTL cacheTTL(final FlightData data) throws XGBoostError { public TTL[] cacheTTL(final FlightData[] data) throws XGBoostError {
return cacheTTL(data, ZonedDateTime.now()); return cacheTTL(data, ZonedDateTime.now());
} }
/** /**
* Method for backtesting * Method for backtesting
*/ */
public TTL cacheTTL(final FlightData data, final ZonedDateTime now) throws XGBoostError { public TTL[] cacheTTL(final FlightData[] data, final ZonedDateTime now) throws XGBoostError {
final float[] predicts = cacheTTLProbability(data, now); final float[][] predicts = cacheTTLProbability(data, now);
return TTL_VALUES[argmax(predicts)];
if (predicts.length == 0) {
return new TTL[] {};
}
final TTL[] ttl = new TTL[predicts.length];
for (int i = 0; i < predicts.length; i++) {
int idx = 0;
float max = Float.MIN_VALUE;
for (int j = 0; j < predicts[i].length; j++) {
if (predicts[i][j] > max) {
max = predicts[i][j];
idx = j;
}
}
ttl[i] = TTL_VALUES[idx];
}
return ttl;
} }
public float[] cacheTTLProbability(final FlightData data) throws XGBoostError { public float[][] cacheTTLProbability(final FlightData[] data) throws XGBoostError {
return cacheTTLProbability(data, ZonedDateTime.now()); return cacheTTLProbability(data, ZonedDateTime.now());
} }
/** /**
* Method for backtesting * Method for backtesting
*/ */
public float[] cacheTTLProbability(final FlightData data, final ZonedDateTime now) throws XGBoostError { public float[][] cacheTTLProbability(final FlightData[] data, final ZonedDateTime now) throws XGBoostError {
final DMatrix matrix = createMatrix(data, now); final DMatrix matrix = createMatrix(data, now);
float[][] predicts = booster.predict(matrix); return booster.predict(matrix);
return predicts[0];
} }
private Booster loadModel(InputStream model) throws XGBoostError, IOException { private Booster loadModel(InputStream model) throws XGBoostError, IOException {
return XGBoost.loadModel(model); return XGBoost.loadModel(model);
} }
private DMatrix createMatrix(final FlightData data, final ZonedDateTime now) throws XGBoostError { private DMatrix createMatrix(final FlightData[] data, final ZonedDateTime now) throws XGBoostError {
final float[] arr = new float[20]; final float[] arr = new float[data.length * NUM_FEATURES];
arr[0] = (float) data.getInputPrice(); for (int i = 0; i < data.length; i++) {
arr[1] = labels.get("type").getOrDefault(data.getType(), 0); final FlightData current = data[i];
arr[2] = labels.get("flight.inboundSegments.departure").getOrDefault(data.getInboundDeparture(), 0); final int start = NUM_FEATURES * i;
arr[3] = labels.get("flight.inboundSegments.arrival").getOrDefault(data.getInboundArrival(), 0); arr[start] = (float) current.getInputPrice();
arr[4] = labels.get("flight.inboundSegments.origin.airportCode").getOrDefault(data.getInboundOrigin(), 0); arr[start + 1] = labels.get("type").getOrDefault(current.getType(), 0);
arr[5] = labels.get("flight.inboundSegments.destination.airportCode").getOrDefault(data.getInboundDestination(), 0); arr[start + 2] = labels.get("flight.inboundSegments.departure").getOrDefault(current.getInboundDeparture(), 0);
arr[6] = labels.get("flight.inboundSegments.airline.code").getOrDefault(data.getInboundAirlines(), 0); arr[start + 3] = labels.get("flight.inboundSegments.arrival").getOrDefault(current.getInboundArrival(), 0);
arr[7] = labels.get("flight.inboundMCX.code").getOrDefault(data.getInboundMCXAirlines(), 0); arr[start + 4] = labels.get("flight.inboundSegments.origin.airportCode").getOrDefault(current.getInboundOrigin(), 0);
arr[8] = labels.get("flight.outboundSegments.departure").getOrDefault(data.getOutboundDeparture(), 0); arr[start + 5] = labels.get("flight.inboundSegments.destination.airportCode").getOrDefault(current.getInboundDestination(), 0);
arr[9] = labels.get("flight.outboundSegments.arrival").getOrDefault(data.getOutboundArrival(), 0); arr[start + 6] = labels.get("flight.inboundSegments.airline.code").getOrDefault(current.getInboundAirlines(), 0);
arr[10] = labels.get("flight.outboundSegments.origin.airportCode").getOrDefault(data.getOutboundOrigin(), 0); arr[start + 7] = labels.get("flight.inboundMCX.code").getOrDefault(current.getInboundMCXAirlines(), 0);
arr[11] = labels.get("flight.outboundSegments.destination.airportCode").getOrDefault(data.getOutboundDestination(), 0); arr[start + 8] = labels.get("flight.outboundSegments.departure").getOrDefault(current.getOutboundDeparture(), 0);
arr[12] = labels.get("flight.outboundSegments.airline.code").getOrDefault(data.getOutboundAirlines(), 0); arr[start + 9] = labels.get("flight.outboundSegments.arrival").getOrDefault(current.getOutboundArrival(), 0);
arr[13] = labels.get("flight.outboundMCX.code").getOrDefault(data.getOutboundMCXAirlines(), 0); arr[start + 10] = labels.get("flight.outboundSegments.origin.airportCode").getOrDefault(current.getOutboundOrigin(), 0);
arr[14] = data.getNights(); arr[start + 11] = labels.get("flight.outboundSegments.destination.airportCode").getOrDefault(current.getOutboundDestination(), 0);
arr[15] = computePrebooking(data.getOutDepartureDate(), now); arr[start + 12] = labels.get("flight.outboundSegments.airline.code").getOrDefault(current.getOutboundAirlines(), 0);
arr[16] = now.getDayOfMonth(); arr[start + 13] = labels.get("flight.outboundMCX.code").getOrDefault(current.getOutboundMCXAirlines(), 0);
arr[17] = now.getDayOfWeek().getValue() - 1; arr[start + 14] = current.getNights();
arr[18] = data.getOutDepartureDate().getDayOfMonth(); arr[start + 15] = computePrebooking(current.getOutDepartureDate(), now);
arr[19] = data.getOutDepartureDate().getDayOfWeek().getValue() - 1; arr[start + 16] = now.getDayOfMonth();
arr[start + 17] = now.getDayOfWeek().getValue() - 1;
return new DMatrix(arr, 1, arr.length); arr[start + 18] = current.getOutDepartureDate().getDayOfMonth();
arr[start + 19] = current.getOutDepartureDate().getDayOfWeek().getValue() - 1;
}
return new DMatrix(arr, data.length, NUM_FEATURES);
} }
private float computePrebooking(final ZonedDateTime outdeparture, final ZonedDateTime now) { private float computePrebooking(final ZonedDateTime outdeparture, final ZonedDateTime now) {
@ -101,18 +125,4 @@ public class CacheInference {
}; };
return mapper.readValue(labels, typeRef); return mapper.readValue(labels, typeRef);
} }
private int argmax(final float[] data) {
int idx = 0;
float max = Float.MIN_VALUE;
for (int i = 0; i < data.length; i++) {
if (data[i] > max) {
max = data[i];
idx = i;
}
}
return idx;
}
} }

@ -29,34 +29,27 @@ public class CacheInferenceTest {
classLoader.getResourceAsStream("model/encoder.json")); classLoader.getResourceAsStream("model/encoder.json"));
} }
@Test(dataProvider = "inferenceData") @Test
void testInference(final FlightData input, final TTL expected) throws XGBoostError { void testInference() throws XGBoostError {
final TTL result = ci.cacheTTL(input, TEST_DATE); final FlightData[] input = {new FlightData(
assertEquals(result, expected); "WS",
} toList("2020-05-09T23:59:00", "2020-05-10T10:30:00"),
toList("2020-05-10T08:55:00", "2020-05-10T11:30:00"),
@DataProvider(name = "inferenceData") toList("MCO", "FRA"),
public Object[][] inferenceData() { toList("FRA", "PRG"),
return new Object[][]{ toList("LH", "LH"),
{new FlightData( "LH",
"WS", toList("2020-05-01T09:50:00", "2020-05-01T11:55:00"),
toList("2020-05-09T23:59:00", "2020-05-10T10:30:00"), toList("2020-05-01T11:00:00", "2020-05-01T21:55:00"),
toList("2020-05-10T08:55:00", "2020-05-10T11:30:00"), toList("PRG", "FRA"),
toList("MCO", "FRA"), toList("FRA", "MCO"),
toList("FRA", "PRG"), toList("LH", "LH"),
toList("LH", "LH"), "LH",
"LH", 39766.0,
toList("2020-05-01T09:50:00", "2020-05-01T11:55:00"), 7,
toList("2020-05-01T11:00:00", "2020-05-01T21:55:00"), ZonedDateTime.of(2020, 5, 1, 9, 50, 0, 0, UTC)
toList("PRG", "FRA"), ),
toList("FRA", "MCO"), new FlightData(
toList("LH", "LH"),
"LH",
39766.0,
7,
ZonedDateTime.of(2020, 5, 1, 9, 50, 0, 0, UTC)
), TTL.D3},
{new FlightData(
"PYTON", "PYTON",
"", "",
"", "",
@ -73,8 +66,8 @@ public class CacheInferenceTest {
336.258, 336.258,
0, 0,
ZonedDateTime.of(2019, 12, 18, 5, 45, 0, 0, UTC) ZonedDateTime.of(2019, 12, 18, 5, 45, 0, 0, UTC)
), TTL.D14}, ),
{new FlightData( new FlightData(
"AVIA", "AVIA",
toList("2020-02-07T02:25:00", "2020-02-07T14:50:00"), toList("2020-02-07T02:25:00", "2020-02-07T14:50:00"),
toList("2020-02-07T13:10:00", "2020-02-07T16:55:00"), toList("2020-02-07T13:10:00", "2020-02-07T16:55:00"),
@ -91,8 +84,8 @@ public class CacheInferenceTest {
5971.77978, 5971.77978,
4, 4,
ZonedDateTime.of(2020, 1, 28, 10, 35, 0, 0, UTC) ZonedDateTime.of(2020, 1, 28, 10, 35, 0, 0, UTC)
), TTL.D2}, ),
{new FlightData( new FlightData(
"HH", "HH",
toList("2019-11-01T16:30:00", "2019-11-01T23:35:00"), toList("2019-11-01T16:30:00", "2019-11-01T23:35:00"),
toList("2019-11-01T21:12:00", "2019-11-02T07:45:00"), toList("2019-11-01T21:12:00", "2019-11-02T07:45:00"),
@ -109,8 +102,8 @@ public class CacheInferenceTest {
17723.0, 17723.0,
12, 12,
ZonedDateTime.of(2019, 10, 18, 8, 10, 0, 0, UTC) ZonedDateTime.of(2019, 10, 18, 8, 10, 0, 0, UTC)
), TTL.D2}, ),
{new FlightData( new FlightData(
"unknown", "unknown",
toList(ZonedDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)), toList(ZonedDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)),
toList(ZonedDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)), toList(ZonedDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)),
@ -127,8 +120,11 @@ public class CacheInferenceTest {
0.0, 0.0,
0, 0,
ZonedDateTime.now() ZonedDateTime.now()
), TTL.D1} )};
}; final TTL[] expected = {TTL.D3, TTL.D14, TTL.D2, TTL.D2, TTL.D1};
final TTL[] result = ci.cacheTTL(input, TEST_DATE);
assertEquals(result, expected);
} }
private String toList(final String... data) { private String toList(final String... data) {

Loading…
Cancel
Save