You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
36 lines
1.1 KiB
36 lines
1.1 KiB
import sagemaker |
|
import boto3 |
|
from sagemaker import get_execution_role |
|
from sagemaker.xgboost.estimator import XGBoost |
|
|
|
boto_session = boto3.Session(profile_name='bonitoo', region_name='eu-central-1') |
|
|
|
# local: sagemaker_session = sagemaker.LocalSession(boto_session=boto_session) |
|
sagemaker_session = sagemaker.Session(boto_session=boto_session) |
|
|
|
role = 'Bonitoo_SageMaker_Execution' |
|
train_input = 's3://customers-bonitoo-cachettl/sagemaker/data/export.csv' |
|
|
|
tf = XGBoost( |
|
entry_point='train_model.py', |
|
source_dir='./src', |
|
train_instance_type='ml.c5.xlarge', |
|
train_instance_count=1, |
|
role=role, |
|
sagemaker_session=sagemaker_session, |
|
framework_version='0.90-1', |
|
py_version='py3', |
|
hyperparameters={ |
|
'bonitoo_price_pos_abs': 1000, |
|
'bonitoo_price_neg_abs': 200, |
|
'bonitoo_price_pos_perc': 0.05, |
|
'bonitoo_price_neg_perc': 0.05, |
|
'num_round': 10, |
|
'max_depth': 15, |
|
'eta': 0.5, |
|
'num_class': 8, |
|
'objective': 'multi:softprob', |
|
'eval_metric': 'mlogloss' |
|
}) |
|
|
|
tf.fit({'training': train_input})
|
|
|