Reducet fields, first fully working version

master
EHP 6 years ago
parent 76645f273d
commit d675ca673a
  1. 75
      export.py
  2. 58
      inference/pom.xml
  3. 81
      inference/src/main/java/cz/aprar/bonitoo/inference/CacheInference.java
  4. 123
      inference/src/main/java/cz/aprar/bonitoo/inference/FlightData.java
  5. 36
      inference/src/main/java/cz/aprar/bonitoo/inference/TTL.java
  6. 140
      inference/src/test/java/cz/aprar/bonitoo/inference/CacheInferenceTest.java
  7. 1
      inference/src/test/resources/model/encoder.json
  8. BIN
      inference/src/test/resources/model/xgboost-model.bin
  9. 33
      runner.py
  10. 45
      src/train_model.py

@ -1,7 +1,7 @@
from pymongo import MongoClient
from pprint import pprint
from datetime import datetime
import csv
client = MongoClient()
db = client.bonitoo
@ -13,26 +13,10 @@ fieldnames = [
'flight.inboundSegments.arrival',
'flight.inboundSegments.origin.airportCode',
'flight.inboundSegments.destination.airportCode',
'flight.inboundSegments.flightNumber',
'flight.inboundSegments.travelClass',
'flight.inboundSegments.bookingCode',
'flight.inboundSegments.availability',
'flight.inboundSegments.elapsedFlyingTime',
'flight.outboundSegments.departure',
'flight.outboundSegments.arrival',
'flight.outboundSegments.origin.airportCode',
'flight.outboundSegments.destination.airportCode',
'flight.outboundSegments.flightNumber',
'flight.outboundSegments.travelClass',
'flight.outboundSegments.bookingCode',
'flight.outboundSegments.availability',
'flight.outboundSegments.elapsedFlyingTime',
'flight.inboundEFT', # elapsed flying time
'flight.outboundEFT',
'oneWay',
'adults', # pocet osob = (adults + children)
'children',
'infants',
'input.price',
'input.tax',
'input.currency',
@ -41,24 +25,9 @@ fieldnames = [
'output.price',
'output.tax',
'output.currency',
'duration' # delka volani do nadrazeneho systemu
'duration'
]
# 5% nebo 200 kc rozdil nahoru
# -200 kc dolu
# abs(+-10kc) ignorovat
# timestamp + ok price - ma byt v cache od cacheat
# timestamp + notok price - nema byt v cache od cacheat
# delka pobytu prilet-odlet
# delka letu ?
# pokud je chyba tak nocache (= chybi priceout)
# brat v uvahu in/out kody aerolinek (mcx ?) - mirek jeste zjisti
# vypocitat uspesnost je/neni v cache v %
counter = 0
with open('export.csv', mode='w') as ef:
writer = csv.DictWriter(ef, fieldnames=fieldnames, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL)
@ -73,21 +42,14 @@ with open('export.csv', mode='w') as ef:
'timestamp': datetime.fromtimestamp(it['timestamp'] / 1000).isoformat(),
'client.channel': it['client']['channel'],
'type': it['type'],
'flight.outboundSegments.departure': '|'.join([x['departure'].isoformat() for x in it['flight']['outboundSegments']]),
'flight.outboundSegments.arrival': '|'.join([x['arrival'].isoformat() for x in it['flight']['outboundSegments']]),
'flight.outboundSegments.origin.airportCode': '|'.join([x['origin']['airportCode'] for x in it['flight']['outboundSegments']]),
'flight.outboundSegments.destination.airportCode': '|'.join([x['destination']['airportCode'] for x in it['flight']['outboundSegments']]),
'flight.outboundSegments.flightNumber': '|'.join([x['flightNumber'] for x in it['flight']['outboundSegments']]),
'flight.outboundSegments.travelClass': '|'.join([x['travelClass'] for x in it['flight']['outboundSegments']]),
'flight.outboundSegments.bookingCode': '|'.join([x.get('bookingCode','') for x in it['flight']['outboundSegments']]),
'flight.outboundSegments.availability': '|'.join([str(x.get('availability','')) for x in it['flight']['outboundSegments']]),
'flight.outboundSegments.elapsedFlyingTime': '|'.join([str(x.get('elapsedFlyingTime','')) for x in it['flight']['outboundSegments']]),
'flight.inboundEFT': it['flight'].get('inboundEFT',''),
'flight.outboundEFT': it['flight'].get('outboundEFT',''),
'oneWay': it['oneWay'],
'adults': it['adults'],
'children': it['children'],
'infants': it['infants'],
'flight.outboundSegments.departure': '|'.join(
[x['departure'].isoformat() for x in it['flight']['outboundSegments']]),
'flight.outboundSegments.arrival': '|'.join(
[x['arrival'].isoformat() for x in it['flight']['outboundSegments']]),
'flight.outboundSegments.origin.airportCode': '|'.join(
[x['origin']['airportCode'] for x in it['flight']['outboundSegments']]),
'flight.outboundSegments.destination.airportCode': '|'.join(
[x['destination']['airportCode'] for x in it['flight']['outboundSegments']]),
'input.price': it['input']['price'],
'input.tax': it['input']['tax'],
'input.currency': it['input']['currency'],
@ -101,15 +63,14 @@ with open('export.csv', mode='w') as ef:
if 'inboundSegments' in it['flight']:
inb = {
'flight.inboundSegments.departure': '|'.join([x['departure'].isoformat() for x in it['flight']['inboundSegments']]),
'flight.inboundSegments.arrival': '|'.join([x['arrival'].isoformat() for x in it['flight']['inboundSegments']]),
'flight.inboundSegments.origin.airportCode': '|'.join([x['origin']['airportCode'] for x in it['flight']['inboundSegments']]),
'flight.inboundSegments.destination.airportCode': '|'.join([x['destination']['airportCode'] for x in it['flight']['inboundSegments']]),
'flight.inboundSegments.flightNumber': '|'.join([x['flightNumber'] for x in it['flight']['inboundSegments']]),
'flight.inboundSegments.travelClass': '|'.join([x['travelClass'] for x in it['flight']['inboundSegments']]),
'flight.inboundSegments.bookingCode': '|'.join([x.get('bookingCode', '') for x in it['flight']['inboundSegments']]),
'flight.inboundSegments.availability': '|'.join([str(x.get('availability','')) for x in it['flight']['inboundSegments']]),
'flight.inboundSegments.elapsedFlyingTime': '|'.join([str(x.get('elapsedFlyingTime','')) for x in it['flight']['inboundSegments']])
'flight.inboundSegments.departure': '|'.join(
[x['departure'].isoformat() for x in it['flight']['inboundSegments']]),
'flight.inboundSegments.arrival': '|'.join(
[x['arrival'].isoformat() for x in it['flight']['inboundSegments']]),
'flight.inboundSegments.origin.airportCode': '|'.join(
[x['origin']['airportCode'] for x in it['flight']['inboundSegments']]),
'flight.inboundSegments.destination.airportCode': '|'.join(
[x['destination']['airportCode'] for x in it['flight']['inboundSegments']]),
}
d = {**d, **inb}
writer.writerow(d)

@ -0,0 +1,58 @@
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>cz.aprar.bonitoo</groupId>
<artifactId>inference</artifactId>
<packaging>jar</packaging>
<version>1.0-SNAPSHOT</version>
<name>inference</name>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
</properties>
<dependencies>
<dependency>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost4j</artifactId>
<version>0.90</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>2.9.9</version>
</dependency>
<dependency>
<groupId>org.testng</groupId>
<artifactId>testng</artifactId>
<version>7.0.0</version>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.8.1</version>
<configuration>
<source>8</source>
<target>8</target>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-pmd-plugin</artifactId>
<version>3.12.0</version>
<configuration>
<sourceEncoding>utf-8</sourceEncoding>
<targetJdk>1.8</targetJdk>
</configuration>
</plugin>
</plugins>
</build>
</project>

@ -0,0 +1,81 @@
package cz.aprar.bonitoo.inference;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import ml.dmlc.xgboost4j.java.Booster;
import ml.dmlc.xgboost4j.java.DMatrix;
import ml.dmlc.xgboost4j.java.XGBoost;
import ml.dmlc.xgboost4j.java.XGBoostError;
import java.io.BufferedInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.List;
import java.util.Map;
public class CacheInference {
private static final TTL[] TTL_VALUES = TTL.values();
private final ObjectMapper mapper = new ObjectMapper();
private final Booster booster;
private final Map<String, Map<String, Integer>> labels;
public CacheInference(final Path modelFile, final Path labelFile) throws IOException, XGBoostError {
try (BufferedInputStream modelIS = new BufferedInputStream(Files.newInputStream(modelFile));
BufferedInputStream labelIS = new BufferedInputStream(Files.newInputStream(labelFile))
) {
booster = loadModel(modelIS);
labels = loadLabels(labelIS);
}
}
public CacheInference(InputStream modelIS, InputStream labelIS) throws IOException, XGBoostError {
booster = loadModel(modelIS);
labels = loadLabels(labelIS);
}
public TTL cacheTTL(final FlightData data) throws XGBoostError {
float[][] predicts = booster.predict(createMatrix(data));
return TTL_VALUES[(int) predicts[0][0]];
}
private Booster loadModel(InputStream model) throws XGBoostError, IOException {
return XGBoost.loadModel(model);
}
private DMatrix createMatrix(final FlightData data) throws XGBoostError {
final float[] arr = new float[18];
arr[0] = labels.get("client.channel").getOrDefault(data.getClientChannel(), 0);
arr[1] = labels.get("type").getOrDefault(data.getType(), 0);
arr[2] = labels.get("flight.inboundSegments.departure").getOrDefault(joinList(data.getInboundDeparture()), 0);
arr[3] = labels.get("flight.inboundSegments.arrival").getOrDefault(joinList(data.getInboundArrival()), 0);
arr[4] = labels.get("flight.inboundSegments.origin.airportCode").getOrDefault(joinList(data.getInboundOrigin()), 0);
arr[5] = labels.get("flight.inboundSegments.destination.airportCode").getOrDefault(joinList(data.getInboundDestination()), 0);
arr[6] = labels.get("flight.outboundSegments.departure").getOrDefault(joinList(data.getOutboundDeparture()), 0);
arr[7] = labels.get("flight.outboundSegments.arrival").getOrDefault(joinList(data.getOutboundArrival()), 0);
arr[8] = labels.get("flight.outboundSegments.origin.airportCode").getOrDefault(joinList(data.getOutboundOrigin()), 0);
arr[9] = labels.get("flight.outboundSegments.destination.airportCode").getOrDefault(joinList(data.getOutboundDestination()), 0);
arr[10] = data.getInputPrice().floatValue();
arr[11] = data.getInputTax().floatValue();
arr[12] = labels.get("input.currency").getOrDefault(data.getInputCurrency(), 0);
arr[13] = data.getStatus().floatValue();
arr[14] = data.getOutputPrice().floatValue();
arr[15] = data.getOutputTax().floatValue();
arr[16] = labels.get("output.currency").getOrDefault(data.getOutputCurrency(), 0);
arr[17] = data.getDuration().floatValue();
return new DMatrix(arr, 1, arr.length);
}
private Map<String, Map<String, Integer>> loadLabels(InputStream labels) throws IOException {
final TypeReference<Map<String, Map<String, Integer>>> typeRef = new TypeReference<Map<String, Map<String, Integer>>>() {
};
return mapper.readValue(labels, typeRef);
}
private String joinList(final List<String> data) {
return String.join("|", data);
}
}

@ -0,0 +1,123 @@
package cz.aprar.bonitoo.inference;
import java.util.Collections;
import java.util.List;
public class FlightData {
private final String clientChannel;
private final String type;
private final List<String> inboundDeparture;
private final List<String> inboundArrival;
private final List<String> inboundOrigin;
private final List<String> inboundDestination;
private final List<String> outboundDeparture;
private final List<String> outboundArrival;
private final List<String> outboundOrigin;
private final List<String> outboundDestination;
private final Double inputPrice;
private final Double inputTax;
private final String inputCurrency;
private final Integer status;
private final Double outputPrice;
private final Double outputTax;
private final String outputCurrency;
private final Integer duration;
public FlightData(final String clientChannel, final String type,
final List<String> inboundDeparture, final List<String> inboundArrival, final List<String> inboundOrigin,
final List<String> inboundDestination, final List<String> outboundDeparture, final List<String> outboundArrival,
final List<String> outboundOrigin, final List<String> outboundDestination, final Double inputPrice,
final Double inputTax, final String inputCurrency, final Integer status, final Double outputPrice,
final Double outputTax, final String outputCurrency, final Integer duration) {
this.clientChannel = clientChannel;
this.type = type;
this.inboundDeparture = inboundDeparture;
this.inboundArrival = inboundArrival;
this.inboundOrigin = inboundOrigin;
this.inboundDestination = inboundDestination;
this.outboundDeparture = outboundDeparture;
this.outboundArrival = outboundArrival;
this.outboundOrigin = outboundOrigin;
this.outboundDestination = outboundDestination;
this.inputPrice = inputPrice;
this.inputTax = inputTax;
this.inputCurrency = inputCurrency;
this.status = status;
this.outputPrice = outputPrice;
this.outputTax = outputTax;
this.outputCurrency = outputCurrency;
this.duration = duration;
}
public String getClientChannel() {
return clientChannel;
}
public String getType() {
return type;
}
public List<String> getInboundDeparture() {
return Collections.unmodifiableList(inboundDeparture);
}
public List<String> getInboundArrival() {
return Collections.unmodifiableList(inboundArrival);
}
public List<String> getInboundOrigin() {
return Collections.unmodifiableList(inboundOrigin);
}
public List<String> getInboundDestination() {
return Collections.unmodifiableList(inboundDestination);
}
public List<String> getOutboundDeparture() {
return Collections.unmodifiableList(outboundDeparture);
}
public List<String> getOutboundArrival() {
return Collections.unmodifiableList(outboundArrival);
}
public List<String> getOutboundOrigin() {
return Collections.unmodifiableList(outboundOrigin);
}
public List<String> getOutboundDestination() {
return Collections.unmodifiableList(outboundDestination);
}
public Double getInputPrice() {
return inputPrice;
}
public Double getInputTax() {
return inputTax;
}
public String getInputCurrency() {
return inputCurrency;
}
public Integer getStatus() {
return status;
}
public Double getOutputPrice() {
return outputPrice;
}
public Double getOutputTax() {
return outputTax;
}
public String getOutputCurrency() {
return outputCurrency;
}
public Integer getDuration() {
return duration;
}
}

@ -0,0 +1,36 @@
package cz.aprar.bonitoo.inference;
public enum TTL {
/**
* Very volatile or unknown data
*/
NOCACHE,
/**
* Cache for 12 hours
*/
H12,
/**
* Cache for 1 day
*/
D1,
/**
* Cache for 2 days
*/
D2,
/**
* Cache for 3 days
*/
D3,
/**
* Cache for 1 week
*/
D7,
/**
* Cache for 2 weeks
*/
D14,
/**
* Cache for 30 days
*/
D30;
}

@ -0,0 +1,140 @@
package cz.aprar.bonitoo.inference;
import ml.dmlc.xgboost4j.java.XGBoostError;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import static java.util.Collections.emptyList;
import static org.testng.Assert.assertEquals;
public class CacheInferenceTest {
private CacheInference ci;
@BeforeClass
void setUp() throws IOException, XGBoostError {
final ClassLoader classLoader = getClass().getClassLoader();
ci = new CacheInference(classLoader.getResourceAsStream("model/xgboost-model.bin"),
classLoader.getResourceAsStream("model/encoder.json"));
}
@Test(dataProvider = "inferenceData")
void testInference(final FlightData input, final TTL expected) throws XGBoostError {
final TTL result = ci.cacheTTL(input);
assertEquals(result, expected);
}
@DataProvider(name = "inferenceData")
public Object[][] inferenceData() {
return new Object[][]{
{new FlightData(
"fly-me-to",
"WS",
toList("2020-05-09T23:59:00", "2020-05-10T10:30:00"),
toList("2020-05-10T08:55:00", "2020-05-10T11:30:00"),
toList("MCO", "FRA"),
toList("FRA", "PRG"),
toList("2020-05-01T09:50:00", "2020-05-01T11:55:00"),
toList("2020-05-01T11:00:00", "2020-05-01T21:55:00"),
toList("PRG", "FRA"),
toList("FRA", "MCO"),
39766.0,
19776.0,
"CZK",
0,
39766.0,
19776.0,
"CZK",
427
), TTL.D1},
{new FlightData(
"fly-me-to",
"PYTON",
emptyList(),
emptyList(),
emptyList(),
emptyList(),
toList("2019-12-18T05:45:00"),
toList("2019-12-18T08:05:00"),
toList("KRK"),
toList("BVA"),
336.258,
0.0,
"CZK",
0,
336.258,
0.0,
"CZK",
2284
), TTL.D7},
{new FlightData(
"levne",
"AVIA",
toList("2020-02-07T02:25:00", "2020-02-07T14:50:00"),
toList("2020-02-07T13:10:00", "2020-02-07T16:55:00"),
toList("LAX", "LHR"),
toList("LHR", "PRG"),
toList("2020-01-28T10:35:00", "2020-01-28T14:40:00"),
toList("2020-01-28T12:45:00", "2020-01-29T01:45:00"),
toList("PRG", "HEL"),
toList("HEL", "LAX"),
5971.77978,
0.0,
"CZK",
0,
15971.77978,
0.0,
"CZK",
551
), TTL.D7},
{new FlightData(
"fly-me-to",
"HH",
toList("2019-11-01T16:30:00", "2019-11-01T23:35:00"),
toList("2019-11-01T21:12:00", "2019-11-02T07:45:00"),
toList("YVR", "YUL"),
toList("YUL", "VIE"),
toList("2019-10-18T08:10:00", "2019-10-18T11:30:00"),
toList("2019-10-18T09:40:00", "2019-10-18T21:25:00"),
toList("VIE", "FRA"),
toList("FRA", "YVR"),
17723.0,
7708.0,
"CZK",
0,
17723.0,
7708.0,
"CZK",
1786
), TTL.D1},
{new FlightData(
"unknown",
"unknown",
toList("unknown"),
toList("unknown"),
toList("unknown"),
toList("unknown"),
toList("unknown"),
toList("unknown"),
toList("unknown"),
toList("unknown"),
0.0,
0.0,
"unknown",
0,
0.0,
0.0,
"unknown",
0
), TTL.NOCACHE}
};
}
private List<String> toList(final String... data) {
return Arrays.asList(data);
}
}

File diff suppressed because one or more lines are too long

@ -0,0 +1,33 @@
import sagemaker
import boto3
from sagemaker import get_execution_role
from sagemaker.xgboost.estimator import XGBoost
boto_session = boto3.Session(profile_name='bonitoo', region_name='eu-central-1')
# local: sagemaker_session = sagemaker.LocalSession(boto_session=boto_session)
sagemaker_session = sagemaker.Session(boto_session=boto_session)
role = 'Bonitoo_SageMaker_Execution'
train_input = 's3://customers-bonitoo-cachettl/sagemaker/data/export-reduced.csv'
tf = XGBoost(
entry_point='train_model.py',
source_dir='./src',
train_instance_type='ml.c5.xlarge',
train_instance_count=1,
role=role,
sagemaker_session=sagemaker_session,
framework_version='0.90-1',
py_version='py3',
hyperparameters={
'bonitoo_price_limit': 1000,
'num_round': 15,
'max_depth': 15,
'eta': 0.5,
'num_class': 8,
'objective': 'multi:softmax',
'eval_metric': 'mlogloss'
})
tf.fit({'training': train_input})

@ -25,26 +25,10 @@ columns = [
'flight.inboundSegments.arrival',
'flight.inboundSegments.origin.airportCode',
'flight.inboundSegments.destination.airportCode',
'flight.inboundSegments.flightNumber',
'flight.inboundSegments.travelClass',
'flight.inboundSegments.bookingCode',
'flight.inboundSegments.availability',
'flight.inboundSegments.elapsedFlyingTime',
'flight.outboundSegments.departure',
'flight.outboundSegments.arrival',
'flight.outboundSegments.origin.airportCode',
'flight.outboundSegments.destination.airportCode',
'flight.outboundSegments.flightNumber',
'flight.outboundSegments.travelClass',
'flight.outboundSegments.bookingCode',
'flight.outboundSegments.availability',
'flight.outboundSegments.elapsedFlyingTime',
'flight.inboundEFT',
'flight.outboundEFT',
'oneWay',
'adults',
'children',
'infants',
'input.price',
'input.tax',
'input.currency',
@ -63,25 +47,12 @@ catcolumns = [
'flight.inboundSegments.arrival',
'flight.inboundSegments.origin.airportCode',
'flight.inboundSegments.destination.airportCode',
'flight.inboundSegments.flightNumber',
'flight.inboundSegments.travelClass',
'flight.inboundSegments.bookingCode',
'flight.inboundSegments.availability',
'flight.inboundSegments.elapsedFlyingTime',
'flight.outboundSegments.departure',
'flight.outboundSegments.arrival',
'flight.outboundSegments.origin.airportCode',
'flight.outboundSegments.destination.airportCode',
'flight.outboundSegments.flightNumber',
'flight.outboundSegments.travelClass',
'flight.outboundSegments.bookingCode',
'flight.outboundSegments.availability',
'flight.outboundSegments.elapsedFlyingTime',
'flight.inboundEFT',
'flight.outboundEFT',
'input.currency',
'output.currency',
'oneWay'
'output.currency'
]
floatcolumns = [
@ -92,9 +63,6 @@ floatcolumns = [
]
intcolumns = [
'adults',
'children',
'infants',
'status',
'duration'
]
@ -157,9 +125,6 @@ def preprocess_data(df):
df.loc[:, 'timestamp'] = df.loc[:, 'timestamp'].apply(lambda x: pd.to_datetime(x))
booleanDictionary = {True: 'TRUE', False: 'FALSE'}
df.loc[:, 'oneWay'] = df.loc[:, 'oneWay'].replace(booleanDictionary)
for cc in catcolumns:
df.loc[:, cc] = df.loc[:, cc].astype('category')
df.loc[:, '%s_codes' % cc] = df[cc].cat.codes
@ -218,7 +183,6 @@ def get_pandas_df(data_path):
return df
def get_df(train_path, validate_path, content_type='text/csv'):
train_files_size = get_size(train_path) if train_path else 0
val_files_size = get_size(validate_path) if validate_path else 0
@ -324,7 +288,6 @@ def sagemaker_train(train_config, data_config, train_path, val_path, model_dir,
else:
raise exc.PlatformError("Number of hosts should be an int greater than or equal to 1")
def train_job(train_cfg, train_dmatrix, val_dmatrix, train_df, model_dir, checkpoint_dir, is_master):
# Parse arguments for train() API
early_stopping_rounds = train_cfg.get('early_stopping_rounds')
@ -361,10 +324,6 @@ def train_job(train_cfg, train_dmatrix, val_dmatrix, train_df, model_dir, checkp
if val_dmatrix:
logging.info("Validation matrix has {} rows".format(val_dmatrix.num_row()))
# TODO remove
#logging.info("cols: %s", str(train_dmatrix.feature_names))
#raise Exception("cols: %s" % str(train_dmatrix.feature_names))
try:
bst = xgb.train(train_cfg, train_dmatrix, num_boost_round=num_round, evals=watchlist, feval=configured_feval,
early_stopping_rounds=early_stopping_rounds, callbacks=callbacks, xgb_model=xgb_model,
@ -389,8 +348,6 @@ def train_job(train_cfg, train_dmatrix, val_dmatrix, train_df, model_dir, checkp
bst.save_model(model_location)
logging.info("Stored trained model at {}".format(model_location))
if __name__ == '__main__':
with open(os.getenv(sm_env_constants.SM_INPUT_TRAINING_CONFIG_FILE), "r") as f:
train_config = json.load(f)

Loading…
Cancel
Save