diff --git a/export.py b/export.py
index 0b70866..62b330f 100644
--- a/export.py
+++ b/export.py
@@ -1,115 +1,76 @@
from pymongo import MongoClient
-from pprint import pprint
from datetime import datetime
import csv
+
client = MongoClient()
-db=client.bonitoo
+db = client.bonitoo
fieldnames = [
-'timestamp',
-'client.channel',
-'type',
-'flight.inboundSegments.departure',
-'flight.inboundSegments.arrival',
-'flight.inboundSegments.origin.airportCode',
-'flight.inboundSegments.destination.airportCode',
-'flight.inboundSegments.flightNumber',
-'flight.inboundSegments.travelClass',
-'flight.inboundSegments.bookingCode',
-'flight.inboundSegments.availability',
-'flight.inboundSegments.elapsedFlyingTime',
-'flight.outboundSegments.departure',
-'flight.outboundSegments.arrival',
-'flight.outboundSegments.origin.airportCode',
-'flight.outboundSegments.destination.airportCode',
-'flight.outboundSegments.flightNumber',
-'flight.outboundSegments.travelClass',
-'flight.outboundSegments.bookingCode',
-'flight.outboundSegments.availability',
-'flight.outboundSegments.elapsedFlyingTime',
-'flight.inboundEFT', # elapsed flying time
-'flight.outboundEFT',
-'oneWay',
-'adults', # pocet osob = (adults + children)
-'children',
-'infants',
-'input.price',
-'input.tax',
-'input.currency',
-'success',
-'status',
-'output.price',
-'output.tax',
-'output.currency',
-'duration' # delka volani do nadrazeneho systemu
+ 'timestamp',
+ 'client.channel',
+ 'type',
+ 'flight.inboundSegments.departure',
+ 'flight.inboundSegments.arrival',
+ 'flight.inboundSegments.origin.airportCode',
+ 'flight.inboundSegments.destination.airportCode',
+ 'flight.outboundSegments.departure',
+ 'flight.outboundSegments.arrival',
+ 'flight.outboundSegments.origin.airportCode',
+ 'flight.outboundSegments.destination.airportCode',
+ 'input.price',
+ 'input.tax',
+ 'input.currency',
+ 'success',
+ 'status',
+ 'output.price',
+ 'output.tax',
+ 'output.currency',
+ 'duration'
]
-# 5% nebo 200 kc rozdil nahoru
-# -200 kc dolu
-# abs(+-10kc) ignorovat
-
-# timestamp + ok price - ma byt v cache od cacheat
-# timestamp + notok price - nema byt v cache od cacheat
-
-# delka pobytu prilet-odlet
-# delka letu ?
-# pokud je chyba tak nocache (= chybi priceout)
-
-# brat v uvahu in/out kody aerolinek (mcx ?) - mirek jeste zjisti
-
-# vypocitat uspesnost je/neni v cache v %
-
counter = 0
with open('export.csv', mode='w') as ef:
- writer = csv.DictWriter(ef, fieldnames=fieldnames, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
- # do not write header for s3 files
- # writer.writeheader()
+ writer = csv.DictWriter(ef, fieldnames=fieldnames, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
+ # do not write header for s3 files
+ # writer.writeheader()
+
+ for it in db.pricing_audit.find():
+ counter += 1
+ if counter % 1000 == 0:
+ print('Iterace %d' % counter)
+ d = {
+ 'timestamp': datetime.fromtimestamp(it['timestamp'] / 1000).isoformat(),
+ 'client.channel': it['client']['channel'],
+ 'type': it['type'],
+ 'flight.outboundSegments.departure': '|'.join(
+ [x['departure'].isoformat() for x in it['flight']['outboundSegments']]),
+ 'flight.outboundSegments.arrival': '|'.join(
+ [x['arrival'].isoformat() for x in it['flight']['outboundSegments']]),
+ 'flight.outboundSegments.origin.airportCode': '|'.join(
+ [x['origin']['airportCode'] for x in it['flight']['outboundSegments']]),
+ 'flight.outboundSegments.destination.airportCode': '|'.join(
+ [x['destination']['airportCode'] for x in it['flight']['outboundSegments']]),
+ 'input.price': it['input']['price'],
+ 'input.tax': it['input']['tax'],
+ 'input.currency': it['input']['currency'],
+ 'success': it['success'],
+ 'status': it.get('status', ''),
+ 'output.price': it.get('output', {'price': 0})['price'],
+ 'output.tax': it.get('output', {'tax': 0})['tax'],
+ 'output.currency': it.get('output', {'currency': 0})['currency'],
+ 'duration': it['duration']
+ }
- for it in db.pricing_audit.find():
- counter += 1
- if counter % 1000 == 0:
- print('Iterace %d' % counter)
- d = {
- 'timestamp': datetime.fromtimestamp(it['timestamp']/1000).isoformat(),
- 'client.channel': it['client']['channel'],
- 'type': it['type'],
- 'flight.outboundSegments.departure': '|'.join([x['departure'].isoformat() for x in it['flight']['outboundSegments']]),
- 'flight.outboundSegments.arrival': '|'.join([x['arrival'].isoformat() for x in it['flight']['outboundSegments']]),
- 'flight.outboundSegments.origin.airportCode': '|'.join([x['origin']['airportCode'] for x in it['flight']['outboundSegments']]),
- 'flight.outboundSegments.destination.airportCode': '|'.join([x['destination']['airportCode'] for x in it['flight']['outboundSegments']]),
- 'flight.outboundSegments.flightNumber': '|'.join([x['flightNumber'] for x in it['flight']['outboundSegments']]),
- 'flight.outboundSegments.travelClass': '|'.join([x['travelClass'] for x in it['flight']['outboundSegments']]),
- 'flight.outboundSegments.bookingCode': '|'.join([x.get('bookingCode','') for x in it['flight']['outboundSegments']]),
- 'flight.outboundSegments.availability': '|'.join([str(x.get('availability','')) for x in it['flight']['outboundSegments']]),
- 'flight.outboundSegments.elapsedFlyingTime': '|'.join([str(x.get('elapsedFlyingTime','')) for x in it['flight']['outboundSegments']]),
- 'flight.inboundEFT': it['flight'].get('inboundEFT',''),
- 'flight.outboundEFT': it['flight'].get('outboundEFT',''),
- 'oneWay': it['oneWay'],
- 'adults': it['adults'],
- 'children': it['children'],
- 'infants': it['infants'],
- 'input.price': it['input']['price'],
- 'input.tax': it['input']['tax'],
- 'input.currency': it['input']['currency'],
- 'success': it['success'],
- 'status': it.get('status',''),
- 'output.price': it.get('output', {'price': 0})['price'],
- 'output.tax': it.get('output', {'tax': 0})['tax'],
- 'output.currency': it.get('output', {'currency': 0})['currency'],
- 'duration': it['duration']
- }
-
- if 'inboundSegments' in it['flight']:
- inb = {
- 'flight.inboundSegments.departure': '|'.join([x['departure'].isoformat() for x in it['flight']['inboundSegments']]),
- 'flight.inboundSegments.arrival': '|'.join([x['arrival'].isoformat() for x in it['flight']['inboundSegments']]),
- 'flight.inboundSegments.origin.airportCode': '|'.join([x['origin']['airportCode'] for x in it['flight']['inboundSegments']]),
- 'flight.inboundSegments.destination.airportCode': '|'.join([x['destination']['airportCode'] for x in it['flight']['inboundSegments']]),
- 'flight.inboundSegments.flightNumber': '|'.join([x['flightNumber'] for x in it['flight']['inboundSegments']]),
- 'flight.inboundSegments.travelClass': '|'.join([x['travelClass'] for x in it['flight']['inboundSegments']]),
- 'flight.inboundSegments.bookingCode': '|'.join([x.get('bookingCode', '') for x in it['flight']['inboundSegments']]),
- 'flight.inboundSegments.availability': '|'.join([str(x.get('availability','')) for x in it['flight']['inboundSegments']]),
- 'flight.inboundSegments.elapsedFlyingTime': '|'.join([str(x.get('elapsedFlyingTime','')) for x in it['flight']['inboundSegments']])
- }
- d = {**d, **inb}
- writer.writerow(d)
+ if 'inboundSegments' in it['flight']:
+ inb = {
+ 'flight.inboundSegments.departure': '|'.join(
+ [x['departure'].isoformat() for x in it['flight']['inboundSegments']]),
+ 'flight.inboundSegments.arrival': '|'.join(
+ [x['arrival'].isoformat() for x in it['flight']['inboundSegments']]),
+ 'flight.inboundSegments.origin.airportCode': '|'.join(
+ [x['origin']['airportCode'] for x in it['flight']['inboundSegments']]),
+ 'flight.inboundSegments.destination.airportCode': '|'.join(
+ [x['destination']['airportCode'] for x in it['flight']['inboundSegments']]),
+ }
+ d = {**d, **inb}
+ writer.writerow(d)
diff --git a/inference/pom.xml b/inference/pom.xml
new file mode 100644
index 0000000..56faa9d
--- /dev/null
+++ b/inference/pom.xml
@@ -0,0 +1,58 @@
+
+ 4.0.0
+ cz.aprar.bonitoo
+ inference
+ jar
+ 1.0-SNAPSHOT
+ inference
+
+
+ UTF-8
+
+
+
+
+ ml.dmlc
+ xgboost4j
+ 0.90
+
+
+
+ com.fasterxml.jackson.core
+ jackson-databind
+ 2.9.9
+
+
+
+ org.testng
+ testng
+ 7.0.0
+ test
+
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-compiler-plugin
+ 3.8.1
+
+ 8
+ 8
+
+
+
+
+ org.apache.maven.plugins
+ maven-pmd-plugin
+ 3.12.0
+
+ utf-8
+ 1.8
+
+
+
+
+
diff --git a/inference/src/main/java/cz/aprar/bonitoo/inference/CacheInference.java b/inference/src/main/java/cz/aprar/bonitoo/inference/CacheInference.java
new file mode 100644
index 0000000..d457684
--- /dev/null
+++ b/inference/src/main/java/cz/aprar/bonitoo/inference/CacheInference.java
@@ -0,0 +1,81 @@
+package cz.aprar.bonitoo.inference;
+
+import com.fasterxml.jackson.core.type.TypeReference;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import ml.dmlc.xgboost4j.java.Booster;
+import ml.dmlc.xgboost4j.java.DMatrix;
+import ml.dmlc.xgboost4j.java.XGBoost;
+import ml.dmlc.xgboost4j.java.XGBoostError;
+
+import java.io.BufferedInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.util.List;
+import java.util.Map;
+
+public class CacheInference {
+ private static final TTL[] TTL_VALUES = TTL.values();
+
+ private final ObjectMapper mapper = new ObjectMapper();
+ private final Booster booster;
+ private final Map> labels;
+
+ public CacheInference(final Path modelFile, final Path labelFile) throws IOException, XGBoostError {
+ try (BufferedInputStream modelIS = new BufferedInputStream(Files.newInputStream(modelFile));
+ BufferedInputStream labelIS = new BufferedInputStream(Files.newInputStream(labelFile))
+ ) {
+ booster = loadModel(modelIS);
+ labels = loadLabels(labelIS);
+ }
+ }
+
+ public CacheInference(InputStream modelIS, InputStream labelIS) throws IOException, XGBoostError {
+ booster = loadModel(modelIS);
+ labels = loadLabels(labelIS);
+ }
+
+ public TTL cacheTTL(final FlightData data) throws XGBoostError {
+ float[][] predicts = booster.predict(createMatrix(data));
+ return TTL_VALUES[(int) predicts[0][0]];
+ }
+
+ private Booster loadModel(InputStream model) throws XGBoostError, IOException {
+ return XGBoost.loadModel(model);
+ }
+
+ private DMatrix createMatrix(final FlightData data) throws XGBoostError {
+ final float[] arr = new float[18];
+ arr[0] = labels.get("client.channel").getOrDefault(data.getClientChannel(), 0);
+ arr[1] = labels.get("type").getOrDefault(data.getType(), 0);
+ arr[2] = labels.get("flight.inboundSegments.departure").getOrDefault(joinList(data.getInboundDeparture()), 0);
+ arr[3] = labels.get("flight.inboundSegments.arrival").getOrDefault(joinList(data.getInboundArrival()), 0);
+ arr[4] = labels.get("flight.inboundSegments.origin.airportCode").getOrDefault(joinList(data.getInboundOrigin()), 0);
+ arr[5] = labels.get("flight.inboundSegments.destination.airportCode").getOrDefault(joinList(data.getInboundDestination()), 0);
+ arr[6] = labels.get("flight.outboundSegments.departure").getOrDefault(joinList(data.getOutboundDeparture()), 0);
+ arr[7] = labels.get("flight.outboundSegments.arrival").getOrDefault(joinList(data.getOutboundArrival()), 0);
+ arr[8] = labels.get("flight.outboundSegments.origin.airportCode").getOrDefault(joinList(data.getOutboundOrigin()), 0);
+ arr[9] = labels.get("flight.outboundSegments.destination.airportCode").getOrDefault(joinList(data.getOutboundDestination()), 0);
+ arr[10] = data.getInputPrice().floatValue();
+ arr[11] = data.getInputTax().floatValue();
+ arr[12] = labels.get("input.currency").getOrDefault(data.getInputCurrency(), 0);
+ arr[13] = data.getStatus().floatValue();
+ arr[14] = data.getOutputPrice().floatValue();
+ arr[15] = data.getOutputTax().floatValue();
+ arr[16] = labels.get("output.currency").getOrDefault(data.getOutputCurrency(), 0);
+ arr[17] = data.getDuration().floatValue();
+
+ return new DMatrix(arr, 1, arr.length);
+ }
+
+ private Map> loadLabels(InputStream labels) throws IOException {
+ final TypeReference