diff --git a/inference/pom.xml b/inference/pom.xml
index 56faa9d..6dceeef 100644
--- a/inference/pom.xml
+++ b/inference/pom.xml
@@ -30,6 +30,13 @@
7.0.0
test
+
+
+
+ org.slf4j
+ slf4j-api
+ 1.7.28
+
diff --git a/inference/src/main/java/cz/aprar/bonitoo/inference/CacheInference.java b/inference/src/main/java/cz/aprar/bonitoo/inference/CacheInference.java
index 2d19484..1871224 100644
--- a/inference/src/main/java/cz/aprar/bonitoo/inference/CacheInference.java
+++ b/inference/src/main/java/cz/aprar/bonitoo/inference/CacheInference.java
@@ -13,11 +13,8 @@ import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.time.ZonedDateTime;
-import java.time.format.DateTimeFormatter;
import java.time.temporal.ChronoUnit;
-import java.util.List;
import java.util.Map;
-import java.util.stream.Collectors;
public class CacheInference {
private static final TTL[] TTL_VALUES = TTL.values();
@@ -71,40 +68,32 @@ public class CacheInference {
private DMatrix createMatrix(final FlightData data, final ZonedDateTime now) throws XGBoostError {
final float[] arr = new float[20];
- arr[0] = data.getInputPrice().floatValue();
+ arr[0] = (float) data.getInputPrice();
arr[1] = labels.get("type").getOrDefault(data.getType(), 0);
- arr[2] = labels.get("flight.inboundSegments.departure").getOrDefault(joinTimestampList(data.getInboundDeparture()), 0);
- arr[3] = labels.get("flight.inboundSegments.arrival").getOrDefault(joinTimestampList(data.getInboundArrival()), 0);
- arr[4] = labels.get("flight.inboundSegments.origin.airportCode").getOrDefault(joinList(data.getInboundOrigin()), 0);
- arr[5] = labels.get("flight.inboundSegments.destination.airportCode").getOrDefault(joinList(data.getInboundDestination()), 0);
- arr[6] = labels.get("flight.inboundSegments.airline.code").getOrDefault(joinList(data.getInboundAirlines()), 0);
+ arr[2] = labels.get("flight.inboundSegments.departure").getOrDefault(data.getInboundDeparture(), 0);
+ arr[3] = labels.get("flight.inboundSegments.arrival").getOrDefault(data.getInboundArrival(), 0);
+ arr[4] = labels.get("flight.inboundSegments.origin.airportCode").getOrDefault(data.getInboundOrigin(), 0);
+ arr[5] = labels.get("flight.inboundSegments.destination.airportCode").getOrDefault(data.getInboundDestination(), 0);
+ arr[6] = labels.get("flight.inboundSegments.airline.code").getOrDefault(data.getInboundAirlines(), 0);
arr[7] = labels.get("flight.inboundMCX.code").getOrDefault(data.getInboundMCXAirlines(), 0);
- arr[8] = labels.get("flight.outboundSegments.departure").getOrDefault(joinTimestampList(data.getOutboundDeparture()), 0);
- arr[9] = labels.get("flight.outboundSegments.arrival").getOrDefault(joinTimestampList(data.getOutboundArrival()), 0);
- arr[10] = labels.get("flight.outboundSegments.origin.airportCode").getOrDefault(joinList(data.getOutboundOrigin()), 0);
- arr[11] = labels.get("flight.outboundSegments.destination.airportCode").getOrDefault(joinList(data.getOutboundDestination()), 0);
- arr[12] = labels.get("flight.outboundSegments.airline.code").getOrDefault(joinList(data.getOutboundAirlines()), 0);
+ arr[8] = labels.get("flight.outboundSegments.departure").getOrDefault(data.getOutboundDeparture(), 0);
+ arr[9] = labels.get("flight.outboundSegments.arrival").getOrDefault(data.getOutboundArrival(), 0);
+ arr[10] = labels.get("flight.outboundSegments.origin.airportCode").getOrDefault(data.getOutboundOrigin(), 0);
+ arr[11] = labels.get("flight.outboundSegments.destination.airportCode").getOrDefault(data.getOutboundDestination(), 0);
+ arr[12] = labels.get("flight.outboundSegments.airline.code").getOrDefault(data.getOutboundAirlines(), 0);
arr[13] = labels.get("flight.outboundMCX.code").getOrDefault(data.getOutboundMCXAirlines(), 0);
- arr[14] = computeDuration(data.getInboundDeparture(), data.getOutboundDeparture());
- arr[15] = computePrebooking(data.getOutboundDeparture(), now);
+ arr[14] = data.getNights();
+ arr[15] = computePrebooking(data.getOutDepartureDate(), now);
arr[16] = now.getDayOfMonth();
arr[17] = now.getDayOfWeek().getValue() - 1;
- arr[18] = data.getOutboundDeparture().get(0).getDayOfMonth();
- arr[19] = data.getOutboundDeparture().get(0).getDayOfWeek().getValue() - 1;
+ arr[18] = data.getOutDepartureDate().getDayOfMonth();
+ arr[19] = data.getOutDepartureDate().getDayOfWeek().getValue() - 1;
return new DMatrix(arr, 1, arr.length);
}
- private float computeDuration(final List indeparture, final List outdeparture) {
- if (indeparture.isEmpty()) {
- return 0;
- }
-
- return ChronoUnit.DAYS.between(outdeparture.get(0), indeparture.get(0));
- }
-
- private float computePrebooking(final List outdeparture, final ZonedDateTime now) {
- return ChronoUnit.DAYS.between(now, outdeparture.get(0));
+ private float computePrebooking(final ZonedDateTime outdeparture, final ZonedDateTime now) {
+ return ChronoUnit.DAYS.between(now, outdeparture);
}
private Map> loadLabels(InputStream labels) throws IOException {
@@ -113,14 +102,6 @@ public class CacheInference {
return mapper.readValue(labels, typeRef);
}
- private String joinTimestampList(final List data) {
- return data.stream().map(DateTimeFormatter.ISO_LOCAL_DATE_TIME::format).collect(Collectors.joining("|"));
- }
-
- private String joinList(final List data) {
- return String.join("|", data);
- }
-
private int argmax(final float[] data) {
int idx = 0;
float max = Float.MIN_VALUE;
diff --git a/inference/src/main/java/cz/aprar/bonitoo/inference/FlightData.java b/inference/src/main/java/cz/aprar/bonitoo/inference/FlightData.java
index 969bdd1..904c9a5 100644
--- a/inference/src/main/java/cz/aprar/bonitoo/inference/FlightData.java
+++ b/inference/src/main/java/cz/aprar/bonitoo/inference/FlightData.java
@@ -1,32 +1,31 @@
package cz.aprar.bonitoo.inference;
import java.time.ZonedDateTime;
-import java.util.Collections;
-import java.util.List;
public class FlightData {
private final String type;
- private final List inboundDeparture;
- private final List inboundArrival;
- private final List inboundOrigin;
- private final List inboundDestination;
- private final List inboundAirlines;
+ private final String inboundDeparture;
+ private final String inboundArrival;
+ private final String inboundOrigin;
+ private final String inboundDestination;
+ private final String inboundAirlines;
private final String inboundMCXAirlines;
- private final List outboundDeparture;
- private final List outboundArrival;
- private final List outboundOrigin;
- private final List outboundDestination;
- private final List outboundAirlines;
+ private final String outboundDeparture;
+ private final String outboundArrival;
+ private final String outboundOrigin;
+ private final String outboundDestination;
+ private final String outboundAirlines;
private final String outboundMCXAirlines;
- private final Double inputPrice;
-
- public FlightData(final String type, final List inboundDeparture,
- final List inboundArrival, final List inboundOrigin,
- final List inboundDestination, final List inboundAirlines,
- final String inboundMCXAirlines, final List outboundDeparture,
- final List outboundArrival, final List outboundOrigin,
- final List outboundDestination, final List outboundAirlines,
- final String outboundMCXAirlines, final Double inputPrice) {
+ private final double inputPrice;
+ private final int nights;
+ private final ZonedDateTime outDepartureDate;
+
+ public FlightData(final String type, final String inboundDeparture, final String inboundArrival,
+ final String inboundOrigin, final String inboundDestination, final String inboundAirlines,
+ final String inboundMCXAirlines, final String outboundDeparture, final String outboundArrival,
+ final String outboundOrigin, final String outboundDestination, final String outboundAirlines,
+ final String outboundMCXAirlines, final double inputPrice, final int nights,
+ final ZonedDateTime outDepartureDate) {
this.type = type;
this.inboundDeparture = inboundDeparture;
this.inboundArrival = inboundArrival;
@@ -41,61 +40,71 @@ public class FlightData {
this.outboundAirlines = outboundAirlines;
this.outboundMCXAirlines = outboundMCXAirlines;
this.inputPrice = inputPrice;
+ this.nights = nights;
+ this.outDepartureDate = outDepartureDate;
}
public String getType() {
return type;
}
- public List getInboundDeparture() {
- return Collections.unmodifiableList(inboundDeparture);
+ public String getInboundDeparture() {
+ return inboundDeparture;
}
- public List getInboundArrival() {
- return Collections.unmodifiableList(inboundArrival);
+ public String getInboundArrival() {
+ return inboundArrival;
}
- public List getInboundOrigin() {
- return Collections.unmodifiableList(inboundOrigin);
+ public String getInboundOrigin() {
+ return inboundOrigin;
}
- public List getInboundDestination() {
- return Collections.unmodifiableList(inboundDestination);
+ public String getInboundDestination() {
+ return inboundDestination;
}
- public List getOutboundDeparture() {
- return Collections.unmodifiableList(outboundDeparture);
+ public String getInboundAirlines() {
+ return inboundAirlines;
}
- public List getOutboundArrival() {
- return Collections.unmodifiableList(outboundArrival);
+ public String getInboundMCXAirlines() {
+ return inboundMCXAirlines;
}
- public List getOutboundOrigin() {
- return Collections.unmodifiableList(outboundOrigin);
+ public String getOutboundDeparture() {
+ return outboundDeparture;
}
- public List getOutboundDestination() {
- return Collections.unmodifiableList(outboundDestination);
+ public String getOutboundArrival() {
+ return outboundArrival;
}
- public Double getInputPrice() {
- return inputPrice;
+ public String getOutboundOrigin() {
+ return outboundOrigin;
}
- public List getInboundAirlines() {
- return Collections.unmodifiableList(inboundAirlines);
+ public String getOutboundDestination() {
+ return outboundDestination;
}
- public List getOutboundAirlines() {
- return Collections.unmodifiableList(outboundAirlines);
- }
-
- public String getInboundMCXAirlines() {
- return inboundMCXAirlines;
+ public String getOutboundAirlines() {
+ return outboundAirlines;
}
public String getOutboundMCXAirlines() {
return outboundMCXAirlines;
}
+
+ public double getInputPrice() {
+ return inputPrice;
+ }
+
+ public int getNights() {
+ return nights;
+ }
+
+ public ZonedDateTime getOutDepartureDate() {
+ return outDepartureDate;
+ }
}
diff --git a/inference/src/test/java/cz/aprar/bonitoo/inference/CacheInferenceTest.java b/inference/src/test/java/cz/aprar/bonitoo/inference/CacheInferenceTest.java
index 6103219..88193e7 100644
--- a/inference/src/test/java/cz/aprar/bonitoo/inference/CacheInferenceTest.java
+++ b/inference/src/test/java/cz/aprar/bonitoo/inference/CacheInferenceTest.java
@@ -17,7 +17,8 @@ import static java.util.Collections.emptyList;
import static org.testng.Assert.assertEquals;
public class CacheInferenceTest {
- private static final ZonedDateTime TEST_DATE = ZonedDateTime.of(2019, 10, 10, 0, 0, 0, 0, ZoneId.of("UTC"));
+ private static final ZoneId UTC = ZoneId.of("UTC");
+ private static final ZonedDateTime TEST_DATE = ZonedDateTime.of(2019, 10, 10, 0, 0, 0, 0, UTC);
private CacheInference ci;
@@ -39,94 +40,98 @@ public class CacheInferenceTest {
return new Object[][]{
{new FlightData(
"WS",
- toTimestampList("2020-05-09T23:59:00", "2020-05-10T10:30:00"),
- toTimestampList("2020-05-10T08:55:00", "2020-05-10T11:30:00"),
+ toList("2020-05-09T23:59:00", "2020-05-10T10:30:00"),
+ toList("2020-05-10T08:55:00", "2020-05-10T11:30:00"),
toList("MCO", "FRA"),
toList("FRA", "PRG"),
toList("LH", "LH"),
"LH",
- toTimestampList("2020-05-01T09:50:00", "2020-05-01T11:55:00"),
- toTimestampList("2020-05-01T11:00:00", "2020-05-01T21:55:00"),
+ toList("2020-05-01T09:50:00", "2020-05-01T11:55:00"),
+ toList("2020-05-01T11:00:00", "2020-05-01T21:55:00"),
toList("PRG", "FRA"),
toList("FRA", "MCO"),
toList("LH", "LH"),
"LH",
- 39766.0
+ 39766.0,
+ 7,
+ ZonedDateTime.of(2020, 5, 1, 9, 50, 0, 0, UTC)
), TTL.D3},
{new FlightData(
"PYTON",
- emptyList(),
- emptyList(),
- emptyList(),
- emptyList(),
- emptyList(),
"",
- toTimestampList("2019-12-18T05:45:00"),
- toTimestampList("2019-12-18T08:05:00"),
+ "",
+ "",
+ "",
+ "",
+ "",
+ toList("2019-12-18T05:45:00"),
+ toList("2019-12-18T08:05:00"),
toList("KRK"),
toList("BVA"),
toList("FR"),
"FR",
- 336.258
+ 336.258,
+ 0,
+ ZonedDateTime.of(2019, 12, 18, 5, 45, 0, 0, UTC)
), TTL.D14},
{new FlightData(
"AVIA",
- toTimestampList("2020-02-07T02:25:00", "2020-02-07T14:50:00"),
- toTimestampList("2020-02-07T13:10:00", "2020-02-07T16:55:00"),
+ toList("2020-02-07T02:25:00", "2020-02-07T14:50:00"),
+ toList("2020-02-07T13:10:00", "2020-02-07T16:55:00"),
toList("LAX", "LHR"),
toList("LHR", "PRG"),
toList("AA", "BA"),
"AA",
- toTimestampList("2020-01-28T10:35:00", "2020-01-28T14:40:00"),
- toTimestampList("2020-01-28T12:45:00", "2020-01-29T01:45:00"),
+ toList("2020-01-28T10:35:00", "2020-01-28T14:40:00"),
+ toList("2020-01-28T12:45:00", "2020-01-29T01:45:00"),
toList("PRG", "HEL"),
toList("HEL", "LAX"),
toList("AY", "AY"),
"AY",
- 5971.77978
+ 5971.77978,
+ 4,
+ ZonedDateTime.of(2020, 1, 28, 10, 35, 0, 0, UTC)
), TTL.D2},
{new FlightData(
"HH",
- toTimestampList("2019-11-01T16:30:00", "2019-11-01T23:35:00"),
- toTimestampList("2019-11-01T21:12:00", "2019-11-02T07:45:00"),
+ toList("2019-11-01T16:30:00", "2019-11-01T23:35:00"),
+ toList("2019-11-01T21:12:00", "2019-11-02T07:45:00"),
toList("YVR", "YUL"),
toList("YUL", "VIE"),
toList("LH", "LH"),
"LH",
- toTimestampList("2019-10-18T08:10:00", "2019-10-18T11:30:00"),
- toTimestampList("2019-10-18T09:40:00", "2019-10-18T21:25:00"),
+ toList("2019-10-18T08:10:00", "2019-10-18T11:30:00"),
+ toList("2019-10-18T09:40:00", "2019-10-18T21:25:00"),
toList("VIE", "FRA"),
toList("FRA", "YVR"),
toList("LH", "LH"),
"LH",
- 17723.0
+ 17723.0,
+ 12,
+ ZonedDateTime.of(2019, 10, 18, 8, 10, 0, 0, UTC)
), TTL.D2},
{new FlightData(
"unknown",
- toTimestampList(ZonedDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)),
- toTimestampList(ZonedDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)),
+ toList(ZonedDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)),
+ toList(ZonedDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)),
toList("unknown"),
toList("unknown"),
toList("unknown"),
"unknown",
- toTimestampList(ZonedDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)),
- toTimestampList(ZonedDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)),
+ toList(ZonedDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)),
+ toList(ZonedDateTime.now().format(DateTimeFormatter.ISO_LOCAL_DATE_TIME)),
toList("unknown"),
toList("unknown"),
toList("unknown"),
"unknown",
- 0.0
+ 0.0,
+ 0,
+ ZonedDateTime.now()
), TTL.D1}
};
}
- private List toList(final String... data) {
- return Arrays.asList(data);
- }
-
- private List toTimestampList(final String... data) {
- return Arrays.stream(data)
- .map(x -> x + "Z")
- .map(x -> ZonedDateTime.parse(x, DateTimeFormatter.ISO_ZONED_DATE_TIME)).collect(Collectors.toList());
+ private String toList(final String... data) {
+ return String.join("|", data);
}
}