parent
76645f273d
commit
d675ca673a
10 changed files with 544 additions and 154 deletions
@ -0,0 +1,58 @@ |
|||||||
|
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" |
||||||
|
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd"> |
||||||
|
<modelVersion>4.0.0</modelVersion> |
||||||
|
<groupId>cz.aprar.bonitoo</groupId> |
||||||
|
<artifactId>inference</artifactId> |
||||||
|
<packaging>jar</packaging> |
||||||
|
<version>1.0-SNAPSHOT</version> |
||||||
|
<name>inference</name> |
||||||
|
|
||||||
|
<properties> |
||||||
|
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> |
||||||
|
</properties> |
||||||
|
|
||||||
|
<dependencies> |
||||||
|
<dependency> |
||||||
|
<groupId>ml.dmlc</groupId> |
||||||
|
<artifactId>xgboost4j</artifactId> |
||||||
|
<version>0.90</version> |
||||||
|
</dependency> |
||||||
|
|
||||||
|
<dependency> |
||||||
|
<groupId>com.fasterxml.jackson.core</groupId> |
||||||
|
<artifactId>jackson-databind</artifactId> |
||||||
|
<version>2.9.9</version> |
||||||
|
</dependency> |
||||||
|
|
||||||
|
<dependency> |
||||||
|
<groupId>org.testng</groupId> |
||||||
|
<artifactId>testng</artifactId> |
||||||
|
<version>7.0.0</version> |
||||||
|
<scope>test</scope> |
||||||
|
</dependency> |
||||||
|
</dependencies> |
||||||
|
|
||||||
|
<build> |
||||||
|
<plugins> |
||||||
|
<plugin> |
||||||
|
<groupId>org.apache.maven.plugins</groupId> |
||||||
|
<artifactId>maven-compiler-plugin</artifactId> |
||||||
|
<version>3.8.1</version> |
||||||
|
<configuration> |
||||||
|
<source>8</source> |
||||||
|
<target>8</target> |
||||||
|
</configuration> |
||||||
|
</plugin> |
||||||
|
|
||||||
|
<plugin> |
||||||
|
<groupId>org.apache.maven.plugins</groupId> |
||||||
|
<artifactId>maven-pmd-plugin</artifactId> |
||||||
|
<version>3.12.0</version> |
||||||
|
<configuration> |
||||||
|
<sourceEncoding>utf-8</sourceEncoding> |
||||||
|
<targetJdk>1.8</targetJdk> |
||||||
|
</configuration> |
||||||
|
</plugin> |
||||||
|
</plugins> |
||||||
|
</build> |
||||||
|
</project> |
@ -0,0 +1,81 @@ |
|||||||
|
package cz.aprar.bonitoo.inference; |
||||||
|
|
||||||
|
import com.fasterxml.jackson.core.type.TypeReference; |
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper; |
||||||
|
import ml.dmlc.xgboost4j.java.Booster; |
||||||
|
import ml.dmlc.xgboost4j.java.DMatrix; |
||||||
|
import ml.dmlc.xgboost4j.java.XGBoost; |
||||||
|
import ml.dmlc.xgboost4j.java.XGBoostError; |
||||||
|
|
||||||
|
import java.io.BufferedInputStream; |
||||||
|
import java.io.IOException; |
||||||
|
import java.io.InputStream; |
||||||
|
import java.nio.file.Files; |
||||||
|
import java.nio.file.Path; |
||||||
|
import java.util.List; |
||||||
|
import java.util.Map; |
||||||
|
|
||||||
|
public class CacheInference { |
||||||
|
private static final TTL[] TTL_VALUES = TTL.values(); |
||||||
|
|
||||||
|
private final ObjectMapper mapper = new ObjectMapper(); |
||||||
|
private final Booster booster; |
||||||
|
private final Map<String, Map<String, Integer>> labels; |
||||||
|
|
||||||
|
public CacheInference(final Path modelFile, final Path labelFile) throws IOException, XGBoostError { |
||||||
|
try (BufferedInputStream modelIS = new BufferedInputStream(Files.newInputStream(modelFile)); |
||||||
|
BufferedInputStream labelIS = new BufferedInputStream(Files.newInputStream(labelFile)) |
||||||
|
) { |
||||||
|
booster = loadModel(modelIS); |
||||||
|
labels = loadLabels(labelIS); |
||||||
|
} |
||||||
|
} |
||||||
|
|
||||||
|
public CacheInference(InputStream modelIS, InputStream labelIS) throws IOException, XGBoostError { |
||||||
|
booster = loadModel(modelIS); |
||||||
|
labels = loadLabels(labelIS); |
||||||
|
} |
||||||
|
|
||||||
|
public TTL cacheTTL(final FlightData data) throws XGBoostError { |
||||||
|
float[][] predicts = booster.predict(createMatrix(data)); |
||||||
|
return TTL_VALUES[(int) predicts[0][0]]; |
||||||
|
} |
||||||
|
|
||||||
|
private Booster loadModel(InputStream model) throws XGBoostError, IOException { |
||||||
|
return XGBoost.loadModel(model); |
||||||
|
} |
||||||
|
|
||||||
|
private DMatrix createMatrix(final FlightData data) throws XGBoostError { |
||||||
|
final float[] arr = new float[18]; |
||||||
|
arr[0] = labels.get("client.channel").getOrDefault(data.getClientChannel(), 0); |
||||||
|
arr[1] = labels.get("type").getOrDefault(data.getType(), 0); |
||||||
|
arr[2] = labels.get("flight.inboundSegments.departure").getOrDefault(joinList(data.getInboundDeparture()), 0); |
||||||
|
arr[3] = labels.get("flight.inboundSegments.arrival").getOrDefault(joinList(data.getInboundArrival()), 0); |
||||||
|
arr[4] = labels.get("flight.inboundSegments.origin.airportCode").getOrDefault(joinList(data.getInboundOrigin()), 0); |
||||||
|
arr[5] = labels.get("flight.inboundSegments.destination.airportCode").getOrDefault(joinList(data.getInboundDestination()), 0); |
||||||
|
arr[6] = labels.get("flight.outboundSegments.departure").getOrDefault(joinList(data.getOutboundDeparture()), 0); |
||||||
|
arr[7] = labels.get("flight.outboundSegments.arrival").getOrDefault(joinList(data.getOutboundArrival()), 0); |
||||||
|
arr[8] = labels.get("flight.outboundSegments.origin.airportCode").getOrDefault(joinList(data.getOutboundOrigin()), 0); |
||||||
|
arr[9] = labels.get("flight.outboundSegments.destination.airportCode").getOrDefault(joinList(data.getOutboundDestination()), 0); |
||||||
|
arr[10] = data.getInputPrice().floatValue(); |
||||||
|
arr[11] = data.getInputTax().floatValue(); |
||||||
|
arr[12] = labels.get("input.currency").getOrDefault(data.getInputCurrency(), 0); |
||||||
|
arr[13] = data.getStatus().floatValue(); |
||||||
|
arr[14] = data.getOutputPrice().floatValue(); |
||||||
|
arr[15] = data.getOutputTax().floatValue(); |
||||||
|
arr[16] = labels.get("output.currency").getOrDefault(data.getOutputCurrency(), 0); |
||||||
|
arr[17] = data.getDuration().floatValue(); |
||||||
|
|
||||||
|
return new DMatrix(arr, 1, arr.length); |
||||||
|
} |
||||||
|
|
||||||
|
private Map<String, Map<String, Integer>> loadLabels(InputStream labels) throws IOException { |
||||||
|
final TypeReference<Map<String, Map<String, Integer>>> typeRef = new TypeReference<Map<String, Map<String, Integer>>>() { |
||||||
|
}; |
||||||
|
return mapper.readValue(labels, typeRef); |
||||||
|
} |
||||||
|
|
||||||
|
private String joinList(final List<String> data) { |
||||||
|
return String.join("|", data); |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,123 @@ |
|||||||
|
package cz.aprar.bonitoo.inference; |
||||||
|
|
||||||
|
import java.util.Collections; |
||||||
|
import java.util.List; |
||||||
|
|
||||||
|
public class FlightData { |
||||||
|
private final String clientChannel; |
||||||
|
private final String type; |
||||||
|
private final List<String> inboundDeparture; |
||||||
|
private final List<String> inboundArrival; |
||||||
|
private final List<String> inboundOrigin; |
||||||
|
private final List<String> inboundDestination; |
||||||
|
private final List<String> outboundDeparture; |
||||||
|
private final List<String> outboundArrival; |
||||||
|
private final List<String> outboundOrigin; |
||||||
|
private final List<String> outboundDestination; |
||||||
|
private final Double inputPrice; |
||||||
|
private final Double inputTax; |
||||||
|
private final String inputCurrency; |
||||||
|
private final Integer status; |
||||||
|
private final Double outputPrice; |
||||||
|
private final Double outputTax; |
||||||
|
private final String outputCurrency; |
||||||
|
private final Integer duration; |
||||||
|
|
||||||
|
public FlightData(final String clientChannel, final String type, |
||||||
|
final List<String> inboundDeparture, final List<String> inboundArrival, final List<String> inboundOrigin, |
||||||
|
final List<String> inboundDestination, final List<String> outboundDeparture, final List<String> outboundArrival, |
||||||
|
final List<String> outboundOrigin, final List<String> outboundDestination, final Double inputPrice, |
||||||
|
final Double inputTax, final String inputCurrency, final Integer status, final Double outputPrice, |
||||||
|
final Double outputTax, final String outputCurrency, final Integer duration) { |
||||||
|
this.clientChannel = clientChannel; |
||||||
|
this.type = type; |
||||||
|
this.inboundDeparture = inboundDeparture; |
||||||
|
this.inboundArrival = inboundArrival; |
||||||
|
this.inboundOrigin = inboundOrigin; |
||||||
|
this.inboundDestination = inboundDestination; |
||||||
|
this.outboundDeparture = outboundDeparture; |
||||||
|
this.outboundArrival = outboundArrival; |
||||||
|
this.outboundOrigin = outboundOrigin; |
||||||
|
this.outboundDestination = outboundDestination; |
||||||
|
this.inputPrice = inputPrice; |
||||||
|
this.inputTax = inputTax; |
||||||
|
this.inputCurrency = inputCurrency; |
||||||
|
this.status = status; |
||||||
|
this.outputPrice = outputPrice; |
||||||
|
this.outputTax = outputTax; |
||||||
|
this.outputCurrency = outputCurrency; |
||||||
|
this.duration = duration; |
||||||
|
} |
||||||
|
|
||||||
|
public String getClientChannel() { |
||||||
|
return clientChannel; |
||||||
|
} |
||||||
|
|
||||||
|
public String getType() { |
||||||
|
return type; |
||||||
|
} |
||||||
|
|
||||||
|
public List<String> getInboundDeparture() { |
||||||
|
return Collections.unmodifiableList(inboundDeparture); |
||||||
|
} |
||||||
|
|
||||||
|
public List<String> getInboundArrival() { |
||||||
|
return Collections.unmodifiableList(inboundArrival); |
||||||
|
} |
||||||
|
|
||||||
|
public List<String> getInboundOrigin() { |
||||||
|
return Collections.unmodifiableList(inboundOrigin); |
||||||
|
} |
||||||
|
|
||||||
|
public List<String> getInboundDestination() { |
||||||
|
return Collections.unmodifiableList(inboundDestination); |
||||||
|
} |
||||||
|
|
||||||
|
public List<String> getOutboundDeparture() { |
||||||
|
return Collections.unmodifiableList(outboundDeparture); |
||||||
|
} |
||||||
|
|
||||||
|
public List<String> getOutboundArrival() { |
||||||
|
return Collections.unmodifiableList(outboundArrival); |
||||||
|
} |
||||||
|
|
||||||
|
public List<String> getOutboundOrigin() { |
||||||
|
return Collections.unmodifiableList(outboundOrigin); |
||||||
|
} |
||||||
|
|
||||||
|
public List<String> getOutboundDestination() { |
||||||
|
return Collections.unmodifiableList(outboundDestination); |
||||||
|
} |
||||||
|
|
||||||
|
public Double getInputPrice() { |
||||||
|
return inputPrice; |
||||||
|
} |
||||||
|
|
||||||
|
public Double getInputTax() { |
||||||
|
return inputTax; |
||||||
|
} |
||||||
|
|
||||||
|
public String getInputCurrency() { |
||||||
|
return inputCurrency; |
||||||
|
} |
||||||
|
|
||||||
|
public Integer getStatus() { |
||||||
|
return status; |
||||||
|
} |
||||||
|
|
||||||
|
public Double getOutputPrice() { |
||||||
|
return outputPrice; |
||||||
|
} |
||||||
|
|
||||||
|
public Double getOutputTax() { |
||||||
|
return outputTax; |
||||||
|
} |
||||||
|
|
||||||
|
public String getOutputCurrency() { |
||||||
|
return outputCurrency; |
||||||
|
} |
||||||
|
|
||||||
|
public Integer getDuration() { |
||||||
|
return duration; |
||||||
|
} |
||||||
|
} |
@ -0,0 +1,36 @@ |
|||||||
|
package cz.aprar.bonitoo.inference; |
||||||
|
|
||||||
|
public enum TTL { |
||||||
|
/** |
||||||
|
* Very volatile or unknown data |
||||||
|
*/ |
||||||
|
NOCACHE, |
||||||
|
/** |
||||||
|
* Cache for 12 hours |
||||||
|
*/ |
||||||
|
H12, |
||||||
|
/** |
||||||
|
* Cache for 1 day |
||||||
|
*/ |
||||||
|
D1, |
||||||
|
/** |
||||||
|
* Cache for 2 days |
||||||
|
*/ |
||||||
|
D2, |
||||||
|
/** |
||||||
|
* Cache for 3 days |
||||||
|
*/ |
||||||
|
D3, |
||||||
|
/** |
||||||
|
* Cache for 1 week |
||||||
|
*/ |
||||||
|
D7, |
||||||
|
/** |
||||||
|
* Cache for 2 weeks |
||||||
|
*/ |
||||||
|
D14, |
||||||
|
/** |
||||||
|
* Cache for 30 days |
||||||
|
*/ |
||||||
|
D30; |
||||||
|
} |
@ -0,0 +1,140 @@ |
|||||||
|
package cz.aprar.bonitoo.inference; |
||||||
|
|
||||||
|
import ml.dmlc.xgboost4j.java.XGBoostError; |
||||||
|
import org.testng.annotations.BeforeClass; |
||||||
|
import org.testng.annotations.DataProvider; |
||||||
|
import org.testng.annotations.Test; |
||||||
|
|
||||||
|
import java.io.IOException; |
||||||
|
import java.util.Arrays; |
||||||
|
import java.util.List; |
||||||
|
|
||||||
|
import static java.util.Collections.emptyList; |
||||||
|
import static org.testng.Assert.assertEquals; |
||||||
|
|
||||||
|
public class CacheInferenceTest { |
||||||
|
private CacheInference ci; |
||||||
|
|
||||||
|
@BeforeClass |
||||||
|
void setUp() throws IOException, XGBoostError { |
||||||
|
final ClassLoader classLoader = getClass().getClassLoader(); |
||||||
|
ci = new CacheInference(classLoader.getResourceAsStream("model/xgboost-model.bin"), |
||||||
|
classLoader.getResourceAsStream("model/encoder.json")); |
||||||
|
} |
||||||
|
|
||||||
|
@Test(dataProvider = "inferenceData") |
||||||
|
void testInference(final FlightData input, final TTL expected) throws XGBoostError { |
||||||
|
final TTL result = ci.cacheTTL(input); |
||||||
|
assertEquals(result, expected); |
||||||
|
} |
||||||
|
|
||||||
|
@DataProvider(name = "inferenceData") |
||||||
|
public Object[][] inferenceData() { |
||||||
|
return new Object[][]{ |
||||||
|
{new FlightData( |
||||||
|
"fly-me-to", |
||||||
|
"WS", |
||||||
|
toList("2020-05-09T23:59:00", "2020-05-10T10:30:00"), |
||||||
|
toList("2020-05-10T08:55:00", "2020-05-10T11:30:00"), |
||||||
|
toList("MCO", "FRA"), |
||||||
|
toList("FRA", "PRG"), |
||||||
|
toList("2020-05-01T09:50:00", "2020-05-01T11:55:00"), |
||||||
|
toList("2020-05-01T11:00:00", "2020-05-01T21:55:00"), |
||||||
|
toList("PRG", "FRA"), |
||||||
|
toList("FRA", "MCO"), |
||||||
|
39766.0, |
||||||
|
19776.0, |
||||||
|
"CZK", |
||||||
|
0, |
||||||
|
39766.0, |
||||||
|
19776.0, |
||||||
|
"CZK", |
||||||
|
427 |
||||||
|
), TTL.D1}, |
||||||
|
{new FlightData( |
||||||
|
"fly-me-to", |
||||||
|
"PYTON", |
||||||
|
emptyList(), |
||||||
|
emptyList(), |
||||||
|
emptyList(), |
||||||
|
emptyList(), |
||||||
|
toList("2019-12-18T05:45:00"), |
||||||
|
toList("2019-12-18T08:05:00"), |
||||||
|
toList("KRK"), |
||||||
|
toList("BVA"), |
||||||
|
336.258, |
||||||
|
0.0, |
||||||
|
"CZK", |
||||||
|
0, |
||||||
|
336.258, |
||||||
|
0.0, |
||||||
|
"CZK", |
||||||
|
2284 |
||||||
|
), TTL.D7}, |
||||||
|
{new FlightData( |
||||||
|
"levne", |
||||||
|
"AVIA", |
||||||
|
toList("2020-02-07T02:25:00", "2020-02-07T14:50:00"), |
||||||
|
toList("2020-02-07T13:10:00", "2020-02-07T16:55:00"), |
||||||
|
toList("LAX", "LHR"), |
||||||
|
toList("LHR", "PRG"), |
||||||
|
toList("2020-01-28T10:35:00", "2020-01-28T14:40:00"), |
||||||
|
toList("2020-01-28T12:45:00", "2020-01-29T01:45:00"), |
||||||
|
toList("PRG", "HEL"), |
||||||
|
toList("HEL", "LAX"), |
||||||
|
5971.77978, |
||||||
|
0.0, |
||||||
|
"CZK", |
||||||
|
0, |
||||||
|
15971.77978, |
||||||
|
0.0, |
||||||
|
"CZK", |
||||||
|
551 |
||||||
|
), TTL.D7}, |
||||||
|
{new FlightData( |
||||||
|
"fly-me-to", |
||||||
|
"HH", |
||||||
|
toList("2019-11-01T16:30:00", "2019-11-01T23:35:00"), |
||||||
|
toList("2019-11-01T21:12:00", "2019-11-02T07:45:00"), |
||||||
|
toList("YVR", "YUL"), |
||||||
|
toList("YUL", "VIE"), |
||||||
|
toList("2019-10-18T08:10:00", "2019-10-18T11:30:00"), |
||||||
|
toList("2019-10-18T09:40:00", "2019-10-18T21:25:00"), |
||||||
|
toList("VIE", "FRA"), |
||||||
|
toList("FRA", "YVR"), |
||||||
|
17723.0, |
||||||
|
7708.0, |
||||||
|
"CZK", |
||||||
|
0, |
||||||
|
17723.0, |
||||||
|
7708.0, |
||||||
|
"CZK", |
||||||
|
1786 |
||||||
|
), TTL.D1}, |
||||||
|
{new FlightData( |
||||||
|
"unknown", |
||||||
|
"unknown", |
||||||
|
toList("unknown"), |
||||||
|
toList("unknown"), |
||||||
|
toList("unknown"), |
||||||
|
toList("unknown"), |
||||||
|
toList("unknown"), |
||||||
|
toList("unknown"), |
||||||
|
toList("unknown"), |
||||||
|
toList("unknown"), |
||||||
|
0.0, |
||||||
|
0.0, |
||||||
|
"unknown", |
||||||
|
0, |
||||||
|
0.0, |
||||||
|
0.0, |
||||||
|
"unknown", |
||||||
|
0 |
||||||
|
), TTL.NOCACHE} |
||||||
|
}; |
||||||
|
} |
||||||
|
|
||||||
|
private List<String> toList(final String... data) { |
||||||
|
return Arrays.asList(data); |
||||||
|
} |
||||||
|
} |
File diff suppressed because one or more lines are too long
Binary file not shown.
@ -0,0 +1,33 @@ |
|||||||
|
import sagemaker |
||||||
|
import boto3 |
||||||
|
from sagemaker import get_execution_role |
||||||
|
from sagemaker.xgboost.estimator import XGBoost |
||||||
|
|
||||||
|
boto_session = boto3.Session(profile_name='bonitoo', region_name='eu-central-1') |
||||||
|
|
||||||
|
# local: sagemaker_session = sagemaker.LocalSession(boto_session=boto_session) |
||||||
|
sagemaker_session = sagemaker.Session(boto_session=boto_session) |
||||||
|
|
||||||
|
role = 'Bonitoo_SageMaker_Execution' |
||||||
|
train_input = 's3://customers-bonitoo-cachettl/sagemaker/data/export-reduced.csv' |
||||||
|
|
||||||
|
tf = XGBoost( |
||||||
|
entry_point='train_model.py', |
||||||
|
source_dir='./src', |
||||||
|
train_instance_type='ml.c5.xlarge', |
||||||
|
train_instance_count=1, |
||||||
|
role=role, |
||||||
|
sagemaker_session=sagemaker_session, |
||||||
|
framework_version='0.90-1', |
||||||
|
py_version='py3', |
||||||
|
hyperparameters={ |
||||||
|
'bonitoo_price_limit': 1000, |
||||||
|
'num_round': 15, |
||||||
|
'max_depth': 15, |
||||||
|
'eta': 0.5, |
||||||
|
'num_class': 8, |
||||||
|
'objective': 'multi:softmax', |
||||||
|
'eval_metric': 'mlogloss' |
||||||
|
}) |
||||||
|
|
||||||
|
tf.fit({'training': train_input}) |
Loading…
Reference in new issue