Array like API

master
EHP 6 years ago
parent e8477fad03
commit 10ca1d2c51
  1. 102
      inference/src/main/java/cz/aprar/bonitoo/inference/CacheInference.java
  2. 68
      inference/src/test/java/cz/aprar/bonitoo/inference/CacheInferenceTest.java

@ -18,6 +18,7 @@ import java.util.Map;
public class CacheInference {
private static final TTL[] TTL_VALUES = TTL.values();
private static final int NUM_FEATURES = 20;
private final ObjectMapper mapper = new ObjectMapper();
private final Booster booster;
@ -37,59 +38,82 @@ public class CacheInference {
labels = loadLabels(labelIS);
}
public TTL cacheTTL(final FlightData data) throws XGBoostError {
public TTL[] cacheTTL(final FlightData[] data) throws XGBoostError {
return cacheTTL(data, ZonedDateTime.now());
}
/**
* Method for backtesting
*/
public TTL cacheTTL(final FlightData data, final ZonedDateTime now) throws XGBoostError {
final float[] predicts = cacheTTLProbability(data, now);
return TTL_VALUES[argmax(predicts)];
public TTL[] cacheTTL(final FlightData[] data, final ZonedDateTime now) throws XGBoostError {
final float[][] predicts = cacheTTLProbability(data, now);
if (predicts.length == 0) {
return new TTL[] {};
}
final TTL[] ttl = new TTL[predicts.length];
for (int i = 0; i < predicts.length; i++) {
int idx = 0;
float max = Float.MIN_VALUE;
for (int j = 0; j < predicts[i].length; j++) {
if (predicts[i][j] > max) {
max = predicts[i][j];
idx = j;
}
}
ttl[i] = TTL_VALUES[idx];
}
return ttl;
}
public float[] cacheTTLProbability(final FlightData data) throws XGBoostError {
public float[][] cacheTTLProbability(final FlightData[] data) throws XGBoostError {
return cacheTTLProbability(data, ZonedDateTime.now());
}
/**
* Method for backtesting
*/
public float[] cacheTTLProbability(final FlightData data, final ZonedDateTime now) throws XGBoostError {
public float[][] cacheTTLProbability(final FlightData[] data, final ZonedDateTime now) throws XGBoostError {
final DMatrix matrix = createMatrix(data, now);
float[][] predicts = booster.predict(matrix);
return predicts[0];
return booster.predict(matrix);
}
private Booster loadModel(InputStream model) throws XGBoostError, IOException {
return XGBoost.loadModel(model);
}
private DMatrix createMatrix(final FlightData data, final ZonedDateTime now) throws XGBoostError {
final float[] arr = new float[20];
arr[0] = (float) data.getInputPrice();
arr[1] = labels.get("type").getOrDefault(data.getType(), 0);
arr[2] = labels.get("flight.inboundSegments.departure").getOrDefault(data.getInboundDeparture(), 0);
arr[3] = labels.get("flight.inboundSegments.arrival").getOrDefault(data.getInboundArrival(), 0);
arr[4] = labels.get("flight.inboundSegments.origin.airportCode").getOrDefault(data.getInboundOrigin(), 0);
arr[5] = labels.get("flight.inboundSegments.destination.airportCode").getOrDefault(data.getInboundDestination(), 0);
arr[6] = labels.get("flight.inboundSegments.airline.code").getOrDefault(data.getInboundAirlines(), 0);
arr[7] = labels.get("flight.inboundMCX.code").getOrDefault(data.getInboundMCXAirlines(), 0);
arr[8] = labels.get("flight.outboundSegments.departure").getOrDefault(data.getOutboundDeparture(), 0);
arr[9] = labels.get("flight.outboundSegments.arrival").getOrDefault(data.getOutboundArrival(), 0);
arr[10] = labels.get("flight.outboundSegments.origin.airportCode").getOrDefault(data.getOutboundOrigin(), 0);
arr[11] = labels.get("flight.outboundSegments.destination.airportCode").getOrDefault(data.getOutboundDestination(), 0);
arr[12] = labels.get("flight.outboundSegments.airline.code").getOrDefault(data.getOutboundAirlines(), 0);
arr[13] = labels.get("flight.outboundMCX.code").getOrDefault(data.getOutboundMCXAirlines(), 0);
arr[14] = data.getNights();
arr[15] = computePrebooking(data.getOutDepartureDate(), now);
arr[16] = now.getDayOfMonth();
arr[17] = now.getDayOfWeek().getValue() - 1;
arr[18] = data.getOutDepartureDate().getDayOfMonth();
arr[19] = data.getOutDepartureDate().getDayOfWeek().getValue() - 1;
return new DMatrix(arr, 1, arr.length);
private DMatrix createMatrix(final FlightData[] data, final ZonedDateTime now) throws XGBoostError {
final float[] arr = new float[data.length * NUM_FEATURES];
for (int i = 0; i < data.length; i++) {
final FlightData current = data[i];
final int start = NUM_FEATURES * i;
arr[start] = (float) current.getInputPrice();
arr[start + 1] = labels.get("type").getOrDefault(current.getType(), 0);
arr[start + 2] = labels.get("flight.inboundSegments.departure").getOrDefault(current.getInboundDeparture(), 0);
arr[start + 3] = labels.get("flight.inboundSegments.arrival").getOrDefault(current.getInboundArrival(), 0);
arr[start + 4] = labels.get("flight.inboundSegments.origin.airportCode").getOrDefault(current.getInboundOrigin(), 0);
arr[start + 5] = labels.get("flight.inboundSegments.destination.airportCode").getOrDefault(current.getInboundDestination(), 0);
arr[start + 6] = labels.get("flight.inboundSegments.airline.code").getOrDefault(current.getInboundAirlines(), 0);
arr[start + 7] = labels.get("flight.inboundMCX.code").getOrDefault(current.getInboundMCXAirlines(), 0);
arr[start + 8] = labels.get("flight.outboundSegments.departure").getOrDefault(current.getOutboundDeparture(), 0);
arr[start + 9] = labels.get("flight.outboundSegments.arrival").getOrDefault(current.getOutboundArrival(), 0);
arr[start + 10] = labels.get("flight.outboundSegments.origin.airportCode").getOrDefault(current.getOutboundOrigin(), 0);
arr[start + 11] = labels.get("flight.outboundSegments.destination.airportCode").getOrDefault(current.getOutboundDestination(), 0);
arr[start + 12] = labels.get("flight.outboundSegments.airline.code").getOrDefault(current.getOutboundAirlines(), 0);
arr[start + 13] = labels.get("flight.outboundMCX.code").getOrDefault(current.getOutboundMCXAirlines(), 0);
arr[start + 14] = current.getNights();
arr[start + 15] = computePrebooking(current.getOutDepartureDate(), now);
arr[start + 16] = now.getDayOfMonth();
arr[start + 17] = now.getDayOfWeek().getValue() - 1;
arr[start + 18] = current.getOutDepartureDate().getDayOfMonth();
arr[start + 19] = current.getOutDepartureDate().getDayOfWeek().getValue() - 1;
}
return new DMatrix(arr, data.length, NUM_FEATURES);
}
private float computePrebooking(final ZonedDateTime outdeparture, final ZonedDateTime now) {
@ -101,18 +125,4 @@ public class CacheInference {
};
return mapper.readValue(labels, typeRef);
}
private int argmax(final float[] data) {
int idx = 0;
float max = Float.MIN_VALUE;
for (int i = 0; i < data.length; i++) {
if (data[i] > max) {
max = data[i];
idx = i;
}
}
return idx;
}
}

@ -29,34 +29,27 @@ public class CacheInferenceTest {
classLoader.getResourceAsStream("model/encoder.json"));
}
@Test(dataProvider = "inferenceData")
void testInference(final FlightData input, final TTL expected) throws XGBoostError {
final TTL result = ci.cacheTTL(input, TEST_DATE);
assertEquals(result, expected);
}
@DataProvider(name = "inferenceData")
public Object[][] inferenceData() {
return new Object[][]{
{new FlightData(
"WS",
toList("2020-05-09T23:59:00", "2020-05-10T10:30:00"),
toList("2020-05-10T08:55:00", "2020-05-10T11:30:00"),
toList("MCO", "FRA"),
toList("FRA", "PRG"),
toList("LH", "LH"),
"LH",
toList("2020-05-01T09:50:00", "2020-05-01T11:55:00"),
toList("2020-05-01T11:00:00", "2020-05-01T21:55:00"),
toList("PRG", "FRA"),
toList("FRA", "MCO"),
toList("LH", "LH"),
"LH",
39766.0,
7,
ZonedDateTime.of(2020, 5, 1, 9, 50, 0, 0, UTC)
), TTL.D3},
{new FlightData(
@Test
void testInference() throws XGBoostError {
final FlightData[] input = {new FlightData(
"WS",
toList("2020-05-09T23:59:00", "2020-05-10T10:30:00"),
toList("2020-05-10T08:55:00", "2020-05-10T11:30:00"),
toList("MCO", "FRA"),
toList("FRA", "PRG"),
toList("LH", "LH"),
"LH",
toList("2020-05-01T09:50:00", "2020-05-01T11:55:00"),
toList("2020-05-01T11:00:00", "2020-05-01T21:55:00"),
toList("PRG", "FRA"),
toList("FRA", "MCO"),
toList("LH", "LH"),
"LH",
39766.0,
7,
ZonedDateTime.of(2020, 5, 1, 9, 50, 0, 0, UTC)
),
new FlightData(
"PYTON",
"",
"",
@ -73,8 +66,8 @@ public class CacheInferenceTest {
336.258,
0,
ZonedDateTime.of(2019, 12, 18, 5, 45, 0, 0, UTC)
), TTL.D14},
{new FlightData(
),
new FlightData(
"AVIA",
toList("2020-02-07T02:25:00", "2020-02-07T14:50:00"),
toList("2020-02-07T13:10:00", "2020-02-07T16:55:00"),
@ -91,8 +84,8 @@ public class CacheInferenceTest {
5971.77978,
4,
ZonedDateTime.of(2020, 1, 28, 10, 35, 0, 0, UTC)
), TTL.D2},
{new FlightData(
),
new FlightData(
"HH",
toList("2019-11-01T16:30:00", "2019-11-01T23:35:00"),
toList("2019-11-01T21:12:00", "2019-11-02T07:45:00"),
@ -109,8 +102,8 @@ public class CacheInferenceTest {
17723.0,
12,
ZonedDateTime.of(2019, 10, 18, 8, 10, 0, 0, UTC)
), TTL.D2},
{new FlightData(
),
new FlightData(
"unknown",
toList(ZonedDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)),
toList(ZonedDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)),
@ -127,8 +120,11 @@ public class CacheInferenceTest {
0.0,
0,
ZonedDateTime.now()
), TTL.D1}
};
)};
final TTL[] expected = {TTL.D3, TTL.D14, TTL.D2, TTL.D2, TTL.D1};
final TTL[] result = ci.cacheTTL(input, TEST_DATE);
assertEquals(result, expected);
}
private String toList(final String... data) {

Loading…
Cancel
Save