You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
414 lines
16 KiB
414 lines
16 KiB
import os |
|
import json |
|
import logging |
|
import argparse |
|
import pandas as pd |
|
import xgboost as xgb |
|
import numpy as np |
|
import datetime as dt |
|
|
|
from sagemaker_algorithm_toolkit import exceptions as exc |
|
from sagemaker_xgboost_container.constants import sm_env_constants |
|
from sagemaker_xgboost_container.data_utils import get_content_type, get_dmatrix, get_size, validate_data_file_path |
|
from sagemaker_xgboost_container import distributed |
|
from sagemaker_xgboost_container import checkpointing |
|
from sagemaker_xgboost_container.algorithm_mode import channel_validation as cv |
|
from sagemaker_xgboost_container.algorithm_mode import hyperparameter_validation as hpv |
|
from sagemaker_xgboost_container.algorithm_mode import metrics as metrics_mod |
|
from sagemaker_xgboost_container.algorithm_mode import train_utils |
|
from sagemaker_xgboost_container.constants.xgb_constants import CUSTOMER_ERRORS |
|
|
|
columns = [ |
|
'timestamp', |
|
'type', |
|
'flight.inboundSegments.departure', |
|
'flight.inboundSegments.arrival', |
|
'flight.inboundSegments.origin.airportCode', |
|
'flight.inboundSegments.destination.airportCode', |
|
'flight.inboundSegments.airline.code', |
|
'flight.inboundMCX.code', |
|
'flight.outboundSegments.departure', |
|
'flight.outboundSegments.arrival', |
|
'flight.outboundSegments.origin.airportCode', |
|
'flight.outboundSegments.destination.airportCode', |
|
'flight.outboundSegments.airline.code', |
|
'flight.outboundMCX.code', |
|
'input.price', |
|
'success', |
|
'output.price', |
|
'cacheAt', |
|
'cacheExp' |
|
] |
|
|
|
catcolumns = [ |
|
'type', |
|
'flight.inboundSegments.departure', |
|
'flight.inboundSegments.arrival', |
|
'flight.inboundSegments.origin.airportCode', |
|
'flight.inboundSegments.destination.airportCode', |
|
'flight.inboundSegments.airline.code', |
|
'flight.inboundMCX.code', |
|
'flight.outboundSegments.departure', |
|
'flight.outboundSegments.arrival', |
|
'flight.outboundSegments.origin.airportCode', |
|
'flight.outboundSegments.destination.airportCode', |
|
'flight.outboundSegments.airline.code', |
|
'flight.outboundMCX.code' |
|
] |
|
|
|
floatcolumns = [ |
|
'input.price', |
|
'output.price' |
|
] |
|
|
|
timestampcolumns = [ |
|
'timestamp', |
|
'cacheAt', |
|
'cacheExp' |
|
] |
|
|
|
def excessive_price(inprice, outprice, price_pos_abs, price_neg_abs, price_pos_perc, price_neg_perc): |
|
return outprice - inprice > price_pos_abs or \ |
|
inprice - outprice > price_neg_abs or \ |
|
outprice > inprice * (1.0 + price_pos_perc) or \ |
|
outprice < inprice * (1.0 - price_neg_perc) |
|
|
|
def equal_price(inprice, outprice): |
|
return abs(inprice - outprice) < 10 |
|
|
|
def expected_value(row, price_pos_abs=200, price_neg_abs=100, price_pos_perc=0.05, price_neg_perc=0.05): |
|
# do not cache errors |
|
success = row['success'] |
|
if success == 0: |
|
return 0 |
|
|
|
inprice, outprice = row['input.price'], row['output.price'] |
|
tstamp, cacheAt, cacheExp = row['timestamp'], row['cacheAt'], row['cacheExp'] |
|
|
|
if not pd.isnull(cacheAt): |
|
incachetime = tstamp - cacheAt |
|
expcachetime = cacheExp - cacheAt |
|
else: |
|
incachetime = np.timedelta64('NaT') |
|
|
|
modifier = 0 |
|
if excessive_price(inprice, outprice, price_pos_abs, price_neg_abs, price_pos_perc, price_neg_perc): |
|
modifier = -1 |
|
expcachetime = incachetime |
|
if equal_price(inprice, outprice): |
|
if pd.isnull(incachetime): |
|
return 4 |
|
modifier = 1 |
|
|
|
if pd.isnull(incachetime): |
|
return 2 |
|
|
|
if expcachetime <= np.timedelta64(12,'h'): |
|
return 1 + modifier |
|
if expcachetime <= np.timedelta64(1,'D'): |
|
return 2 + modifier |
|
if expcachetime <= np.timedelta64(2,'D'): |
|
return 3 + modifier |
|
if expcachetime <= np.timedelta64(3,'D'): |
|
return 4 + modifier |
|
if expcachetime <= np.timedelta64(7,'D'): |
|
return 5 + modifier |
|
if expcachetime <= np.timedelta64(14,'D'): |
|
return 6 + modifier |
|
|
|
return min(7, 7 + modifier) |
|
|
|
def fromisoformat(s): |
|
return dt.datetime.strptime(s, '%Y-%m-%dT%H:%M:%S') |
|
|
|
def compute_duration(row): |
|
indeparture, outdeparture = row['flight.inboundSegments.departure'], row['flight.outboundSegments.departure'] |
|
if pd.isna(indeparture): |
|
return 0 |
|
else: |
|
indt = fromisoformat(indeparture.split('|')[0]) |
|
outdt = fromisoformat(outdeparture.split('|')[0]) |
|
return (indt - outdt).days |
|
|
|
def compute_prebooking(row): |
|
tstamp, outdeparture = row['timestamp'], row['flight.outboundSegments.departure'] |
|
outdt = fromisoformat(outdeparture.split('|')[0]) |
|
return (outdt - tstamp).days |
|
|
|
def compute_dom(outdeparture): |
|
outdt = fromisoformat(outdeparture.split('|')[0]) |
|
return outdt.day |
|
|
|
def compute_dow(outdeparture): |
|
outdt = fromisoformat(outdeparture.split('|')[0]) |
|
return outdt.weekday() |
|
|
|
def preprocess_data(df): |
|
logging.info('Preprocessing start') |
|
|
|
booleanDictionary = {True: 1, False: 0} |
|
df.loc[:, 'success'] = df.loc[:, 'success'].replace(booleanDictionary) |
|
|
|
for ct in timestampcolumns: |
|
df.loc[:, ct] = df.loc[:, ct].apply(lambda x: pd.to_datetime(x)) |
|
|
|
for cc in catcolumns: |
|
df.loc[:, cc] = df.loc[:, cc].astype('category') |
|
df.loc[:, '%s_codes' % cc] = df[cc].cat.codes |
|
|
|
df.loc[:, floatcolumns] = df.loc[:, floatcolumns].astype('float64') |
|
|
|
df.loc[:, 'duration'] = df.apply(lambda x: compute_duration(x), axis=1) |
|
df.loc[:, 'prebooking'] = df.apply(lambda x: compute_prebooking(x), axis=1) |
|
df.loc[:, 'order_dom'] = df.loc[:, 'timestamp'].apply(lambda x: x.day) |
|
df.loc[:, 'order_dow'] = df.loc[:, 'timestamp'].apply(lambda x: x.dayofweek) |
|
df.loc[:, 'flight_dom'] = df.loc[:, 'flight.outboundSegments.departure'].apply(lambda x: compute_dom(x)) |
|
df.loc[:, 'flight_dow'] = df.loc[:, 'flight.outboundSegments.departure'].apply(lambda x: compute_dow(x)) |
|
|
|
return df |
|
|
|
def remove_non_features(df): |
|
return df.drop(catcolumns + timestampcolumns + ['output.price', 'success'], axis=1), df |
|
|
|
def train_test_split(df, label, ratio): |
|
logging.info('Splitting dataset with ration %f', ratio) |
|
|
|
msk = np.random.rand(len(df)) < ratio |
|
train_data = df[msk] |
|
test_data = df[~msk] |
|
train_label = label[msk] |
|
test_label = label[~msk] |
|
|
|
train_data = train_data.reset_index() |
|
test_data = test_data.reset_index() |
|
train_label = train_label.reset_index() |
|
test_label = test_label.reset_index() |
|
|
|
return train_data, test_data, train_label, test_label |
|
|
|
def get_csv_pandas(files_path): |
|
csv_file = files_path if os.path.isfile(files_path) else [ |
|
f for f in os.listdir(files_path) if os.path.isfile(os.path.join(files_path, f))][0] |
|
|
|
try: |
|
logging.info('Loading csv file %s', csv_file) |
|
|
|
df = pd.read_csv(os.path.join(files_path, csv_file), header=None) |
|
df.columns = columns |
|
return df |
|
|
|
except Exception as e: |
|
raise exc.UserError("Failed to load csv data with exception:\n{}".format(e)) |
|
|
|
def get_pandas_df(data_path): |
|
if not os.path.exists(data_path): |
|
return None |
|
else: |
|
if os.path.isfile(data_path): |
|
files_path = data_path |
|
else: |
|
for root, dirs, files in os.walk(data_path): |
|
if dirs == []: |
|
files_path = root |
|
break |
|
df = get_csv_pandas(files_path) |
|
|
|
return df |
|
|
|
|
|
def get_df(train_path, validate_path, content_type='text/csv'): |
|
train_files_size = get_size(train_path) if train_path else 0 |
|
val_files_size = get_size(validate_path) if validate_path else 0 |
|
|
|
logging.debug("File size need to be processed in the node: {}mb.".format( |
|
round((train_files_size + val_files_size) / (1024 * 1024), 2))) |
|
|
|
if train_files_size > 0: |
|
validate_data_file_path(train_path, content_type) |
|
if val_files_size > 0: |
|
validate_data_file_path(validate_path, content_type) |
|
|
|
train_pandas = get_pandas_df(train_path) if train_files_size > 0 else None |
|
val_pandas = get_pandas_df(validate_path) if val_files_size > 0 else None |
|
|
|
return train_pandas, val_pandas |
|
|
|
def get_dmatrices(train_pandas, train_label_pandas, val_pandas, val_label_pandas, ratio=0.8): |
|
if val_pandas: |
|
train_dmatrix = xgb.DMatrix(train_pandas, label=train_label_pandas.loc[:, 'label']) |
|
val_dmatrix = xgb.DMatrix(val_pandas, label=val_label_pandas.loc[:, 'label']) |
|
else: |
|
train_data, test_data, train_label, test_label = train_test_split(train_pandas, train_label_pandas, ratio) |
|
train_dmatrix = xgb.DMatrix(train_data, label=train_label.loc[:, 'label']) |
|
val_dmatrix = xgb.DMatrix(test_data, label=test_label.loc[:, 'label']) |
|
|
|
return train_dmatrix, val_dmatrix |
|
|
|
def save_encoders(encoder_location, df): |
|
logging.info('Saving encoders') |
|
|
|
jsondata = {} |
|
for cc in catcolumns: |
|
jsondata[cc] = {cat: idx for idx, cat in enumerate(df[cc].cat.categories)} |
|
|
|
with open(encoder_location, 'w') as f: |
|
json.dump(jsondata, f) |
|
|
|
def sagemaker_train(train_config, data_config, train_path, val_path, model_dir, sm_hosts, sm_current_host, |
|
checkpoint_config): |
|
metrics = metrics_mod.initialize() |
|
hyperparameters = hpv.initialize(metrics) |
|
|
|
price_pos_abs = int(train_config.get('bonitoo_price_pos_abs', 200)) |
|
price_neg_abs = int(train_config.get('bonitoo_price_neg_abs', 200)) |
|
price_pos_perc = float(train_config.get('bonitoo_price_pos_perc', 0.05)) |
|
price_neg_perc = float(train_config.get('bonitoo_price_neg_perc', 0.05)) |
|
|
|
train_config = {k:v.replace('"', '') for k,v in train_config.items() if not k.startswith('sagemaker_') and not k.startswith('bonitoo_')} |
|
train_config = hyperparameters.validate(train_config) |
|
|
|
if train_config.get("updater"): |
|
train_config["updater"] = ",".join(train_config["updater"]) |
|
|
|
logging.info("hyperparameters {}".format(train_config)) |
|
logging.info("channels {}".format(data_config)) |
|
|
|
# Get Training and Validation Data Matrices |
|
validation_channel = data_config.get('validation', None) |
|
checkpoint_dir = checkpoint_config.get("LocalPath", None) |
|
|
|
train_df, val_df = get_df(train_path, val_path) |
|
|
|
train_df = preprocess_data(train_df) |
|
train_label_df = train_df.apply(lambda x: expected_value(x, price_pos_abs, price_neg_abs, price_pos_perc, price_neg_perc), axis=1).to_frame(name='label') |
|
train_df, train_df_orig = remove_non_features(train_df) |
|
val_label_df = None |
|
|
|
if val_df: |
|
val_df = preprocess_data(val_df) |
|
val_label_df = val_df.apply(lambda x: expected_value(x, cachetime_df, price_pos_abs, price_neg_abs, price_pos_perc, price_neg_perc), axis=1).to_frame(name='label') |
|
val_df, val_df_orig = remove_non_features(val_df) |
|
train_dmatrix, val_dmatrix = get_dmatrices(train_df, train_label_df, val_df, val_label_df) |
|
|
|
train_args = dict( |
|
train_cfg=train_config, |
|
train_dmatrix=train_dmatrix, |
|
train_df=train_df_orig, |
|
val_dmatrix=val_dmatrix, |
|
model_dir=model_dir, |
|
checkpoint_dir=checkpoint_dir) |
|
|
|
# Obtain information about training resources to determine whether to set up Rabit or not |
|
num_hosts = len(sm_hosts) |
|
|
|
if num_hosts > 1: |
|
# Wait for hosts to find each other |
|
logging.info("Distributed node training with {} hosts: {}".format(num_hosts, sm_hosts)) |
|
distributed.wait_hostname_resolution(sm_hosts) |
|
|
|
if not train_dmatrix: |
|
logging.warning("Host {} does not have data. Will broadcast to cluster and will not be used in distributed" |
|
" training.".format(sm_current_host)) |
|
distributed.rabit_run(exec_fun=train_job, args=train_args, include_in_training=(train_dmatrix is not None), |
|
hosts=sm_hosts, current_host=sm_current_host, update_rabit_args=True) |
|
elif num_hosts == 1: |
|
if train_dmatrix: |
|
if validation_channel: |
|
if not val_dmatrix: |
|
raise exc.UserError("No data in validation channel path {}".format(val_path)) |
|
logging.info("Single node training.") |
|
train_args.update({'is_master': True}) |
|
train_job(**train_args) |
|
else: |
|
raise exc.UserError("No data in training channel path {}".format(train_path)) |
|
else: |
|
raise exc.PlatformError("Number of hosts should be an int greater than or equal to 1") |
|
|
|
|
|
def train_job(train_cfg, train_dmatrix, val_dmatrix, train_df, model_dir, checkpoint_dir, is_master): |
|
# Parse arguments for train() API |
|
early_stopping_rounds = train_cfg.get('early_stopping_rounds') |
|
num_round = int(train_cfg["num_round"]) |
|
|
|
# Evaluation metrics to use with train() API |
|
tuning_objective_metric_param = train_cfg.get("_tuning_objective_metric") |
|
eval_metric = train_cfg.get("eval_metric") |
|
cleaned_eval_metric, configured_feval = train_utils.get_eval_metrics_and_feval( |
|
tuning_objective_metric_param, eval_metric) |
|
if cleaned_eval_metric: |
|
train_cfg['eval_metric'] = cleaned_eval_metric |
|
else: |
|
train_cfg.pop('eval_metric', None) |
|
|
|
# Set callback evals |
|
watchlist = [(train_dmatrix, 'train')] |
|
if val_dmatrix is not None: |
|
watchlist.append((val_dmatrix, 'validation')) |
|
|
|
xgb_model, iteration = checkpointing.load_checkpoint(checkpoint_dir) |
|
num_round -= iteration |
|
if xgb_model is not None: |
|
logging.info("Checkpoint loaded from %s", xgb_model) |
|
logging.info("Resuming from iteration %s", iteration) |
|
|
|
callbacks = [] |
|
callbacks.append(checkpointing.print_checkpointed_evaluation(start_iteration=iteration)) |
|
if checkpoint_dir: |
|
save_checkpoint = checkpointing.save_checkpoint(checkpoint_dir, start_iteration=iteration) |
|
callbacks.append(save_checkpoint) |
|
|
|
logging.info("Train matrix has {} rows".format(train_dmatrix.num_row())) |
|
if val_dmatrix: |
|
logging.info("Validation matrix has {} rows".format(val_dmatrix.num_row())) |
|
|
|
try: |
|
bst = xgb.train(train_cfg, train_dmatrix, num_boost_round=num_round, evals=watchlist, feval=configured_feval, |
|
early_stopping_rounds=early_stopping_rounds, callbacks=callbacks, xgb_model=xgb_model, |
|
verbose_eval=False) |
|
except Exception as e: |
|
for customer_error_message in CUSTOMER_ERRORS: |
|
if customer_error_message in str(e): |
|
raise exc.UserError(str(e)) |
|
|
|
exception_prefix = "XGB train call failed with exception" |
|
raise exc.AlgorithmError("{}:\n {}".format(exception_prefix, str(e))) |
|
|
|
if not os.path.exists(model_dir): |
|
os.makedirs(model_dir) |
|
|
|
if is_master: |
|
encoder_location = model_dir + '/encoder.json' |
|
save_encoders(encoder_location, train_df) |
|
logging.info("Stored encoders at {}".format(encoder_location)) |
|
|
|
model_location = model_dir + '/xgboost-model.bin' |
|
bst.save_model(model_location) |
|
logging.info("Stored trained model at {}".format(model_location)) |
|
|
|
if __name__ == '__main__': |
|
with open(os.getenv(sm_env_constants.SM_INPUT_TRAINING_CONFIG_FILE), "r") as f: |
|
train_config = json.load(f) |
|
with open(os.getenv(sm_env_constants.SM_INPUT_DATA_CONFIG_FILE), "r") as f: |
|
data_config = json.load(f) |
|
|
|
checkpoint_config_file = os.getenv(sm_env_constants.SM_CHECKPOINT_CONFIG_FILE) |
|
if os.path.exists(checkpoint_config_file): |
|
with open(checkpoint_config_file, "r") as f: |
|
checkpoint_config = json.load(f) |
|
else: |
|
checkpoint_config = {} |
|
|
|
train_path = os.environ['SM_CHANNEL_TRAINING'] |
|
val_path = os.environ.get(sm_env_constants.SM_CHANNEL_VALIDATION) |
|
|
|
sm_hosts = json.loads(os.environ[sm_env_constants.SM_HOSTS]) |
|
sm_current_host = os.environ[sm_env_constants.SM_CURRENT_HOST] |
|
|
|
model_dir = os.getenv(sm_env_constants.SM_MODEL_DIR) |
|
|
|
sagemaker_train( |
|
train_config=train_config, data_config=data_config, |
|
train_path=train_path, val_path=val_path, model_dir=model_dir, |
|
sm_hosts=sm_hosts, sm_current_host=sm_current_host, |
|
checkpoint_config=checkpoint_config |
|
)
|
|
|