|
|
|
|
@ -18,6 +18,7 @@ import java.util.Map; |
|
|
|
|
|
|
|
|
|
public class CacheInference { |
|
|
|
|
private static final TTL[] TTL_VALUES = TTL.values(); |
|
|
|
|
private static final int NUM_FEATURES = 20; |
|
|
|
|
|
|
|
|
|
private final ObjectMapper mapper = new ObjectMapper(); |
|
|
|
|
private final Booster booster; |
|
|
|
|
@ -37,59 +38,82 @@ public class CacheInference { |
|
|
|
|
labels = loadLabels(labelIS); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
public TTL cacheTTL(final FlightData data) throws XGBoostError { |
|
|
|
|
public TTL[] cacheTTL(final FlightData[] data) throws XGBoostError { |
|
|
|
|
return cacheTTL(data, ZonedDateTime.now()); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
/** |
|
|
|
|
* Method for backtesting |
|
|
|
|
*/ |
|
|
|
|
public TTL cacheTTL(final FlightData data, final ZonedDateTime now) throws XGBoostError { |
|
|
|
|
final float[] predicts = cacheTTLProbability(data, now); |
|
|
|
|
return TTL_VALUES[argmax(predicts)]; |
|
|
|
|
public TTL[] cacheTTL(final FlightData[] data, final ZonedDateTime now) throws XGBoostError { |
|
|
|
|
final float[][] predicts = cacheTTLProbability(data, now); |
|
|
|
|
|
|
|
|
|
if (predicts.length == 0) { |
|
|
|
|
return new TTL[] {}; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
final TTL[] ttl = new TTL[predicts.length]; |
|
|
|
|
for (int i = 0; i < predicts.length; i++) { |
|
|
|
|
int idx = 0; |
|
|
|
|
float max = Float.MIN_VALUE; |
|
|
|
|
|
|
|
|
|
for (int j = 0; j < predicts[i].length; j++) { |
|
|
|
|
if (predicts[i][j] > max) { |
|
|
|
|
max = predicts[i][j]; |
|
|
|
|
idx = j; |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
ttl[i] = TTL_VALUES[idx]; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
return ttl; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
public float[] cacheTTLProbability(final FlightData data) throws XGBoostError { |
|
|
|
|
public float[][] cacheTTLProbability(final FlightData[] data) throws XGBoostError { |
|
|
|
|
return cacheTTLProbability(data, ZonedDateTime.now()); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
/** |
|
|
|
|
* Method for backtesting |
|
|
|
|
*/ |
|
|
|
|
public float[] cacheTTLProbability(final FlightData data, final ZonedDateTime now) throws XGBoostError { |
|
|
|
|
public float[][] cacheTTLProbability(final FlightData[] data, final ZonedDateTime now) throws XGBoostError { |
|
|
|
|
final DMatrix matrix = createMatrix(data, now); |
|
|
|
|
float[][] predicts = booster.predict(matrix); |
|
|
|
|
return predicts[0]; |
|
|
|
|
return booster.predict(matrix); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
private Booster loadModel(InputStream model) throws XGBoostError, IOException { |
|
|
|
|
return XGBoost.loadModel(model); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
private DMatrix createMatrix(final FlightData data, final ZonedDateTime now) throws XGBoostError { |
|
|
|
|
final float[] arr = new float[20]; |
|
|
|
|
arr[0] = (float) data.getInputPrice(); |
|
|
|
|
arr[1] = labels.get("type").getOrDefault(data.getType(), 0); |
|
|
|
|
arr[2] = labels.get("flight.inboundSegments.departure").getOrDefault(data.getInboundDeparture(), 0); |
|
|
|
|
arr[3] = labels.get("flight.inboundSegments.arrival").getOrDefault(data.getInboundArrival(), 0); |
|
|
|
|
arr[4] = labels.get("flight.inboundSegments.origin.airportCode").getOrDefault(data.getInboundOrigin(), 0); |
|
|
|
|
arr[5] = labels.get("flight.inboundSegments.destination.airportCode").getOrDefault(data.getInboundDestination(), 0); |
|
|
|
|
arr[6] = labels.get("flight.inboundSegments.airline.code").getOrDefault(data.getInboundAirlines(), 0); |
|
|
|
|
arr[7] = labels.get("flight.inboundMCX.code").getOrDefault(data.getInboundMCXAirlines(), 0); |
|
|
|
|
arr[8] = labels.get("flight.outboundSegments.departure").getOrDefault(data.getOutboundDeparture(), 0); |
|
|
|
|
arr[9] = labels.get("flight.outboundSegments.arrival").getOrDefault(data.getOutboundArrival(), 0); |
|
|
|
|
arr[10] = labels.get("flight.outboundSegments.origin.airportCode").getOrDefault(data.getOutboundOrigin(), 0); |
|
|
|
|
arr[11] = labels.get("flight.outboundSegments.destination.airportCode").getOrDefault(data.getOutboundDestination(), 0); |
|
|
|
|
arr[12] = labels.get("flight.outboundSegments.airline.code").getOrDefault(data.getOutboundAirlines(), 0); |
|
|
|
|
arr[13] = labels.get("flight.outboundMCX.code").getOrDefault(data.getOutboundMCXAirlines(), 0); |
|
|
|
|
arr[14] = data.getNights(); |
|
|
|
|
arr[15] = computePrebooking(data.getOutDepartureDate(), now); |
|
|
|
|
arr[16] = now.getDayOfMonth(); |
|
|
|
|
arr[17] = now.getDayOfWeek().getValue() - 1; |
|
|
|
|
arr[18] = data.getOutDepartureDate().getDayOfMonth(); |
|
|
|
|
arr[19] = data.getOutDepartureDate().getDayOfWeek().getValue() - 1; |
|
|
|
|
|
|
|
|
|
return new DMatrix(arr, 1, arr.length); |
|
|
|
|
private DMatrix createMatrix(final FlightData[] data, final ZonedDateTime now) throws XGBoostError { |
|
|
|
|
final float[] arr = new float[data.length * NUM_FEATURES]; |
|
|
|
|
for (int i = 0; i < data.length; i++) { |
|
|
|
|
final FlightData current = data[i]; |
|
|
|
|
final int start = NUM_FEATURES * i; |
|
|
|
|
arr[start] = (float) current.getInputPrice(); |
|
|
|
|
arr[start + 1] = labels.get("type").getOrDefault(current.getType(), 0); |
|
|
|
|
arr[start + 2] = labels.get("flight.inboundSegments.departure").getOrDefault(current.getInboundDeparture(), 0); |
|
|
|
|
arr[start + 3] = labels.get("flight.inboundSegments.arrival").getOrDefault(current.getInboundArrival(), 0); |
|
|
|
|
arr[start + 4] = labels.get("flight.inboundSegments.origin.airportCode").getOrDefault(current.getInboundOrigin(), 0); |
|
|
|
|
arr[start + 5] = labels.get("flight.inboundSegments.destination.airportCode").getOrDefault(current.getInboundDestination(), 0); |
|
|
|
|
arr[start + 6] = labels.get("flight.inboundSegments.airline.code").getOrDefault(current.getInboundAirlines(), 0); |
|
|
|
|
arr[start + 7] = labels.get("flight.inboundMCX.code").getOrDefault(current.getInboundMCXAirlines(), 0); |
|
|
|
|
arr[start + 8] = labels.get("flight.outboundSegments.departure").getOrDefault(current.getOutboundDeparture(), 0); |
|
|
|
|
arr[start + 9] = labels.get("flight.outboundSegments.arrival").getOrDefault(current.getOutboundArrival(), 0); |
|
|
|
|
arr[start + 10] = labels.get("flight.outboundSegments.origin.airportCode").getOrDefault(current.getOutboundOrigin(), 0); |
|
|
|
|
arr[start + 11] = labels.get("flight.outboundSegments.destination.airportCode").getOrDefault(current.getOutboundDestination(), 0); |
|
|
|
|
arr[start + 12] = labels.get("flight.outboundSegments.airline.code").getOrDefault(current.getOutboundAirlines(), 0); |
|
|
|
|
arr[start + 13] = labels.get("flight.outboundMCX.code").getOrDefault(current.getOutboundMCXAirlines(), 0); |
|
|
|
|
arr[start + 14] = current.getNights(); |
|
|
|
|
arr[start + 15] = computePrebooking(current.getOutDepartureDate(), now); |
|
|
|
|
arr[start + 16] = now.getDayOfMonth(); |
|
|
|
|
arr[start + 17] = now.getDayOfWeek().getValue() - 1; |
|
|
|
|
arr[start + 18] = current.getOutDepartureDate().getDayOfMonth(); |
|
|
|
|
arr[start + 19] = current.getOutDepartureDate().getDayOfWeek().getValue() - 1; |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
return new DMatrix(arr, data.length, NUM_FEATURES); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
private float computePrebooking(final ZonedDateTime outdeparture, final ZonedDateTime now) { |
|
|
|
|
@ -101,18 +125,4 @@ public class CacheInference { |
|
|
|
|
}; |
|
|
|
|
return mapper.readValue(labels, typeRef); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
private int argmax(final float[] data) { |
|
|
|
|
int idx = 0; |
|
|
|
|
float max = Float.MIN_VALUE; |
|
|
|
|
|
|
|
|
|
for (int i = 0; i < data.length; i++) { |
|
|
|
|
if (data[i] > max) { |
|
|
|
|
max = data[i]; |
|
|
|
|
idx = i; |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
return idx; |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|