parent
76645f273d
commit
d675ca673a
10 changed files with 544 additions and 154 deletions
@ -1,115 +1,76 @@ |
||||
from pymongo import MongoClient |
||||
from pprint import pprint |
||||
from datetime import datetime |
||||
import csv |
||||
|
||||
client = MongoClient() |
||||
db=client.bonitoo |
||||
db = client.bonitoo |
||||
|
||||
fieldnames = [ |
||||
'timestamp', |
||||
'client.channel', |
||||
'type', |
||||
'flight.inboundSegments.departure', |
||||
'flight.inboundSegments.arrival', |
||||
'flight.inboundSegments.origin.airportCode', |
||||
'flight.inboundSegments.destination.airportCode', |
||||
'flight.inboundSegments.flightNumber', |
||||
'flight.inboundSegments.travelClass', |
||||
'flight.inboundSegments.bookingCode', |
||||
'flight.inboundSegments.availability', |
||||
'flight.inboundSegments.elapsedFlyingTime', |
||||
'flight.outboundSegments.departure', |
||||
'flight.outboundSegments.arrival', |
||||
'flight.outboundSegments.origin.airportCode', |
||||
'flight.outboundSegments.destination.airportCode', |
||||
'flight.outboundSegments.flightNumber', |
||||
'flight.outboundSegments.travelClass', |
||||
'flight.outboundSegments.bookingCode', |
||||
'flight.outboundSegments.availability', |
||||
'flight.outboundSegments.elapsedFlyingTime', |
||||
'flight.inboundEFT', # elapsed flying time |
||||
'flight.outboundEFT', |
||||
'oneWay', |
||||
'adults', # pocet osob = (adults + children) |
||||
'children', |
||||
'infants', |
||||
'input.price', |
||||
'input.tax', |
||||
'input.currency', |
||||
'success', |
||||
'status', |
||||
'output.price', |
||||
'output.tax', |
||||
'output.currency', |
||||
'duration' # delka volani do nadrazeneho systemu |
||||
'timestamp', |
||||
'client.channel', |
||||
'type', |
||||
'flight.inboundSegments.departure', |
||||
'flight.inboundSegments.arrival', |
||||
'flight.inboundSegments.origin.airportCode', |
||||
'flight.inboundSegments.destination.airportCode', |
||||
'flight.outboundSegments.departure', |
||||
'flight.outboundSegments.arrival', |
||||
'flight.outboundSegments.origin.airportCode', |
||||
'flight.outboundSegments.destination.airportCode', |
||||
'input.price', |
||||
'input.tax', |
||||
'input.currency', |
||||
'success', |
||||
'status', |
||||
'output.price', |
||||
'output.tax', |
||||
'output.currency', |
||||
'duration' |
||||
] |
||||
|
||||
# 5% nebo 200 kc rozdil nahoru |
||||
# -200 kc dolu |
||||
# abs(+-10kc) ignorovat |
||||
|
||||
# timestamp + ok price - ma byt v cache od cacheat |
||||
# timestamp + notok price - nema byt v cache od cacheat |
||||
|
||||
# delka pobytu prilet-odlet |
||||
# delka letu ? |
||||
# pokud je chyba tak nocache (= chybi priceout) |
||||
|
||||
# brat v uvahu in/out kody aerolinek (mcx ?) - mirek jeste zjisti |
||||
|
||||
# vypocitat uspesnost je/neni v cache v % |
||||
|
||||
counter = 0 |
||||
with open('export.csv', mode='w') as ef: |
||||
writer = csv.DictWriter(ef, fieldnames=fieldnames, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL) |
||||
# do not write header for s3 files |
||||
# writer.writeheader() |
||||
writer = csv.DictWriter(ef, fieldnames=fieldnames, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL) |
||||
# do not write header for s3 files |
||||
# writer.writeheader() |
||||
|
||||
for it in db.pricing_audit.find(): |
||||
counter += 1 |
||||
if counter % 1000 == 0: |
||||
print('Iterace %d' % counter) |
||||
d = { |
||||
'timestamp': datetime.fromtimestamp(it['timestamp']/1000).isoformat(), |
||||
'client.channel': it['client']['channel'], |
||||
'type': it['type'], |
||||
'flight.outboundSegments.departure': '|'.join([x['departure'].isoformat() for x in it['flight']['outboundSegments']]), |
||||
'flight.outboundSegments.arrival': '|'.join([x['arrival'].isoformat() for x in it['flight']['outboundSegments']]), |
||||
'flight.outboundSegments.origin.airportCode': '|'.join([x['origin']['airportCode'] for x in it['flight']['outboundSegments']]), |
||||
'flight.outboundSegments.destination.airportCode': '|'.join([x['destination']['airportCode'] for x in it['flight']['outboundSegments']]), |
||||
'flight.outboundSegments.flightNumber': '|'.join([x['flightNumber'] for x in it['flight']['outboundSegments']]), |
||||
'flight.outboundSegments.travelClass': '|'.join([x['travelClass'] for x in it['flight']['outboundSegments']]), |
||||
'flight.outboundSegments.bookingCode': '|'.join([x.get('bookingCode','') for x in it['flight']['outboundSegments']]), |
||||
'flight.outboundSegments.availability': '|'.join([str(x.get('availability','')) for x in it['flight']['outboundSegments']]), |
||||
'flight.outboundSegments.elapsedFlyingTime': '|'.join([str(x.get('elapsedFlyingTime','')) for x in it['flight']['outboundSegments']]), |
||||
'flight.inboundEFT': it['flight'].get('inboundEFT',''), |
||||
'flight.outboundEFT': it['flight'].get('outboundEFT',''), |
||||
'oneWay': it['oneWay'], |
||||
'adults': it['adults'], |
||||
'children': it['children'], |
||||
'infants': it['infants'], |
||||
'input.price': it['input']['price'], |
||||
'input.tax': it['input']['tax'], |
||||
'input.currency': it['input']['currency'], |
||||
'success': it['success'], |
||||
'status': it.get('status',''), |
||||
'output.price': it.get('output', {'price': 0})['price'], |
||||
'output.tax': it.get('output', {'tax': 0})['tax'], |
||||
'output.currency': it.get('output', {'currency': 0})['currency'], |
||||
'duration': it['duration'] |
||||
} |
||||
for it in db.pricing_audit.find(): |
||||
counter += 1 |
||||
if counter % 1000 == 0: |
||||
print('Iterace %d' % counter) |
||||
d = { |
||||
'timestamp': datetime.fromtimestamp(it['timestamp'] / 1000).isoformat(), |
||||
'client.channel': it['client']['channel'], |
||||
'type': it['type'], |
||||
'flight.outboundSegments.departure': '|'.join( |
||||
[x['departure'].isoformat() for x in it['flight']['outboundSegments']]), |
||||
'flight.outboundSegments.arrival': '|'.join( |
||||
[x['arrival'].isoformat() for x in it['flight']['outboundSegments']]), |
||||
'flight.outboundSegments.origin.airportCode': '|'.join( |
||||
[x['origin']['airportCode'] for x in it['flight']['outboundSegments']]), |
||||
'flight.outboundSegments.destination.airportCode': '|'.join( |
||||
[x['destination']['airportCode'] for x in it['flight']['outboundSegments']]), |
||||
'input.price': it['input']['price'], |
||||
'input.tax': it['input']['tax'], |
||||
'input.currency': it['input']['currency'], |
||||
'success': it['success'], |
||||
'status': it.get('status', ''), |
||||
'output.price': it.get('output', {'price': 0})['price'], |
||||
'output.tax': it.get('output', {'tax': 0})['tax'], |
||||
'output.currency': it.get('output', {'currency': 0})['currency'], |
||||
'duration': it['duration'] |
||||
} |
||||
|
||||
if 'inboundSegments' in it['flight']: |
||||
inb = { |
||||
'flight.inboundSegments.departure': '|'.join([x['departure'].isoformat() for x in it['flight']['inboundSegments']]), |
||||
'flight.inboundSegments.arrival': '|'.join([x['arrival'].isoformat() for x in it['flight']['inboundSegments']]), |
||||
'flight.inboundSegments.origin.airportCode': '|'.join([x['origin']['airportCode'] for x in it['flight']['inboundSegments']]), |
||||
'flight.inboundSegments.destination.airportCode': '|'.join([x['destination']['airportCode'] for x in it['flight']['inboundSegments']]), |
||||
'flight.inboundSegments.flightNumber': '|'.join([x['flightNumber'] for x in it['flight']['inboundSegments']]), |
||||
'flight.inboundSegments.travelClass': '|'.join([x['travelClass'] for x in it['flight']['inboundSegments']]), |
||||
'flight.inboundSegments.bookingCode': '|'.join([x.get('bookingCode', '') for x in it['flight']['inboundSegments']]), |
||||
'flight.inboundSegments.availability': '|'.join([str(x.get('availability','')) for x in it['flight']['inboundSegments']]), |
||||
'flight.inboundSegments.elapsedFlyingTime': '|'.join([str(x.get('elapsedFlyingTime','')) for x in it['flight']['inboundSegments']]) |
||||
} |
||||
d = {**d, **inb} |
||||
writer.writerow(d) |
||||
if 'inboundSegments' in it['flight']: |
||||
inb = { |
||||
'flight.inboundSegments.departure': '|'.join( |
||||
[x['departure'].isoformat() for x in it['flight']['inboundSegments']]), |
||||
'flight.inboundSegments.arrival': '|'.join( |
||||
[x['arrival'].isoformat() for x in it['flight']['inboundSegments']]), |
||||
'flight.inboundSegments.origin.airportCode': '|'.join( |
||||
[x['origin']['airportCode'] for x in it['flight']['inboundSegments']]), |
||||
'flight.inboundSegments.destination.airportCode': '|'.join( |
||||
[x['destination']['airportCode'] for x in it['flight']['inboundSegments']]), |
||||
} |
||||
d = {**d, **inb} |
||||
writer.writerow(d) |
||||
|
@ -0,0 +1,58 @@ |
||||
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" |
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd"> |
||||
<modelVersion>4.0.0</modelVersion> |
||||
<groupId>cz.aprar.bonitoo</groupId> |
||||
<artifactId>inference</artifactId> |
||||
<packaging>jar</packaging> |
||||
<version>1.0-SNAPSHOT</version> |
||||
<name>inference</name> |
||||
|
||||
<properties> |
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> |
||||
</properties> |
||||
|
||||
<dependencies> |
||||
<dependency> |
||||
<groupId>ml.dmlc</groupId> |
||||
<artifactId>xgboost4j</artifactId> |
||||
<version>0.90</version> |
||||
</dependency> |
||||
|
||||
<dependency> |
||||
<groupId>com.fasterxml.jackson.core</groupId> |
||||
<artifactId>jackson-databind</artifactId> |
||||
<version>2.9.9</version> |
||||
</dependency> |
||||
|
||||
<dependency> |
||||
<groupId>org.testng</groupId> |
||||
<artifactId>testng</artifactId> |
||||
<version>7.0.0</version> |
||||
<scope>test</scope> |
||||
</dependency> |
||||
</dependencies> |
||||
|
||||
<build> |
||||
<plugins> |
||||
<plugin> |
||||
<groupId>org.apache.maven.plugins</groupId> |
||||
<artifactId>maven-compiler-plugin</artifactId> |
||||
<version>3.8.1</version> |
||||
<configuration> |
||||
<source>8</source> |
||||
<target>8</target> |
||||
</configuration> |
||||
</plugin> |
||||
|
||||
<plugin> |
||||
<groupId>org.apache.maven.plugins</groupId> |
||||
<artifactId>maven-pmd-plugin</artifactId> |
||||
<version>3.12.0</version> |
||||
<configuration> |
||||
<sourceEncoding>utf-8</sourceEncoding> |
||||
<targetJdk>1.8</targetJdk> |
||||
</configuration> |
||||
</plugin> |
||||
</plugins> |
||||
</build> |
||||
</project> |
@ -0,0 +1,81 @@ |
||||
package cz.aprar.bonitoo.inference; |
||||
|
||||
import com.fasterxml.jackson.core.type.TypeReference; |
||||
import com.fasterxml.jackson.databind.ObjectMapper; |
||||
import ml.dmlc.xgboost4j.java.Booster; |
||||
import ml.dmlc.xgboost4j.java.DMatrix; |
||||
import ml.dmlc.xgboost4j.java.XGBoost; |
||||
import ml.dmlc.xgboost4j.java.XGBoostError; |
||||
|
||||
import java.io.BufferedInputStream; |
||||
import java.io.IOException; |
||||
import java.io.InputStream; |
||||
import java.nio.file.Files; |
||||
import java.nio.file.Path; |
||||
import java.util.List; |
||||
import java.util.Map; |
||||
|
||||
public class CacheInference { |
||||
private static final TTL[] TTL_VALUES = TTL.values(); |
||||
|
||||
private final ObjectMapper mapper = new ObjectMapper(); |
||||
private final Booster booster; |
||||
private final Map<String, Map<String, Integer>> labels; |
||||
|
||||
public CacheInference(final Path modelFile, final Path labelFile) throws IOException, XGBoostError { |
||||
try (BufferedInputStream modelIS = new BufferedInputStream(Files.newInputStream(modelFile)); |
||||
BufferedInputStream labelIS = new BufferedInputStream(Files.newInputStream(labelFile)) |
||||
) { |
||||
booster = loadModel(modelIS); |
||||
labels = loadLabels(labelIS); |
||||
} |
||||
} |
||||
|
||||
public CacheInference(InputStream modelIS, InputStream labelIS) throws IOException, XGBoostError { |
||||
booster = loadModel(modelIS); |
||||
labels = loadLabels(labelIS); |
||||
} |
||||
|
||||
public TTL cacheTTL(final FlightData data) throws XGBoostError { |
||||
float[][] predicts = booster.predict(createMatrix(data)); |
||||
return TTL_VALUES[(int) predicts[0][0]]; |
||||
} |
||||
|
||||
private Booster loadModel(InputStream model) throws XGBoostError, IOException { |
||||
return XGBoost.loadModel(model); |
||||
} |
||||
|
||||
private DMatrix createMatrix(final FlightData data) throws XGBoostError { |
||||
final float[] arr = new float[18]; |
||||
arr[0] = labels.get("client.channel").getOrDefault(data.getClientChannel(), 0); |
||||
arr[1] = labels.get("type").getOrDefault(data.getType(), 0); |
||||
arr[2] = labels.get("flight.inboundSegments.departure").getOrDefault(joinList(data.getInboundDeparture()), 0); |
||||
arr[3] = labels.get("flight.inboundSegments.arrival").getOrDefault(joinList(data.getInboundArrival()), 0); |
||||
arr[4] = labels.get("flight.inboundSegments.origin.airportCode").getOrDefault(joinList(data.getInboundOrigin()), 0); |
||||
arr[5] = labels.get("flight.inboundSegments.destination.airportCode").getOrDefault(joinList(data.getInboundDestination()), 0); |
||||
arr[6] = labels.get("flight.outboundSegments.departure").getOrDefault(joinList(data.getOutboundDeparture()), 0); |
||||
arr[7] = labels.get("flight.outboundSegments.arrival").getOrDefault(joinList(data.getOutboundArrival()), 0); |
||||
arr[8] = labels.get("flight.outboundSegments.origin.airportCode").getOrDefault(joinList(data.getOutboundOrigin()), 0); |
||||
arr[9] = labels.get("flight.outboundSegments.destination.airportCode").getOrDefault(joinList(data.getOutboundDestination()), 0); |
||||
arr[10] = data.getInputPrice().floatValue(); |
||||
arr[11] = data.getInputTax().floatValue(); |
||||
arr[12] = labels.get("input.currency").getOrDefault(data.getInputCurrency(), 0); |
||||
arr[13] = data.getStatus().floatValue(); |
||||
arr[14] = data.getOutputPrice().floatValue(); |
||||
arr[15] = data.getOutputTax().floatValue(); |
||||
arr[16] = labels.get("output.currency").getOrDefault(data.getOutputCurrency(), 0); |
||||
arr[17] = data.getDuration().floatValue(); |
||||
|
||||
return new DMatrix(arr, 1, arr.length); |
||||
} |
||||
|
||||
private Map<String, Map<String, Integer>> loadLabels(InputStream labels) throws IOException { |
||||
final TypeReference<Map<String, Map<String, Integer>>> typeRef = new TypeReference<Map<String, Map<String, Integer>>>() { |
||||
}; |
||||
return mapper.readValue(labels, typeRef); |
||||
} |
||||
|
||||
private String joinList(final List<String> data) { |
||||
return String.join("|", data); |
||||
} |
||||
} |
@ -0,0 +1,123 @@ |
||||
package cz.aprar.bonitoo.inference; |
||||
|
||||
import java.util.Collections; |
||||
import java.util.List; |
||||
|
||||
public class FlightData { |
||||
private final String clientChannel; |
||||
private final String type; |
||||
private final List<String> inboundDeparture; |
||||
private final List<String> inboundArrival; |
||||
private final List<String> inboundOrigin; |
||||
private final List<String> inboundDestination; |
||||
private final List<String> outboundDeparture; |
||||
private final List<String> outboundArrival; |
||||
private final List<String> outboundOrigin; |
||||
private final List<String> outboundDestination; |
||||
private final Double inputPrice; |
||||
private final Double inputTax; |
||||
private final String inputCurrency; |
||||
private final Integer status; |
||||
private final Double outputPrice; |
||||
private final Double outputTax; |
||||
private final String outputCurrency; |
||||
private final Integer duration; |
||||
|
||||
public FlightData(final String clientChannel, final String type, |
||||
final List<String> inboundDeparture, final List<String> inboundArrival, final List<String> inboundOrigin, |
||||
final List<String> inboundDestination, final List<String> outboundDeparture, final List<String> outboundArrival, |
||||
final List<String> outboundOrigin, final List<String> outboundDestination, final Double inputPrice, |
||||
final Double inputTax, final String inputCurrency, final Integer status, final Double outputPrice, |
||||
final Double outputTax, final String outputCurrency, final Integer duration) { |
||||
this.clientChannel = clientChannel; |
||||
this.type = type; |
||||
this.inboundDeparture = inboundDeparture; |
||||
this.inboundArrival = inboundArrival; |
||||
this.inboundOrigin = inboundOrigin; |
||||
this.inboundDestination = inboundDestination; |
||||
this.outboundDeparture = outboundDeparture; |
||||
this.outboundArrival = outboundArrival; |
||||
this.outboundOrigin = outboundOrigin; |
||||
this.outboundDestination = outboundDestination; |
||||
this.inputPrice = inputPrice; |
||||
this.inputTax = inputTax; |
||||
this.inputCurrency = inputCurrency; |
||||
this.status = status; |
||||
this.outputPrice = outputPrice; |
||||
this.outputTax = outputTax; |
||||
this.outputCurrency = outputCurrency; |
||||
this.duration = duration; |
||||
} |
||||
|
||||
public String getClientChannel() { |
||||
return clientChannel; |
||||
} |
||||
|
||||
public String getType() { |
||||
return type; |
||||
} |
||||
|
||||
public List<String> getInboundDeparture() { |
||||
return Collections.unmodifiableList(inboundDeparture); |
||||
} |
||||
|
||||
public List<String> getInboundArrival() { |
||||
return Collections.unmodifiableList(inboundArrival); |
||||
} |
||||
|
||||
public List<String> getInboundOrigin() { |
||||
return Collections.unmodifiableList(inboundOrigin); |
||||
} |
||||
|
||||
public List<String> getInboundDestination() { |
||||
return Collections.unmodifiableList(inboundDestination); |
||||
} |
||||
|
||||
public List<String> getOutboundDeparture() { |
||||
return Collections.unmodifiableList(outboundDeparture); |
||||
} |
||||
|
||||
public List<String> getOutboundArrival() { |
||||
return Collections.unmodifiableList(outboundArrival); |
||||
} |
||||
|
||||
public List<String> getOutboundOrigin() { |
||||
return Collections.unmodifiableList(outboundOrigin); |
||||
} |
||||
|
||||
public List<String> getOutboundDestination() { |
||||
return Collections.unmodifiableList(outboundDestination); |
||||
} |
||||
|
||||
public Double getInputPrice() { |
||||
return inputPrice; |
||||
} |
||||
|
||||
public Double getInputTax() { |
||||
return inputTax; |
||||
} |
||||
|
||||
public String getInputCurrency() { |
||||
return inputCurrency; |
||||
} |
||||
|
||||
public Integer getStatus() { |
||||
return status; |
||||
} |
||||
|
||||
public Double getOutputPrice() { |
||||
return outputPrice; |
||||
} |
||||
|
||||
public Double getOutputTax() { |
||||
return outputTax; |
||||
} |
||||
|
||||
public String getOutputCurrency() { |
||||
return outputCurrency; |
||||
} |
||||
|
||||
public Integer getDuration() { |
||||
return duration; |
||||
} |
||||
} |
@ -0,0 +1,36 @@ |
||||
package cz.aprar.bonitoo.inference; |
||||
|
||||
public enum TTL { |
||||
/** |
||||
* Very volatile or unknown data |
||||
*/ |
||||
NOCACHE, |
||||
/** |
||||
* Cache for 12 hours |
||||
*/ |
||||
H12, |
||||
/** |
||||
* Cache for 1 day |
||||
*/ |
||||
D1, |
||||
/** |
||||
* Cache for 2 days |
||||
*/ |
||||
D2, |
||||
/** |
||||
* Cache for 3 days |
||||
*/ |
||||
D3, |
||||
/** |
||||
* Cache for 1 week |
||||
*/ |
||||
D7, |
||||
/** |
||||
* Cache for 2 weeks |
||||
*/ |
||||
D14, |
||||
/** |
||||
* Cache for 30 days |
||||
*/ |
||||
D30; |
||||
} |
@ -0,0 +1,140 @@ |
||||
package cz.aprar.bonitoo.inference; |
||||
|
||||
import ml.dmlc.xgboost4j.java.XGBoostError; |
||||
import org.testng.annotations.BeforeClass; |
||||
import org.testng.annotations.DataProvider; |
||||
import org.testng.annotations.Test; |
||||
|
||||
import java.io.IOException; |
||||
import java.util.Arrays; |
||||
import java.util.List; |
||||
|
||||
import static java.util.Collections.emptyList; |
||||
import static org.testng.Assert.assertEquals; |
||||
|
||||
public class CacheInferenceTest { |
||||
private CacheInference ci; |
||||
|
||||
@BeforeClass |
||||
void setUp() throws IOException, XGBoostError { |
||||
final ClassLoader classLoader = getClass().getClassLoader(); |
||||
ci = new CacheInference(classLoader.getResourceAsStream("model/xgboost-model.bin"), |
||||
classLoader.getResourceAsStream("model/encoder.json")); |
||||
} |
||||
|
||||
@Test(dataProvider = "inferenceData") |
||||
void testInference(final FlightData input, final TTL expected) throws XGBoostError { |
||||
final TTL result = ci.cacheTTL(input); |
||||
assertEquals(result, expected); |
||||
} |
||||
|
||||
@DataProvider(name = "inferenceData") |
||||
public Object[][] inferenceData() { |
||||
return new Object[][]{ |
||||
{new FlightData( |
||||
"fly-me-to", |
||||
"WS", |
||||
toList("2020-05-09T23:59:00", "2020-05-10T10:30:00"), |
||||
toList("2020-05-10T08:55:00", "2020-05-10T11:30:00"), |
||||
toList("MCO", "FRA"), |
||||
toList("FRA", "PRG"), |
||||
toList("2020-05-01T09:50:00", "2020-05-01T11:55:00"), |
||||
toList("2020-05-01T11:00:00", "2020-05-01T21:55:00"), |
||||
toList("PRG", "FRA"), |
||||
toList("FRA", "MCO"), |
||||
39766.0, |
||||
19776.0, |
||||
"CZK", |
||||
0, |
||||
39766.0, |
||||
19776.0, |
||||
"CZK", |
||||
427 |
||||
), TTL.D1}, |
||||
{new FlightData( |
||||
"fly-me-to", |
||||
"PYTON", |
||||
emptyList(), |
||||
emptyList(), |
||||
emptyList(), |
||||
emptyList(), |
||||
toList("2019-12-18T05:45:00"), |
||||
toList("2019-12-18T08:05:00"), |
||||
toList("KRK"), |
||||
toList("BVA"), |
||||
336.258, |
||||
0.0, |
||||
"CZK", |
||||
0, |
||||
336.258, |
||||
0.0, |
||||
"CZK", |
||||
2284 |
||||
), TTL.D7}, |
||||
{new FlightData( |
||||
"levne", |
||||
"AVIA", |
||||
toList("2020-02-07T02:25:00", "2020-02-07T14:50:00"), |
||||
toList("2020-02-07T13:10:00", "2020-02-07T16:55:00"), |
||||
toList("LAX", "LHR"), |
||||
toList("LHR", "PRG"), |
||||
toList("2020-01-28T10:35:00", "2020-01-28T14:40:00"), |
||||
toList("2020-01-28T12:45:00", "2020-01-29T01:45:00"), |
||||
toList("PRG", "HEL"), |
||||
toList("HEL", "LAX"), |
||||
5971.77978, |
||||
0.0, |
||||
"CZK", |
||||
0, |
||||
15971.77978, |
||||
0.0, |
||||
"CZK", |
||||
551 |
||||
), TTL.D7}, |
||||
{new FlightData( |
||||
"fly-me-to", |
||||
"HH", |
||||
toList("2019-11-01T16:30:00", "2019-11-01T23:35:00"), |
||||
toList("2019-11-01T21:12:00", "2019-11-02T07:45:00"), |
||||
toList("YVR", "YUL"), |
||||
toList("YUL", "VIE"), |
||||
toList("2019-10-18T08:10:00", "2019-10-18T11:30:00"), |
||||
toList("2019-10-18T09:40:00", "2019-10-18T21:25:00"), |
||||
toList("VIE", "FRA"), |
||||
toList("FRA", "YVR"), |
||||
17723.0, |
||||
7708.0, |
||||
"CZK", |
||||
0, |
||||
17723.0, |
||||
7708.0, |
||||
"CZK", |
||||
1786 |
||||
), TTL.D1}, |
||||
{new FlightData( |
||||
"unknown", |
||||
"unknown", |
||||
toList("unknown"), |
||||
toList("unknown"), |
||||
toList("unknown"), |
||||
toList("unknown"), |
||||
toList("unknown"), |
||||
toList("unknown"), |
||||
toList("unknown"), |
||||
toList("unknown"), |
||||
0.0, |
||||
0.0, |
||||
"unknown", |
||||
0, |
||||
0.0, |
||||
0.0, |
||||
"unknown", |
||||
0 |
||||
), TTL.NOCACHE} |
||||
}; |
||||
} |
||||
|
||||
private List<String> toList(final String... data) { |
||||
return Arrays.asList(data); |
||||
} |
||||
} |
File diff suppressed because one or more lines are too long
Binary file not shown.
@ -0,0 +1,33 @@ |
||||
import sagemaker |
||||
import boto3 |
||||
from sagemaker import get_execution_role |
||||
from sagemaker.xgboost.estimator import XGBoost |
||||
|
||||
boto_session = boto3.Session(profile_name='bonitoo', region_name='eu-central-1') |
||||
|
||||
# local: sagemaker_session = sagemaker.LocalSession(boto_session=boto_session) |
||||
sagemaker_session = sagemaker.Session(boto_session=boto_session) |
||||
|
||||
role = 'Bonitoo_SageMaker_Execution' |
||||
train_input = 's3://customers-bonitoo-cachettl/sagemaker/data/export-reduced.csv' |
||||
|
||||
tf = XGBoost( |
||||
entry_point='train_model.py', |
||||
source_dir='./src', |
||||
train_instance_type='ml.c5.xlarge', |
||||
train_instance_count=1, |
||||
role=role, |
||||
sagemaker_session=sagemaker_session, |
||||
framework_version='0.90-1', |
||||
py_version='py3', |
||||
hyperparameters={ |
||||
'bonitoo_price_limit': 1000, |
||||
'num_round': 15, |
||||
'max_depth': 15, |
||||
'eta': 0.5, |
||||
'num_class': 8, |
||||
'objective': 'multi:softmax', |
||||
'eval_metric': 'mlogloss' |
||||
}) |
||||
|
||||
tf.fit({'training': train_input}) |
Loading…
Reference in new issue