Bonitoo cache ttl estimation
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

414 lines
16 KiB

import os
import json
import logging
import argparse
import pandas as pd
import xgboost as xgb
import numpy as np
import datetime as dt
from sagemaker_algorithm_toolkit import exceptions as exc
from sagemaker_xgboost_container.constants import sm_env_constants
from sagemaker_xgboost_container.data_utils import get_content_type, get_dmatrix, get_size, validate_data_file_path
from sagemaker_xgboost_container import distributed
from sagemaker_xgboost_container import checkpointing
from sagemaker_xgboost_container.algorithm_mode import channel_validation as cv
from sagemaker_xgboost_container.algorithm_mode import hyperparameter_validation as hpv
from sagemaker_xgboost_container.algorithm_mode import metrics as metrics_mod
from sagemaker_xgboost_container.algorithm_mode import train_utils
from sagemaker_xgboost_container.constants.xgb_constants import CUSTOMER_ERRORS
columns = [
'timestamp',
'type',
'flight.inboundSegments.departure',
'flight.inboundSegments.arrival',
'flight.inboundSegments.origin.airportCode',
'flight.inboundSegments.destination.airportCode',
'flight.inboundSegments.airline.code',
'flight.inboundMCX.code',
'flight.outboundSegments.departure',
'flight.outboundSegments.arrival',
'flight.outboundSegments.origin.airportCode',
'flight.outboundSegments.destination.airportCode',
'flight.outboundSegments.airline.code',
'flight.outboundMCX.code',
'input.price',
'success',
'output.price',
'cacheAt',
'cacheExp'
]
catcolumns = [
'type',
'flight.inboundSegments.departure',
'flight.inboundSegments.arrival',
'flight.inboundSegments.origin.airportCode',
'flight.inboundSegments.destination.airportCode',
'flight.inboundSegments.airline.code',
'flight.inboundMCX.code',
'flight.outboundSegments.departure',
'flight.outboundSegments.arrival',
'flight.outboundSegments.origin.airportCode',
'flight.outboundSegments.destination.airportCode',
'flight.outboundSegments.airline.code',
'flight.outboundMCX.code'
]
floatcolumns = [
'input.price',
'output.price'
]
timestampcolumns = [
'timestamp',
'cacheAt',
'cacheExp'
]
def excessive_price(inprice, outprice, price_pos_abs, price_neg_abs, price_pos_perc, price_neg_perc):
return outprice - inprice > price_pos_abs or \
inprice - outprice > price_neg_abs or \
outprice > inprice * (1.0 + price_pos_perc) or \
outprice < inprice * (1.0 - price_neg_perc)
def equal_price(inprice, outprice):
return abs(inprice - outprice) < 10
def expected_value(row, price_pos_abs=200, price_neg_abs=100, price_pos_perc=0.05, price_neg_perc=0.05):
# do not cache errors
success = row['success']
if success == 0:
return 0
inprice, outprice = row['input.price'], row['output.price']
tstamp, cacheAt, cacheExp = row['timestamp'], row['cacheAt'], row['cacheExp']
if not pd.isnull(cacheAt):
incachetime = tstamp - cacheAt
expcachetime = cacheExp - cacheAt
else:
incachetime = np.timedelta64('NaT')
modifier = 0
if excessive_price(inprice, outprice, price_pos_abs, price_neg_abs, price_pos_perc, price_neg_perc):
modifier = -1
expcachetime = incachetime
if equal_price(inprice, outprice):
if pd.isnull(incachetime):
return 4
modifier = 1
if pd.isnull(incachetime):
return 2
if expcachetime <= np.timedelta64(12,'h'):
return 1 + modifier
if expcachetime <= np.timedelta64(1,'D'):
return 2 + modifier
if expcachetime <= np.timedelta64(2,'D'):
return 3 + modifier
if expcachetime <= np.timedelta64(3,'D'):
return 4 + modifier
if expcachetime <= np.timedelta64(7,'D'):
return 5 + modifier
if expcachetime <= np.timedelta64(14,'D'):
return 6 + modifier
return min(7, 7 + modifier)
def fromisoformat(s):
return dt.datetime.strptime(s, '%Y-%m-%dT%H:%M:%S')
def compute_duration(row):
indeparture, outdeparture = row['flight.inboundSegments.departure'], row['flight.outboundSegments.departure']
if pd.isna(indeparture):
return 0
else:
indt = fromisoformat(indeparture.split('|')[0])
outdt = fromisoformat(outdeparture.split('|')[0])
return (indt - outdt).days
def compute_prebooking(row):
tstamp, outdeparture = row['timestamp'], row['flight.outboundSegments.departure']
outdt = fromisoformat(outdeparture.split('|')[0])
return (outdt - tstamp).days
def compute_dom(outdeparture):
outdt = fromisoformat(outdeparture.split('|')[0])
return outdt.day
def compute_dow(outdeparture):
outdt = fromisoformat(outdeparture.split('|')[0])
return outdt.weekday()
def preprocess_data(df):
logging.info('Preprocessing start')
booleanDictionary = {True: 1, False: 0}
df.loc[:, 'success'] = df.loc[:, 'success'].replace(booleanDictionary)
for ct in timestampcolumns:
df.loc[:, ct] = df.loc[:, ct].apply(lambda x: pd.to_datetime(x))
for cc in catcolumns:
df.loc[:, cc] = df.loc[:, cc].astype('category')
df.loc[:, '%s_codes' % cc] = df[cc].cat.codes
df.loc[:, floatcolumns] = df.loc[:, floatcolumns].astype('float64')
df.loc[:, 'duration'] = df.apply(lambda x: compute_duration(x), axis=1)
df.loc[:, 'prebooking'] = df.apply(lambda x: compute_prebooking(x), axis=1)
df.loc[:, 'order_dom'] = df.loc[:, 'timestamp'].apply(lambda x: x.day)
df.loc[:, 'order_dow'] = df.loc[:, 'timestamp'].apply(lambda x: x.dayofweek)
df.loc[:, 'flight_dom'] = df.loc[:, 'flight.outboundSegments.departure'].apply(lambda x: compute_dom(x))
df.loc[:, 'flight_dow'] = df.loc[:, 'flight.outboundSegments.departure'].apply(lambda x: compute_dow(x))
return df
def remove_non_features(df):
return df.drop(catcolumns + timestampcolumns + ['output.price', 'success'], axis=1), df
def train_test_split(df, label, ratio):
logging.info('Splitting dataset with ration %f', ratio)
msk = np.random.rand(len(df)) < ratio
train_data = df[msk]
test_data = df[~msk]
train_label = label[msk]
test_label = label[~msk]
train_data = train_data.reset_index()
test_data = test_data.reset_index()
train_label = train_label.reset_index()
test_label = test_label.reset_index()
return train_data, test_data, train_label, test_label
def get_csv_pandas(files_path):
csv_file = files_path if os.path.isfile(files_path) else [
f for f in os.listdir(files_path) if os.path.isfile(os.path.join(files_path, f))][0]
try:
logging.info('Loading csv file %s', csv_file)
df = pd.read_csv(os.path.join(files_path, csv_file), header=None)
df.columns = columns
return df
except Exception as e:
raise exc.UserError("Failed to load csv data with exception:\n{}".format(e))
def get_pandas_df(data_path):
if not os.path.exists(data_path):
return None
else:
if os.path.isfile(data_path):
files_path = data_path
else:
for root, dirs, files in os.walk(data_path):
if dirs == []:
files_path = root
break
df = get_csv_pandas(files_path)
return df
def get_df(train_path, validate_path, content_type='text/csv'):
train_files_size = get_size(train_path) if train_path else 0
val_files_size = get_size(validate_path) if validate_path else 0
logging.debug("File size need to be processed in the node: {}mb.".format(
round((train_files_size + val_files_size) / (1024 * 1024), 2)))
if train_files_size > 0:
validate_data_file_path(train_path, content_type)
if val_files_size > 0:
validate_data_file_path(validate_path, content_type)
train_pandas = get_pandas_df(train_path) if train_files_size > 0 else None
val_pandas = get_pandas_df(validate_path) if val_files_size > 0 else None
return train_pandas, val_pandas
def get_dmatrices(train_pandas, train_label_pandas, val_pandas, val_label_pandas, ratio=0.8):
if val_pandas:
train_dmatrix = xgb.DMatrix(train_pandas, label=train_label_pandas.loc[:, 'label'])
val_dmatrix = xgb.DMatrix(val_pandas, label=val_label_pandas.loc[:, 'label'])
else:
train_data, test_data, train_label, test_label = train_test_split(train_pandas, train_label_pandas, ratio)
train_dmatrix = xgb.DMatrix(train_data, label=train_label.loc[:, 'label'])
val_dmatrix = xgb.DMatrix(test_data, label=test_label.loc[:, 'label'])
return train_dmatrix, val_dmatrix
def save_encoders(encoder_location, df):
logging.info('Saving encoders')
jsondata = {}
for cc in catcolumns:
jsondata[cc] = {cat: idx for idx, cat in enumerate(df[cc].cat.categories)}
with open(encoder_location, 'w') as f:
json.dump(jsondata, f)
def sagemaker_train(train_config, data_config, train_path, val_path, model_dir, sm_hosts, sm_current_host,
checkpoint_config):
metrics = metrics_mod.initialize()
hyperparameters = hpv.initialize(metrics)
price_pos_abs = int(train_config.get('bonitoo_price_pos_abs', 200))
price_neg_abs = int(train_config.get('bonitoo_price_neg_abs', 200))
price_pos_perc = float(train_config.get('bonitoo_price_pos_perc', 0.05))
price_neg_perc = float(train_config.get('bonitoo_price_neg_perc', 0.05))
train_config = {k:v.replace('"', '') for k,v in train_config.items() if not k.startswith('sagemaker_') and not k.startswith('bonitoo_')}
train_config = hyperparameters.validate(train_config)
if train_config.get("updater"):
train_config["updater"] = ",".join(train_config["updater"])
logging.info("hyperparameters {}".format(train_config))
logging.info("channels {}".format(data_config))
# Get Training and Validation Data Matrices
validation_channel = data_config.get('validation', None)
checkpoint_dir = checkpoint_config.get("LocalPath", None)
train_df, val_df = get_df(train_path, val_path)
train_df = preprocess_data(train_df)
train_label_df = train_df.apply(lambda x: expected_value(x, price_pos_abs, price_neg_abs, price_pos_perc, price_neg_perc), axis=1).to_frame(name='label')
train_df, train_df_orig = remove_non_features(train_df)
val_label_df = None
if val_df:
val_df = preprocess_data(val_df)
val_label_df = val_df.apply(lambda x: expected_value(x, cachetime_df, price_pos_abs, price_neg_abs, price_pos_perc, price_neg_perc), axis=1).to_frame(name='label')
val_df, val_df_orig = remove_non_features(val_df)
train_dmatrix, val_dmatrix = get_dmatrices(train_df, train_label_df, val_df, val_label_df)
train_args = dict(
train_cfg=train_config,
train_dmatrix=train_dmatrix,
train_df=train_df_orig,
val_dmatrix=val_dmatrix,
model_dir=model_dir,
checkpoint_dir=checkpoint_dir)
# Obtain information about training resources to determine whether to set up Rabit or not
num_hosts = len(sm_hosts)
if num_hosts > 1:
# Wait for hosts to find each other
logging.info("Distributed node training with {} hosts: {}".format(num_hosts, sm_hosts))
distributed.wait_hostname_resolution(sm_hosts)
if not train_dmatrix:
logging.warning("Host {} does not have data. Will broadcast to cluster and will not be used in distributed"
" training.".format(sm_current_host))
distributed.rabit_run(exec_fun=train_job, args=train_args, include_in_training=(train_dmatrix is not None),
hosts=sm_hosts, current_host=sm_current_host, update_rabit_args=True)
elif num_hosts == 1:
if train_dmatrix:
if validation_channel:
if not val_dmatrix:
raise exc.UserError("No data in validation channel path {}".format(val_path))
logging.info("Single node training.")
train_args.update({'is_master': True})
train_job(**train_args)
else:
raise exc.UserError("No data in training channel path {}".format(train_path))
else:
raise exc.PlatformError("Number of hosts should be an int greater than or equal to 1")
def train_job(train_cfg, train_dmatrix, val_dmatrix, train_df, model_dir, checkpoint_dir, is_master):
# Parse arguments for train() API
early_stopping_rounds = train_cfg.get('early_stopping_rounds')
num_round = int(train_cfg["num_round"])
# Evaluation metrics to use with train() API
tuning_objective_metric_param = train_cfg.get("_tuning_objective_metric")
eval_metric = train_cfg.get("eval_metric")
cleaned_eval_metric, configured_feval = train_utils.get_eval_metrics_and_feval(
tuning_objective_metric_param, eval_metric)
if cleaned_eval_metric:
train_cfg['eval_metric'] = cleaned_eval_metric
else:
train_cfg.pop('eval_metric', None)
# Set callback evals
watchlist = [(train_dmatrix, 'train')]
if val_dmatrix is not None:
watchlist.append((val_dmatrix, 'validation'))
xgb_model, iteration = checkpointing.load_checkpoint(checkpoint_dir)
num_round -= iteration
if xgb_model is not None:
logging.info("Checkpoint loaded from %s", xgb_model)
logging.info("Resuming from iteration %s", iteration)
callbacks = []
callbacks.append(checkpointing.print_checkpointed_evaluation(start_iteration=iteration))
if checkpoint_dir:
save_checkpoint = checkpointing.save_checkpoint(checkpoint_dir, start_iteration=iteration)
callbacks.append(save_checkpoint)
logging.info("Train matrix has {} rows".format(train_dmatrix.num_row()))
if val_dmatrix:
logging.info("Validation matrix has {} rows".format(val_dmatrix.num_row()))
try:
bst = xgb.train(train_cfg, train_dmatrix, num_boost_round=num_round, evals=watchlist, feval=configured_feval,
early_stopping_rounds=early_stopping_rounds, callbacks=callbacks, xgb_model=xgb_model,
verbose_eval=False)
except Exception as e:
for customer_error_message in CUSTOMER_ERRORS:
if customer_error_message in str(e):
raise exc.UserError(str(e))
exception_prefix = "XGB train call failed with exception"
raise exc.AlgorithmError("{}:\n {}".format(exception_prefix, str(e)))
if not os.path.exists(model_dir):
os.makedirs(model_dir)
if is_master:
encoder_location = model_dir + '/encoder.json'
save_encoders(encoder_location, train_df)
logging.info("Stored encoders at {}".format(encoder_location))
model_location = model_dir + '/xgboost-model.bin'
bst.save_model(model_location)
logging.info("Stored trained model at {}".format(model_location))
if __name__ == '__main__':
with open(os.getenv(sm_env_constants.SM_INPUT_TRAINING_CONFIG_FILE), "r") as f:
train_config = json.load(f)
with open(os.getenv(sm_env_constants.SM_INPUT_DATA_CONFIG_FILE), "r") as f:
data_config = json.load(f)
checkpoint_config_file = os.getenv(sm_env_constants.SM_CHECKPOINT_CONFIG_FILE)
if os.path.exists(checkpoint_config_file):
with open(checkpoint_config_file, "r") as f:
checkpoint_config = json.load(f)
else:
checkpoint_config = {}
train_path = os.environ['SM_CHANNEL_TRAINING']
val_path = os.environ.get(sm_env_constants.SM_CHANNEL_VALIDATION)
sm_hosts = json.loads(os.environ[sm_env_constants.SM_HOSTS])
sm_current_host = os.environ[sm_env_constants.SM_CURRENT_HOST]
model_dir = os.getenv(sm_env_constants.SM_MODEL_DIR)
sagemaker_train(
train_config=train_config, data_config=data_config,
train_path=train_path, val_path=val_path, model_dir=model_dir,
sm_hosts=sm_hosts, sm_current_host=sm_current_host,
checkpoint_config=checkpoint_config
)