parent
76645f273d
commit
d675ca673a
10 changed files with 544 additions and 154 deletions
@ -0,0 +1,58 @@ |
||||
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" |
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd"> |
||||
<modelVersion>4.0.0</modelVersion> |
||||
<groupId>cz.aprar.bonitoo</groupId> |
||||
<artifactId>inference</artifactId> |
||||
<packaging>jar</packaging> |
||||
<version>1.0-SNAPSHOT</version> |
||||
<name>inference</name> |
||||
|
||||
<properties> |
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> |
||||
</properties> |
||||
|
||||
<dependencies> |
||||
<dependency> |
||||
<groupId>ml.dmlc</groupId> |
||||
<artifactId>xgboost4j</artifactId> |
||||
<version>0.90</version> |
||||
</dependency> |
||||
|
||||
<dependency> |
||||
<groupId>com.fasterxml.jackson.core</groupId> |
||||
<artifactId>jackson-databind</artifactId> |
||||
<version>2.9.9</version> |
||||
</dependency> |
||||
|
||||
<dependency> |
||||
<groupId>org.testng</groupId> |
||||
<artifactId>testng</artifactId> |
||||
<version>7.0.0</version> |
||||
<scope>test</scope> |
||||
</dependency> |
||||
</dependencies> |
||||
|
||||
<build> |
||||
<plugins> |
||||
<plugin> |
||||
<groupId>org.apache.maven.plugins</groupId> |
||||
<artifactId>maven-compiler-plugin</artifactId> |
||||
<version>3.8.1</version> |
||||
<configuration> |
||||
<source>8</source> |
||||
<target>8</target> |
||||
</configuration> |
||||
</plugin> |
||||
|
||||
<plugin> |
||||
<groupId>org.apache.maven.plugins</groupId> |
||||
<artifactId>maven-pmd-plugin</artifactId> |
||||
<version>3.12.0</version> |
||||
<configuration> |
||||
<sourceEncoding>utf-8</sourceEncoding> |
||||
<targetJdk>1.8</targetJdk> |
||||
</configuration> |
||||
</plugin> |
||||
</plugins> |
||||
</build> |
||||
</project> |
@ -0,0 +1,81 @@ |
||||
package cz.aprar.bonitoo.inference; |
||||
|
||||
import com.fasterxml.jackson.core.type.TypeReference; |
||||
import com.fasterxml.jackson.databind.ObjectMapper; |
||||
import ml.dmlc.xgboost4j.java.Booster; |
||||
import ml.dmlc.xgboost4j.java.DMatrix; |
||||
import ml.dmlc.xgboost4j.java.XGBoost; |
||||
import ml.dmlc.xgboost4j.java.XGBoostError; |
||||
|
||||
import java.io.BufferedInputStream; |
||||
import java.io.IOException; |
||||
import java.io.InputStream; |
||||
import java.nio.file.Files; |
||||
import java.nio.file.Path; |
||||
import java.util.List; |
||||
import java.util.Map; |
||||
|
||||
public class CacheInference { |
||||
private static final TTL[] TTL_VALUES = TTL.values(); |
||||
|
||||
private final ObjectMapper mapper = new ObjectMapper(); |
||||
private final Booster booster; |
||||
private final Map<String, Map<String, Integer>> labels; |
||||
|
||||
public CacheInference(final Path modelFile, final Path labelFile) throws IOException, XGBoostError { |
||||
try (BufferedInputStream modelIS = new BufferedInputStream(Files.newInputStream(modelFile)); |
||||
BufferedInputStream labelIS = new BufferedInputStream(Files.newInputStream(labelFile)) |
||||
) { |
||||
booster = loadModel(modelIS); |
||||
labels = loadLabels(labelIS); |
||||
} |
||||
} |
||||
|
||||
public CacheInference(InputStream modelIS, InputStream labelIS) throws IOException, XGBoostError { |
||||
booster = loadModel(modelIS); |
||||
labels = loadLabels(labelIS); |
||||
} |
||||
|
||||
public TTL cacheTTL(final FlightData data) throws XGBoostError { |
||||
float[][] predicts = booster.predict(createMatrix(data)); |
||||
return TTL_VALUES[(int) predicts[0][0]]; |
||||
} |
||||
|
||||
private Booster loadModel(InputStream model) throws XGBoostError, IOException { |
||||
return XGBoost.loadModel(model); |
||||
} |
||||
|
||||
private DMatrix createMatrix(final FlightData data) throws XGBoostError { |
||||
final float[] arr = new float[18]; |
||||
arr[0] = labels.get("client.channel").getOrDefault(data.getClientChannel(), 0); |
||||
arr[1] = labels.get("type").getOrDefault(data.getType(), 0); |
||||
arr[2] = labels.get("flight.inboundSegments.departure").getOrDefault(joinList(data.getInboundDeparture()), 0); |
||||
arr[3] = labels.get("flight.inboundSegments.arrival").getOrDefault(joinList(data.getInboundArrival()), 0); |
||||
arr[4] = labels.get("flight.inboundSegments.origin.airportCode").getOrDefault(joinList(data.getInboundOrigin()), 0); |
||||
arr[5] = labels.get("flight.inboundSegments.destination.airportCode").getOrDefault(joinList(data.getInboundDestination()), 0); |
||||
arr[6] = labels.get("flight.outboundSegments.departure").getOrDefault(joinList(data.getOutboundDeparture()), 0); |
||||
arr[7] = labels.get("flight.outboundSegments.arrival").getOrDefault(joinList(data.getOutboundArrival()), 0); |
||||
arr[8] = labels.get("flight.outboundSegments.origin.airportCode").getOrDefault(joinList(data.getOutboundOrigin()), 0); |
||||
arr[9] = labels.get("flight.outboundSegments.destination.airportCode").getOrDefault(joinList(data.getOutboundDestination()), 0); |
||||
arr[10] = data.getInputPrice().floatValue(); |
||||
arr[11] = data.getInputTax().floatValue(); |
||||
arr[12] = labels.get("input.currency").getOrDefault(data.getInputCurrency(), 0); |
||||
arr[13] = data.getStatus().floatValue(); |
||||
arr[14] = data.getOutputPrice().floatValue(); |
||||
arr[15] = data.getOutputTax().floatValue(); |
||||
arr[16] = labels.get("output.currency").getOrDefault(data.getOutputCurrency(), 0); |
||||
arr[17] = data.getDuration().floatValue(); |
||||
|
||||
return new DMatrix(arr, 1, arr.length); |
||||
} |
||||
|
||||
private Map<String, Map<String, Integer>> loadLabels(InputStream labels) throws IOException { |
||||
final TypeReference<Map<String, Map<String, Integer>>> typeRef = new TypeReference<Map<String, Map<String, Integer>>>() { |
||||
}; |
||||
return mapper.readValue(labels, typeRef); |
||||
} |
||||
|
||||
private String joinList(final List<String> data) { |
||||
return String.join("|", data); |
||||
} |
||||
} |
@ -0,0 +1,123 @@ |
||||
package cz.aprar.bonitoo.inference; |
||||
|
||||
import java.util.Collections; |
||||
import java.util.List; |
||||
|
||||
public class FlightData { |
||||
private final String clientChannel; |
||||
private final String type; |
||||
private final List<String> inboundDeparture; |
||||
private final List<String> inboundArrival; |
||||
private final List<String> inboundOrigin; |
||||
private final List<String> inboundDestination; |
||||
private final List<String> outboundDeparture; |
||||
private final List<String> outboundArrival; |
||||
private final List<String> outboundOrigin; |
||||
private final List<String> outboundDestination; |
||||
private final Double inputPrice; |
||||
private final Double inputTax; |
||||
private final String inputCurrency; |
||||
private final Integer status; |
||||
private final Double outputPrice; |
||||
private final Double outputTax; |
||||
private final String outputCurrency; |
||||
private final Integer duration; |
||||
|
||||
public FlightData(final String clientChannel, final String type, |
||||
final List<String> inboundDeparture, final List<String> inboundArrival, final List<String> inboundOrigin, |
||||
final List<String> inboundDestination, final List<String> outboundDeparture, final List<String> outboundArrival, |
||||
final List<String> outboundOrigin, final List<String> outboundDestination, final Double inputPrice, |
||||
final Double inputTax, final String inputCurrency, final Integer status, final Double outputPrice, |
||||
final Double outputTax, final String outputCurrency, final Integer duration) { |
||||
this.clientChannel = clientChannel; |
||||
this.type = type; |
||||
this.inboundDeparture = inboundDeparture; |
||||
this.inboundArrival = inboundArrival; |
||||
this.inboundOrigin = inboundOrigin; |
||||
this.inboundDestination = inboundDestination; |
||||
this.outboundDeparture = outboundDeparture; |
||||
this.outboundArrival = outboundArrival; |
||||
this.outboundOrigin = outboundOrigin; |
||||
this.outboundDestination = outboundDestination; |
||||
this.inputPrice = inputPrice; |
||||
this.inputTax = inputTax; |
||||
this.inputCurrency = inputCurrency; |
||||
this.status = status; |
||||
this.outputPrice = outputPrice; |
||||
this.outputTax = outputTax; |
||||
this.outputCurrency = outputCurrency; |
||||
this.duration = duration; |
||||
} |
||||
|
||||
public String getClientChannel() { |
||||
return clientChannel; |
||||
} |
||||
|
||||
public String getType() { |
||||
return type; |
||||
} |
||||
|
||||
public List<String> getInboundDeparture() { |
||||
return Collections.unmodifiableList(inboundDeparture); |
||||
} |
||||
|
||||
public List<String> getInboundArrival() { |
||||
return Collections.unmodifiableList(inboundArrival); |
||||
} |
||||
|
||||
public List<String> getInboundOrigin() { |
||||
return Collections.unmodifiableList(inboundOrigin); |
||||
} |
||||
|
||||
public List<String> getInboundDestination() { |
||||
return Collections.unmodifiableList(inboundDestination); |
||||
} |
||||
|
||||
public List<String> getOutboundDeparture() { |
||||
return Collections.unmodifiableList(outboundDeparture); |
||||
} |
||||
|
||||
public List<String> getOutboundArrival() { |
||||
return Collections.unmodifiableList(outboundArrival); |
||||
} |
||||
|
||||
public List<String> getOutboundOrigin() { |
||||
return Collections.unmodifiableList(outboundOrigin); |
||||
} |
||||
|
||||
public List<String> getOutboundDestination() { |
||||
return Collections.unmodifiableList(outboundDestination); |
||||
} |
||||
|
||||
public Double getInputPrice() { |
||||
return inputPrice; |
||||
} |
||||
|
||||
public Double getInputTax() { |
||||
return inputTax; |
||||
} |
||||
|
||||
public String getInputCurrency() { |
||||
return inputCurrency; |
||||
} |
||||
|
||||
public Integer getStatus() { |
||||
return status; |
||||
} |
||||
|
||||
public Double getOutputPrice() { |
||||
return outputPrice; |
||||
} |
||||
|
||||
public Double getOutputTax() { |
||||
return outputTax; |
||||
} |
||||
|
||||
public String getOutputCurrency() { |
||||
return outputCurrency; |
||||
} |
||||
|
||||
public Integer getDuration() { |
||||
return duration; |
||||
} |
||||
} |
@ -0,0 +1,36 @@ |
||||
package cz.aprar.bonitoo.inference; |
||||
|
||||
public enum TTL { |
||||
/** |
||||
* Very volatile or unknown data |
||||
*/ |
||||
NOCACHE, |
||||
/** |
||||
* Cache for 12 hours |
||||
*/ |
||||
H12, |
||||
/** |
||||
* Cache for 1 day |
||||
*/ |
||||
D1, |
||||
/** |
||||
* Cache for 2 days |
||||
*/ |
||||
D2, |
||||
/** |
||||
* Cache for 3 days |
||||
*/ |
||||
D3, |
||||
/** |
||||
* Cache for 1 week |
||||
*/ |
||||
D7, |
||||
/** |
||||
* Cache for 2 weeks |
||||
*/ |
||||
D14, |
||||
/** |
||||
* Cache for 30 days |
||||
*/ |
||||
D30; |
||||
} |
@ -0,0 +1,140 @@ |
||||
package cz.aprar.bonitoo.inference; |
||||
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError; |
||||
import org.testng.annotations.BeforeClass; |
||||
import org.testng.annotations.DataProvider; |
||||
import org.testng.annotations.Test; |
||||
|
||||
import java.io.IOException; |
||||
import java.util.Arrays; |
||||
import java.util.List; |
||||
|
||||
import static java.util.Collections.emptyList; |
||||
import static org.testng.Assert.assertEquals; |
||||
|
||||
public class CacheInferenceTest { |
||||
private CacheInference ci; |
||||
|
||||
@BeforeClass |
||||
void setUp() throws IOException, XGBoostError { |
||||
final ClassLoader classLoader = getClass().getClassLoader(); |
||||
ci = new CacheInference(classLoader.getResourceAsStream("model/xgboost-model.bin"), |
||||
classLoader.getResourceAsStream("model/encoder.json")); |
||||
} |
||||
|
||||
@Test(dataProvider = "inferenceData") |
||||
void testInference(final FlightData input, final TTL expected) throws XGBoostError { |
||||
final TTL result = ci.cacheTTL(input); |
||||
assertEquals(result, expected); |
||||
} |
||||
|
||||
@DataProvider(name = "inferenceData") |
||||
public Object[][] inferenceData() { |
||||
return new Object[][]{ |
||||
{new FlightData( |
||||
"fly-me-to", |
||||
"WS", |
||||
toList("2020-05-09T23:59:00", "2020-05-10T10:30:00"), |
||||
toList("2020-05-10T08:55:00", "2020-05-10T11:30:00"), |
||||
toList("MCO", "FRA"), |
||||
toList("FRA", "PRG"), |
||||
toList("2020-05-01T09:50:00", "2020-05-01T11:55:00"), |
||||
toList("2020-05-01T11:00:00", "2020-05-01T21:55:00"), |
||||
toList("PRG", "FRA"), |
||||
toList("FRA", "MCO"), |
||||
39766.0, |
||||
19776.0, |
||||
"CZK", |
||||
0, |
||||
39766.0, |
||||
19776.0, |
||||
"CZK", |
||||
427 |
||||
), TTL.D1}, |
||||
{new FlightData( |
||||
"fly-me-to", |
||||
"PYTON", |
||||
emptyList(), |
||||
emptyList(), |
||||
emptyList(), |
||||
emptyList(), |
||||
toList("2019-12-18T05:45:00"), |
||||
toList("2019-12-18T08:05:00"), |
||||
toList("KRK"), |
||||
toList("BVA"), |
||||
336.258, |
||||
0.0, |
||||
"CZK", |
||||
0, |
||||
336.258, |
||||
0.0, |
||||
"CZK", |
||||
2284 |
||||
), TTL.D7}, |
||||
{new FlightData( |
||||
"levne", |
||||
"AVIA", |
||||
toList("2020-02-07T02:25:00", "2020-02-07T14:50:00"), |
||||
toList("2020-02-07T13:10:00", "2020-02-07T16:55:00"), |
||||
toList("LAX", "LHR"), |
||||
toList("LHR", "PRG"), |
||||
toList("2020-01-28T10:35:00", "2020-01-28T14:40:00"), |
||||
toList("2020-01-28T12:45:00", "2020-01-29T01:45:00"), |
||||
toList("PRG", "HEL"), |
||||
toList("HEL", "LAX"), |
||||
5971.77978, |
||||
0.0, |
||||
"CZK", |
||||
0, |
||||
15971.77978, |
||||
0.0, |
||||
"CZK", |
||||
551 |
||||
), TTL.D7}, |
||||
{new FlightData( |
||||
"fly-me-to", |
||||
"HH", |
||||
toList("2019-11-01T16:30:00", "2019-11-01T23:35:00"), |
||||
toList("2019-11-01T21:12:00", "2019-11-02T07:45:00"), |
||||
toList("YVR", "YUL"), |
||||
toList("YUL", "VIE"), |
||||
toList("2019-10-18T08:10:00", "2019-10-18T11:30:00"), |
||||
toList("2019-10-18T09:40:00", "2019-10-18T21:25:00"), |
||||
toList("VIE", "FRA"), |
||||
toList("FRA", "YVR"), |
||||
17723.0, |
||||
7708.0, |
||||
"CZK", |
||||
0, |
||||
17723.0, |
||||
7708.0, |
||||
"CZK", |
||||
1786 |
||||
), TTL.D1}, |
||||
{new FlightData( |
||||
"unknown", |
||||
"unknown", |
||||
toList("unknown"), |
||||
toList("unknown"), |
||||
toList("unknown"), |
||||
toList("unknown"), |
||||
toList("unknown"), |
||||
toList("unknown"), |
||||
toList("unknown"), |
||||
toList("unknown"), |
||||
0.0, |
||||
0.0, |
||||
"unknown", |
||||
0, |
||||
0.0, |
||||
0.0, |
||||
"unknown", |
||||
0 |
||||
), TTL.NOCACHE} |
||||
}; |
||||
} |
||||
|
||||
private List<String> toList(final String... data) { |
||||
return Arrays.asList(data); |
||||
} |
||||
} |
File diff suppressed because one or more lines are too long
Binary file not shown.
@ -0,0 +1,33 @@ |
||||
import sagemaker |
||||
import boto3 |
||||
from sagemaker import get_execution_role |
||||
from sagemaker.xgboost.estimator import XGBoost |
||||
|
||||
boto_session = boto3.Session(profile_name='bonitoo', region_name='eu-central-1') |
||||
|
||||
# local: sagemaker_session = sagemaker.LocalSession(boto_session=boto_session) |
||||
sagemaker_session = sagemaker.Session(boto_session=boto_session) |
||||
|
||||
role = 'Bonitoo_SageMaker_Execution' |
||||
train_input = 's3://customers-bonitoo-cachettl/sagemaker/data/export-reduced.csv' |
||||
|
||||
tf = XGBoost( |
||||
entry_point='train_model.py', |
||||
source_dir='./src', |
||||
train_instance_type='ml.c5.xlarge', |
||||
train_instance_count=1, |
||||
role=role, |
||||
sagemaker_session=sagemaker_session, |
||||
framework_version='0.90-1', |
||||
py_version='py3', |
||||
hyperparameters={ |
||||
'bonitoo_price_limit': 1000, |
||||
'num_round': 15, |
||||
'max_depth': 15, |
||||
'eta': 0.5, |
||||
'num_class': 8, |
||||
'objective': 'multi:softmax', |
||||
'eval_metric': 'mlogloss' |
||||
}) |
||||
|
||||
tf.fit({'training': train_input}) |
Loading…
Reference in new issue