parent
76645f273d
commit
d675ca673a
10 changed files with 544 additions and 154 deletions
@ -1,115 +1,76 @@ |
|||||||
from pymongo import MongoClient |
from pymongo import MongoClient |
||||||
from pprint import pprint |
|
||||||
from datetime import datetime |
from datetime import datetime |
||||||
import csv |
import csv |
||||||
|
|
||||||
client = MongoClient() |
client = MongoClient() |
||||||
db=client.bonitoo |
db = client.bonitoo |
||||||
|
|
||||||
fieldnames = [ |
fieldnames = [ |
||||||
'timestamp', |
'timestamp', |
||||||
'client.channel', |
'client.channel', |
||||||
'type', |
'type', |
||||||
'flight.inboundSegments.departure', |
'flight.inboundSegments.departure', |
||||||
'flight.inboundSegments.arrival', |
'flight.inboundSegments.arrival', |
||||||
'flight.inboundSegments.origin.airportCode', |
'flight.inboundSegments.origin.airportCode', |
||||||
'flight.inboundSegments.destination.airportCode', |
'flight.inboundSegments.destination.airportCode', |
||||||
'flight.inboundSegments.flightNumber', |
'flight.outboundSegments.departure', |
||||||
'flight.inboundSegments.travelClass', |
'flight.outboundSegments.arrival', |
||||||
'flight.inboundSegments.bookingCode', |
'flight.outboundSegments.origin.airportCode', |
||||||
'flight.inboundSegments.availability', |
'flight.outboundSegments.destination.airportCode', |
||||||
'flight.inboundSegments.elapsedFlyingTime', |
'input.price', |
||||||
'flight.outboundSegments.departure', |
'input.tax', |
||||||
'flight.outboundSegments.arrival', |
'input.currency', |
||||||
'flight.outboundSegments.origin.airportCode', |
'success', |
||||||
'flight.outboundSegments.destination.airportCode', |
'status', |
||||||
'flight.outboundSegments.flightNumber', |
'output.price', |
||||||
'flight.outboundSegments.travelClass', |
'output.tax', |
||||||
'flight.outboundSegments.bookingCode', |
'output.currency', |
||||||
'flight.outboundSegments.availability', |
'duration' |
||||||
'flight.outboundSegments.elapsedFlyingTime', |
|
||||||
'flight.inboundEFT', # elapsed flying time |
|
||||||
'flight.outboundEFT', |
|
||||||
'oneWay', |
|
||||||
'adults', # pocet osob = (adults + children) |
|
||||||
'children', |
|
||||||
'infants', |
|
||||||
'input.price', |
|
||||||
'input.tax', |
|
||||||
'input.currency', |
|
||||||
'success', |
|
||||||
'status', |
|
||||||
'output.price', |
|
||||||
'output.tax', |
|
||||||
'output.currency', |
|
||||||
'duration' # delka volani do nadrazeneho systemu |
|
||||||
] |
] |
||||||
|
|
||||||
# 5% nebo 200 kc rozdil nahoru |
|
||||||
# -200 kc dolu |
|
||||||
# abs(+-10kc) ignorovat |
|
||||||
|
|
||||||
# timestamp + ok price - ma byt v cache od cacheat |
|
||||||
# timestamp + notok price - nema byt v cache od cacheat |
|
||||||
|
|
||||||
# delka pobytu prilet-odlet |
|
||||||
# delka letu ? |
|
||||||
# pokud je chyba tak nocache (= chybi priceout) |
|
||||||
|
|
||||||
# brat v uvahu in/out kody aerolinek (mcx ?) - mirek jeste zjisti |
|
||||||
|
|
||||||
# vypocitat uspesnost je/neni v cache v % |
|
||||||
|
|
||||||
counter = 0 |
counter = 0 |
||||||
with open('export.csv', mode='w') as ef: |
with open('export.csv', mode='w') as ef: |
||||||
writer = csv.DictWriter(ef, fieldnames=fieldnames, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL) |
writer = csv.DictWriter(ef, fieldnames=fieldnames, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL) |
||||||
# do not write header for s3 files |
# do not write header for s3 files |
||||||
# writer.writeheader() |
# writer.writeheader() |
||||||
|
|
||||||
|
for it in db.pricing_audit.find(): |
||||||
|
counter += 1 |
||||||
|
if counter % 1000 == 0: |
||||||
|
print('Iterace %d' % counter) |
||||||
|
d = { |
||||||
|
'timestamp': datetime.fromtimestamp(it['timestamp'] / 1000).isoformat(), |
||||||
|
'client.channel': it['client']['channel'], |
||||||
|
'type': it['type'], |
||||||
|
'flight.outboundSegments.departure': '|'.join( |
||||||
|
[x['departure'].isoformat() for x in it['flight']['outboundSegments']]), |
||||||
|
'flight.outboundSegments.arrival': '|'.join( |
||||||
|
[x['arrival'].isoformat() for x in it['flight']['outboundSegments']]), |
||||||
|
'flight.outboundSegments.origin.airportCode': '|'.join( |
||||||
|
[x['origin']['airportCode'] for x in it['flight']['outboundSegments']]), |
||||||
|
'flight.outboundSegments.destination.airportCode': '|'.join( |
||||||
|
[x['destination']['airportCode'] for x in it['flight']['outboundSegments']]), |
||||||
|
'input.price': it['input']['price'], |
||||||
|
'input.tax': it['input']['tax'], |
||||||
|
'input.currency': it['input']['currency'], |
||||||
|
'success': it['success'], |
||||||
|
'status': it.get('status', ''), |
||||||
|
'output.price': it.get('output', {'price': 0})['price'], |
||||||
|
'output.tax': it.get('output', {'tax': 0})['tax'], |
||||||
|
'output.currency': it.get('output', {'currency': 0})['currency'], |
||||||
|
'duration': it['duration'] |
||||||
|
} |
||||||
|
|
||||||
for it in db.pricing_audit.find(): |
if 'inboundSegments' in it['flight']: |
||||||
counter += 1 |
inb = { |
||||||
if counter % 1000 == 0: |
'flight.inboundSegments.departure': '|'.join( |
||||||
print('Iterace %d' % counter) |
[x['departure'].isoformat() for x in it['flight']['inboundSegments']]), |
||||||
d = { |
'flight.inboundSegments.arrival': '|'.join( |
||||||
'timestamp': datetime.fromtimestamp(it['timestamp']/1000).isoformat(), |
[x['arrival'].isoformat() for x in it['flight']['inboundSegments']]), |
||||||
'client.channel': it['client']['channel'], |
'flight.inboundSegments.origin.airportCode': '|'.join( |
||||||
'type': it['type'], |
[x['origin']['airportCode'] for x in it['flight']['inboundSegments']]), |
||||||
'flight.outboundSegments.departure': '|'.join([x['departure'].isoformat() for x in it['flight']['outboundSegments']]), |
'flight.inboundSegments.destination.airportCode': '|'.join( |
||||||
'flight.outboundSegments.arrival': '|'.join([x['arrival'].isoformat() for x in it['flight']['outboundSegments']]), |
[x['destination']['airportCode'] for x in it['flight']['inboundSegments']]), |
||||||
'flight.outboundSegments.origin.airportCode': '|'.join([x['origin']['airportCode'] for x in it['flight']['outboundSegments']]), |
} |
||||||
'flight.outboundSegments.destination.airportCode': '|'.join([x['destination']['airportCode'] for x in it['flight']['outboundSegments']]), |
d = {**d, **inb} |
||||||
'flight.outboundSegments.flightNumber': '|'.join([x['flightNumber'] for x in it['flight']['outboundSegments']]), |
writer.writerow(d) |
||||||
'flight.outboundSegments.travelClass': '|'.join([x['travelClass'] for x in it['flight']['outboundSegments']]), |
|
||||||
'flight.outboundSegments.bookingCode': '|'.join([x.get('bookingCode','') for x in it['flight']['outboundSegments']]), |
|
||||||
'flight.outboundSegments.availability': '|'.join([str(x.get('availability','')) for x in it['flight']['outboundSegments']]), |
|
||||||
'flight.outboundSegments.elapsedFlyingTime': '|'.join([str(x.get('elapsedFlyingTime','')) for x in it['flight']['outboundSegments']]), |
|
||||||
'flight.inboundEFT': it['flight'].get('inboundEFT',''), |
|
||||||
'flight.outboundEFT': it['flight'].get('outboundEFT',''), |
|
||||||
'oneWay': it['oneWay'], |
|
||||||
'adults': it['adults'], |
|
||||||
'children': it['children'], |
|
||||||
'infants': it['infants'], |
|
||||||
'input.price': it['input']['price'], |
|
||||||
'input.tax': it['input']['tax'], |
|
||||||
'input.currency': it['input']['currency'], |
|
||||||
'success': it['success'], |
|
||||||
'status': it.get('status',''), |
|
||||||
'output.price': it.get('output', {'price': 0})['price'], |
|
||||||
'output.tax': it.get('output', {'tax': 0})['tax'], |
|
||||||
'output.currency': it.get('output', {'currency': 0})['currency'], |
|
||||||
'duration': it['duration'] |
|
||||||
} |
|
||||||
|
|
||||||
if 'inboundSegments' in it['flight']: |
|
||||||
inb = { |
|
||||||
'flight.inboundSegments.departure': '|'.join([x['departure'].isoformat() for x in it['flight']['inboundSegments']]), |
|
||||||
'flight.inboundSegments.arrival': '|'.join([x['arrival'].isoformat() for x in it['flight']['inboundSegments']]), |
|
||||||
'flight.inboundSegments.origin.airportCode': '|'.join([x['origin']['airportCode'] for x in it['flight']['inboundSegments']]), |
|
||||||
'flight.inboundSegments.destination.airportCode': '|'.join([x['destination']['airportCode'] for x in it['flight']['inboundSegments']]), |
|
||||||
'flight.inboundSegments.flightNumber': '|'.join([x['flightNumber'] for x in it['flight']['inboundSegments']]), |
|
||||||
'flight.inboundSegments.travelClass': '|'.join([x['travelClass'] for x in it['flight']['inboundSegments']]), |
|
||||||
'flight.inboundSegments.bookingCode': '|'.join([x.get('bookingCode', '') for x in it['flight']['inboundSegments']]), |
|
||||||
'flight.inboundSegments.availability': '|'.join([str(x.get('availability','')) for x in it['flight']['inboundSegments']]), |
|
||||||
'flight.inboundSegments.elapsedFlyingTime': '|'.join([str(x.get('elapsedFlyingTime','')) for x in it['flight']['inboundSegments']]) |
|
||||||
} |
|
||||||
d = {**d, **inb} |
|
||||||
writer.writerow(d) |
|
||||||
|
@ -0,0 +1,58 @@ |
|||||||
|
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" |
||||||
|
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd"> |
||||||
|
<modelVersion>4.0.0</modelVersion> |
||||||
|
<groupId>cz.aprar.bonitoo</groupId> |
||||||
|
<artifactId>inference</artifactId> |
||||||
|
<packaging>jar</packaging> |
||||||
|
<version>1.0-SNAPSHOT</version> |
||||||
|
<name>inference</name> |
||||||
|
|
||||||
|
<properties> |
||||||
|
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> |
||||||
|
</properties> |
||||||
|
|
||||||
|
<dependencies> |
||||||
|
<dependency> |
||||||
|
<groupId>ml.dmlc</groupId> |
||||||
|
<artifactId>xgboost4j</artifactId> |
||||||
|
<version>0.90</version> |
||||||
|
</dependency> |
||||||
|
|
||||||
|
<dependency> |
||||||
|
<groupId>com.fasterxml.jackson.core</groupId> |
||||||
|
<artifactId>jackson-databind</artifactId> |
||||||
|
<version>2.9.9</version> |
||||||
|
</dependency> |
||||||
|
|
||||||
|
<dependency> |
||||||
|
<groupId>org.testng</groupId> |
||||||
|
<artifactId>testng</artifactId> |
||||||
|
<version>7.0.0</version> |
||||||
|
<scope>test</scope> |
||||||
|
</dependency> |
||||||
|
</dependencies> |
||||||
|
|
||||||
|
<build> |
||||||
|
<plugins> |
||||||
|
<plugin> |
||||||
|
<groupId>org.apache.maven.plugins</groupId> |
||||||
|
<artifactId>maven-compiler-plugin</artifactId> |
||||||
|
<version>3.8.1</version> |
||||||
|
<configuration> |
||||||
|
<source>8</source> |
||||||
|
<target>8</target> |
||||||
|
</configuration> |
||||||
|
</plugin> |
||||||
|
|
||||||
|
<plugin> |
||||||
|
<groupId>org.apache.maven.plugins</groupId> |
||||||
|
<artifactId>maven-pmd-plugin</artifactId> |
||||||
|
<version>3.12.0</version> |
||||||
|
<configuration> |
||||||
|
<sourceEncoding>utf-8</sourceEncoding> |
||||||
|
<targetJdk>1.8</targetJdk> |
||||||
|
</configuration> |
||||||
|
</plugin> |
||||||
|
</plugins> |
||||||
|
</build> |
||||||
|
</project> |
@ -0,0 +1,81 @@ |
|||||||
|
package cz.aprar.bonitoo.inference; |
||||||
|
|
||||||
|
import com.fasterxml.jackson.core.type.TypeReference; |
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper; |
||||||
|
import ml.dmlc.xgboost4j.java.Booster; |
||||||
|
import ml.dmlc.xgboost4j.java.DMatrix; |
||||||
|
import ml.dmlc.xgboost4j.java.XGBoost; |
||||||
|
import ml.dmlc.xgboost4j.java.XGBoostError; |
||||||
|
|
||||||
|
import java.io.BufferedInputStream; |
||||||
|
import java.io.IOException; |
||||||
|
import java.io.InputStream; |
||||||
|
import java.nio.file.Files; |
||||||
|
import java.nio.file.Path; |
||||||
|
import java.util.List; |
||||||
|
import java.util.Map; |
||||||
|
|
||||||
|
public class CacheInference { |
||||||
|
private static final TTL[] TTL_VALUES = TTL.values(); |
||||||
|
|
||||||
|
private final ObjectMapper mapper = new ObjectMapper(); |
||||||
|
private final Booster booster; |
||||||
|
private final Map<String, Map<String, Integer>> labels; |
||||||
|
|
||||||
|
public CacheInference(final Path modelFile, final Path labelFile) throws IOException, XGBoostError { |
||||||
|
try (BufferedInputStream modelIS = new BufferedInputStream(Files.newInputStream(modelFile)); |
||||||
|
BufferedInputStream labelIS = new BufferedInputStream(Files.newInputStream(labelFile)) |
||||||
|
) { |
||||||
|
booster = loadModel(modelIS); |
||||||
|
labels = loadLabels(labelIS); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
public CacheInference(InputStream modelIS, InputStream labelIS) throws IOException, XGBoostError { |
||||||
|
booster = loadModel(modelIS); |
||||||
|
labels = loadLabels(labelIS); |
||||||
|
} |
||||||
|
|
||||||
|
public TTL cacheTTL(final FlightData data) throws XGBoostError { |
||||||
|
float[][] predicts = booster.predict(createMatrix(data)); |
||||||
|
return TTL_VALUES[(int) predicts[0][0]]; |
||||||
|
} |
||||||
|
|
||||||
|
private Booster loadModel(InputStream model) throws XGBoostError, IOException { |
||||||
|
return XGBoost.loadModel(model); |
||||||
|
} |
||||||
|
|
||||||
|
private DMatrix createMatrix(final FlightData data) throws XGBoostError { |
||||||
|
final float[] arr = new float[18]; |
||||||
|
arr[0] = labels.get("client.channel").getOrDefault(data.getClientChannel(), 0); |
||||||
|
arr[1] = labels.get("type").getOrDefault(data.getType(), 0); |
||||||
|
arr[2] = labels.get("flight.inboundSegments.departure").getOrDefault(joinList(data.getInboundDeparture()), 0); |
||||||
|
arr[3] = labels.get("flight.inboundSegments.arrival").getOrDefault(joinList(data.getInboundArrival()), 0); |
||||||
|
arr[4] = labels.get("flight.inboundSegments.origin.airportCode").getOrDefault(joinList(data.getInboundOrigin()), 0); |
||||||
|
arr[5] = labels.get("flight.inboundSegments.destination.airportCode").getOrDefault(joinList(data.getInboundDestination()), 0); |
||||||
|
arr[6] = labels.get("flight.outboundSegments.departure").getOrDefault(joinList(data.getOutboundDeparture()), 0); |
||||||
|
arr[7] = labels.get("flight.outboundSegments.arrival").getOrDefault(joinList(data.getOutboundArrival()), 0); |
||||||
|
arr[8] = labels.get("flight.outboundSegments.origin.airportCode").getOrDefault(joinList(data.getOutboundOrigin()), 0); |
||||||
|
arr[9] = labels.get("flight.outboundSegments.destination.airportCode").getOrDefault(joinList(data.getOutboundDestination()), 0); |
||||||
|
arr[10] = data.getInputPrice().floatValue(); |
||||||
|
arr[11] = data.getInputTax().floatValue(); |
||||||
|
arr[12] = labels.get("input.currency").getOrDefault(data.getInputCurrency(), 0); |
||||||
|
arr[13] = data.getStatus().floatValue(); |
||||||
|
arr[14] = data.getOutputPrice().floatValue(); |
||||||
|
arr[15] = data.getOutputTax().floatValue(); |
||||||
|
arr[16] = labels.get("output.currency").getOrDefault(data.getOutputCurrency(), 0); |
||||||
|
arr[17] = data.getDuration().floatValue(); |
||||||
|
|
||||||
|
return new DMatrix(arr, 1, arr.length); |
||||||
|
} |
||||||
|
|
||||||
|
private Map<String, Map<String, Integer>> loadLabels(InputStream labels) throws IOException { |
||||||
|
final TypeReference<Map<String, Map<String, Integer>>> typeRef = new TypeReference<Map<String, Map<String, Integer>>>() { |
||||||
|
}; |
||||||
|
return mapper.readValue(labels, typeRef); |
||||||
|
} |
||||||
|
|
||||||
|
private String joinList(final List<String> data) { |
||||||
|
return String.join("|", data); |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,123 @@ |
|||||||
|
package cz.aprar.bonitoo.inference; |
||||||
|
|
||||||
|
import java.util.Collections; |
||||||
|
import java.util.List; |
||||||
|
|
||||||
|
public class FlightData { |
||||||
|
private final String clientChannel; |
||||||
|
private final String type; |
||||||
|
private final List<String> inboundDeparture; |
||||||
|
private final List<String> inboundArrival; |
||||||
|
private final List<String> inboundOrigin; |
||||||
|
private final List<String> inboundDestination; |
||||||
|
private final List<String> outboundDeparture; |
||||||
|
private final List<String> outboundArrival; |
||||||
|
private final List<String> outboundOrigin; |
||||||
|
private final List<String> outboundDestination; |
||||||
|
private final Double inputPrice; |
||||||
|
private final Double inputTax; |
||||||
|
private final String inputCurrency; |
||||||
|
private final Integer status; |
||||||
|
private final Double outputPrice; |
||||||
|
private final Double outputTax; |
||||||
|
private final String outputCurrency; |
||||||
|
private final Integer duration; |
||||||
|
|
||||||
|
public FlightData(final String clientChannel, final String type, |
||||||
|
final List<String> inboundDeparture, final List<String> inboundArrival, final List<String> inboundOrigin, |
||||||
|
final List<String> inboundDestination, final List<String> outboundDeparture, final List<String> outboundArrival, |
||||||
|
final List<String> outboundOrigin, final List<String> outboundDestination, final Double inputPrice, |
||||||
|
final Double inputTax, final String inputCurrency, final Integer status, final Double outputPrice, |
||||||
|
final Double outputTax, final String outputCurrency, final Integer duration) { |
||||||
|
this.clientChannel = clientChannel; |
||||||
|
this.type = type; |
||||||
|
this.inboundDeparture = inboundDeparture; |
||||||
|
this.inboundArrival = inboundArrival; |
||||||
|
this.inboundOrigin = inboundOrigin; |
||||||
|
this.inboundDestination = inboundDestination; |
||||||
|
this.outboundDeparture = outboundDeparture; |
||||||
|
this.outboundArrival = outboundArrival; |
||||||
|
this.outboundOrigin = outboundOrigin; |
||||||
|
this.outboundDestination = outboundDestination; |
||||||
|
this.inputPrice = inputPrice; |
||||||
|
this.inputTax = inputTax; |
||||||
|
this.inputCurrency = inputCurrency; |
||||||
|
this.status = status; |
||||||
|
this.outputPrice = outputPrice; |
||||||
|
this.outputTax = outputTax; |
||||||
|
this.outputCurrency = outputCurrency; |
||||||
|
this.duration = duration; |
||||||
|
} |
||||||
|
|
||||||
|
public String getClientChannel() { |
||||||
|
return clientChannel; |
||||||
|
} |
||||||
|
|
||||||
|
public String getType() { |
||||||
|
return type; |
||||||
|
} |
||||||
|
|
||||||
|
public List<String> getInboundDeparture() { |
||||||
|
return Collections.unmodifiableList(inboundDeparture); |
||||||
|
} |
||||||
|
|
||||||
|
public List<String> getInboundArrival() { |
||||||
|
return Collections.unmodifiableList(inboundArrival); |
||||||
|
} |
||||||
|
|
||||||
|
public List<String> getInboundOrigin() { |
||||||
|
return Collections.unmodifiableList(inboundOrigin); |
||||||
|
} |
||||||
|
|
||||||
|
public List<String> getInboundDestination() { |
||||||
|
return Collections.unmodifiableList(inboundDestination); |
||||||
|
} |
||||||
|
|
||||||
|
public List<String> getOutboundDeparture() { |
||||||
|
return Collections.unmodifiableList(outboundDeparture); |
||||||
|
} |
||||||
|
|
||||||
|
public List<String> getOutboundArrival() { |
||||||
|
return Collections.unmodifiableList(outboundArrival); |
||||||
|
} |
||||||
|
|
||||||
|
public List<String> getOutboundOrigin() { |
||||||
|
return Collections.unmodifiableList(outboundOrigin); |
||||||
|
} |
||||||
|
|
||||||
|
public List<String> getOutboundDestination() { |
||||||
|
return Collections.unmodifiableList(outboundDestination); |
||||||
|
} |
||||||
|
|
||||||
|
public Double getInputPrice() { |
||||||
|
return inputPrice; |
||||||
|
} |
||||||
|
|
||||||
|
public Double getInputTax() { |
||||||
|
return inputTax; |
||||||
|
} |
||||||
|
|
||||||
|
public String getInputCurrency() { |
||||||
|
return inputCurrency; |
||||||
|
} |
||||||
|
|
||||||
|
public Integer getStatus() { |
||||||
|
return status; |
||||||
|
} |
||||||
|
|
||||||
|
public Double getOutputPrice() { |
||||||
|
return outputPrice; |
||||||
|
} |
||||||
|
|
||||||
|
public Double getOutputTax() { |
||||||
|
return outputTax; |
||||||
|
} |
||||||
|
|
||||||
|
public String getOutputCurrency() { |
||||||
|
return outputCurrency; |
||||||
|
} |
||||||
|
|
||||||
|
public Integer getDuration() { |
||||||
|
return duration; |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,36 @@ |
|||||||
|
package cz.aprar.bonitoo.inference; |
||||||
|
|
||||||
|
public enum TTL { |
||||||
|
/** |
||||||
|
* Very volatile or unknown data |
||||||
|
*/ |
||||||
|
NOCACHE, |
||||||
|
/** |
||||||
|
* Cache for 12 hours |
||||||
|
*/ |
||||||
|
H12, |
||||||
|
/** |
||||||
|
* Cache for 1 day |
||||||
|
*/ |
||||||
|
D1, |
||||||
|
/** |
||||||
|
* Cache for 2 days |
||||||
|
*/ |
||||||
|
D2, |
||||||
|
/** |
||||||
|
* Cache for 3 days |
||||||
|
*/ |
||||||
|
D3, |
||||||
|
/** |
||||||
|
* Cache for 1 week |
||||||
|
*/ |
||||||
|
D7, |
||||||
|
/** |
||||||
|
* Cache for 2 weeks |
||||||
|
*/ |
||||||
|
D14, |
||||||
|
/** |
||||||
|
* Cache for 30 days |
||||||
|
*/ |
||||||
|
D30; |
||||||
|
} |
@ -0,0 +1,140 @@ |
|||||||
|
package cz.aprar.bonitoo.inference; |
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.java.XGBoostError; |
||||||
|
import org.testng.annotations.BeforeClass; |
||||||
|
import org.testng.annotations.DataProvider; |
||||||
|
import org.testng.annotations.Test; |
||||||
|
|
||||||
|
import java.io.IOException; |
||||||
|
import java.util.Arrays; |
||||||
|
import java.util.List; |
||||||
|
|
||||||
|
import static java.util.Collections.emptyList; |
||||||
|
import static org.testng.Assert.assertEquals; |
||||||
|
|
||||||
|
public class CacheInferenceTest { |
||||||
|
private CacheInference ci; |
||||||
|
|
||||||
|
@BeforeClass |
||||||
|
void setUp() throws IOException, XGBoostError { |
||||||
|
final ClassLoader classLoader = getClass().getClassLoader(); |
||||||
|
ci = new CacheInference(classLoader.getResourceAsStream("model/xgboost-model.bin"), |
||||||
|
classLoader.getResourceAsStream("model/encoder.json")); |
||||||
|
} |
||||||
|
|
||||||
|
@Test(dataProvider = "inferenceData") |
||||||
|
void testInference(final FlightData input, final TTL expected) throws XGBoostError { |
||||||
|
final TTL result = ci.cacheTTL(input); |
||||||
|
assertEquals(result, expected); |
||||||
|
} |
||||||
|
|
||||||
|
@DataProvider(name = "inferenceData") |
||||||
|
public Object[][] inferenceData() { |
||||||
|
return new Object[][]{ |
||||||
|
{new FlightData( |
||||||
|
"fly-me-to", |
||||||
|
"WS", |
||||||
|
toList("2020-05-09T23:59:00", "2020-05-10T10:30:00"), |
||||||
|
toList("2020-05-10T08:55:00", "2020-05-10T11:30:00"), |
||||||
|
toList("MCO", "FRA"), |
||||||
|
toList("FRA", "PRG"), |
||||||
|
toList("2020-05-01T09:50:00", "2020-05-01T11:55:00"), |
||||||
|
toList("2020-05-01T11:00:00", "2020-05-01T21:55:00"), |
||||||
|
toList("PRG", "FRA"), |
||||||
|
toList("FRA", "MCO"), |
||||||
|
39766.0, |
||||||
|
19776.0, |
||||||
|
"CZK", |
||||||
|
0, |
||||||
|
39766.0, |
||||||
|
19776.0, |
||||||
|
"CZK", |
||||||
|
427 |
||||||
|
), TTL.D1}, |
||||||
|
{new FlightData( |
||||||
|
"fly-me-to", |
||||||
|
"PYTON", |
||||||
|
emptyList(), |
||||||
|
emptyList(), |
||||||
|
emptyList(), |
||||||
|
emptyList(), |
||||||
|
toList("2019-12-18T05:45:00"), |
||||||
|
toList("2019-12-18T08:05:00"), |
||||||
|
toList("KRK"), |
||||||
|
toList("BVA"), |
||||||
|
336.258, |
||||||
|
0.0, |
||||||
|
"CZK", |
||||||
|
0, |
||||||
|
336.258, |
||||||
|
0.0, |
||||||
|
"CZK", |
||||||
|
2284 |
||||||
|
), TTL.D7}, |
||||||
|
{new FlightData( |
||||||
|
"levne", |
||||||
|
"AVIA", |
||||||
|
toList("2020-02-07T02:25:00", "2020-02-07T14:50:00"), |
||||||
|
toList("2020-02-07T13:10:00", "2020-02-07T16:55:00"), |
||||||
|
toList("LAX", "LHR"), |
||||||
|
toList("LHR", "PRG"), |
||||||
|
toList("2020-01-28T10:35:00", "2020-01-28T14:40:00"), |
||||||
|
toList("2020-01-28T12:45:00", "2020-01-29T01:45:00"), |
||||||
|
toList("PRG", "HEL"), |
||||||
|
toList("HEL", "LAX"), |
||||||
|
5971.77978, |
||||||
|
0.0, |
||||||
|
"CZK", |
||||||
|
0, |
||||||
|
15971.77978, |
||||||
|
0.0, |
||||||
|
"CZK", |
||||||
|
551 |
||||||
|
), TTL.D7}, |
||||||
|
{new FlightData( |
||||||
|
"fly-me-to", |
||||||
|
"HH", |
||||||
|
toList("2019-11-01T16:30:00", "2019-11-01T23:35:00"), |
||||||
|
toList("2019-11-01T21:12:00", "2019-11-02T07:45:00"), |
||||||
|
toList("YVR", "YUL"), |
||||||
|
toList("YUL", "VIE"), |
||||||
|
toList("2019-10-18T08:10:00", "2019-10-18T11:30:00"), |
||||||
|
toList("2019-10-18T09:40:00", "2019-10-18T21:25:00"), |
||||||
|
toList("VIE", "FRA"), |
||||||
|
toList("FRA", "YVR"), |
||||||
|
17723.0, |
||||||
|
7708.0, |
||||||
|
"CZK", |
||||||
|
0, |
||||||
|
17723.0, |
||||||
|
7708.0, |
||||||
|
"CZK", |
||||||
|
1786 |
||||||
|
), TTL.D1}, |
||||||
|
{new FlightData( |
||||||
|
"unknown", |
||||||
|
"unknown", |
||||||
|
toList("unknown"), |
||||||
|
toList("unknown"), |
||||||
|
toList("unknown"), |
||||||
|
toList("unknown"), |
||||||
|
toList("unknown"), |
||||||
|
toList("unknown"), |
||||||
|
toList("unknown"), |
||||||
|
toList("unknown"), |
||||||
|
0.0, |
||||||
|
0.0, |
||||||
|
"unknown", |
||||||
|
0, |
||||||
|
0.0, |
||||||
|
0.0, |
||||||
|
"unknown", |
||||||
|
0 |
||||||
|
), TTL.NOCACHE} |
||||||
|
}; |
||||||
|
} |
||||||
|
|
||||||
|
private List<String> toList(final String... data) { |
||||||
|
return Arrays.asList(data); |
||||||
|
} |
||||||
|
} |
File diff suppressed because one or more lines are too long
Binary file not shown.
@ -0,0 +1,33 @@ |
|||||||
|
import sagemaker |
||||||
|
import boto3 |
||||||
|
from sagemaker import get_execution_role |
||||||
|
from sagemaker.xgboost.estimator import XGBoost |
||||||
|
|
||||||
|
boto_session = boto3.Session(profile_name='bonitoo', region_name='eu-central-1') |
||||||
|
|
||||||
|
# local: sagemaker_session = sagemaker.LocalSession(boto_session=boto_session) |
||||||
|
sagemaker_session = sagemaker.Session(boto_session=boto_session) |
||||||
|
|
||||||
|
role = 'Bonitoo_SageMaker_Execution' |
||||||
|
train_input = 's3://customers-bonitoo-cachettl/sagemaker/data/export-reduced.csv' |
||||||
|
|
||||||
|
tf = XGBoost( |
||||||
|
entry_point='train_model.py', |
||||||
|
source_dir='./src', |
||||||
|
train_instance_type='ml.c5.xlarge', |
||||||
|
train_instance_count=1, |
||||||
|
role=role, |
||||||
|
sagemaker_session=sagemaker_session, |
||||||
|
framework_version='0.90-1', |
||||||
|
py_version='py3', |
||||||
|
hyperparameters={ |
||||||
|
'bonitoo_price_limit': 1000, |
||||||
|
'num_round': 15, |
||||||
|
'max_depth': 15, |
||||||
|
'eta': 0.5, |
||||||
|
'num_class': 8, |
||||||
|
'objective': 'multi:softmax', |
||||||
|
'eval_metric': 'mlogloss' |
||||||
|
}) |
||||||
|
|
||||||
|
tf.fit({'training': train_input}) |
Loading…
Reference in new issue