diff --git a/examples/clustering/fit.py b/examples/clustering/fit.py new file mode 100644 index 0000000..ee9cd74 --- /dev/null +++ b/examples/clustering/fit.py @@ -0,0 +1,24 @@ +from igel import Igel + +""" +The goal of igel is to use ML without writing code. Therefore, the right and simplest way to use igel is from terminal. +You can run ` igel fit -dp path_to_dataset -yml path_to_yaml_file`. + +Alternatively, you can write code if you want. This example below demonstrates how to use igel if you want to write code. +However, I suggest you try and use the igel CLI. Type igel -h in your terminal to know more. + +=============================================================================================================== + +This example fits a machine learning model on the indian-diabetes dataset + +- default model here is the neural network and the configuration are provided in neural-network.yaml file +- You can switch to random forest by providing the random-forest.yaml as the config file in the parameters + +""" + +mock_fit_params = {'data_path': '../data/clustering-data/train.csv', + 'yaml_path': './igel.yaml', + 'cmd': 'fit'} + +Igel(**mock_fit_params) + diff --git a/examples/clustering/igel.yaml b/examples/clustering/igel.yaml new file mode 100644 index 0000000..4e8afb2 --- /dev/null +++ b/examples/clustering/igel.yaml @@ -0,0 +1,18 @@ + +dataset: + type: csv + +# model definition +model: + type: clustering + algorithm: KMeans + arguments: # if you don't provide these arguments, then default will be used + n_clusters: 3 + init: random + n_init: 10 + max_iter: 300 + tol: 0.0004 + random_state: 0 + +# target you want to predict +target: # you can keep this empty since you want to use clustering diff --git a/examples/clustering/predict.py b/examples/clustering/predict.py new file mode 100644 index 0000000..afa6d0f --- /dev/null +++ b/examples/clustering/predict.py @@ -0,0 +1,21 @@ +from igel import Igel + + +""" +The goal of igel is to use ML without writing code. Therefore, the right and simplest way to use igel is from terminal. +You can run ` igel predict -dp path_to_dataset`. + +Alternatively, you can write code if you want. This example below demonstrates how to use igel if you want to write code. +However, I suggest you try and use the igel CLI. Type igel -h in your terminal to know more. + + +=============================================================================================================== + +This example uses the pre-fitted machine learning model to generate predictions + +""" + +mock_pred_params = {'data_path': '../data/clustering-data/train.csv', + 'cmd': 'predict'} + +Igel(**mock_pred_params) diff --git a/examples/data/clustering-data/raw_data.csv b/examples/data/clustering-data/raw_data.csv new file mode 100644 index 0000000..47af889 --- /dev/null +++ b/examples/data/clustering-data/raw_data.csv @@ -0,0 +1,151 @@ +feature1,feature2,result +2.6050973193764344,1.2252955252992357,1 +0.5323772047314387,3.3133890933364265,0 +0.8023140038834188,4.38196181200038,0 +0.5285367979496572,4.497238576378021,0 +2.6185854824861314,0.3576979057562252,1 +1.5914154189103553,4.904977251840595,0 +1.742659685725724,5.038466712398533,0 +2.375333284481674,0.08918563778251964,1 +-2.1213336421139197,2.664474084183779,2 +1.720396175444295,5.251731915463681,0 +3.136885496073223,1.5659276346561328,1 +-0.3749456643799345,2.387874349972349,2 +-1.84562252599802,2.7192463541687237,2 +0.7214439876706683,4.084750176642797,0 +0.16117090506347276,4.535178455211277,0 +-1.9991271411810305,2.7128574147318485,2 +-1.4780415293644775,3.2093591012097695,2 +1.8706766024616561,0.7779740711499736,1 +-1.5933443020153832,2.7689868216322586,2 +2.03562611231913,0.31361691106733813,1 +0.6400398546585195,4.1240107466781195,0 +2.441162797039712,1.3094157369198025,1 +1.1328039293719456,3.8767394577975276,0 +1.048291864126934,5.030924080929878,0 +-1.2663715749955262,2.6299882764265896,2 +2.316905851698755,0.8118904943268128,1 +2.362307206605918,1.3587669957212003,1 +1.209101298411725,3.535665484309778,0 +-2.5422462471126526,3.9501286920127825,2 +1.4815331952273267,0.6787536375657198,1 +-1.5948788635610658,3.48632794263447,2 +-1.8255620477045866,2.7989213964651194,2 +-1.1337400321674174,2.6846727129651513,2 +-1.7587020005449525,3.1586229982198537,2 +0.34987239852153196,4.692532505364345,0 +1.6854860232372213,1.6691709576413047,1 +2.989047001646163,1.3506859890756295,1 +1.7373444822434374,1.2358803074111866,1 +0.6591090317060133,4.122416744454821,0 +1.154453277133696,4.657073911544364,0 +-1.3273808404018403,1.5315858831197977,2 +-1.6814104977454665,2.0798803581681344,2 +0.34102757930362826,4.788485681527396,0 +1.878270565806136,0.21018801322892744,1 +2.1386042691191425,1.21517937838399,1 +2.4836828273842233,0.5721508632878634,1 +-1.1811346376868121,3.265256833161126,2 +2.1111473905402987,3.5766044901490077,0 +-1.1937124722059482,2.687522367638846,2 +1.4513142873092897,4.228108723299541,0 +1.8376907455720592,1.8222955241776078,1 +0.4408937677912238,4.831013190913958,0 +1.0804075675849252,4.792106845690246,0 +1.8484580310530043,0.523936254217558,1 +2.391414899939096,1.1013945780584922,1 +-1.4486507442392176,3.0339727794332605,2 +0.7208675097620585,3.7134712353871837,0 +3.0167385346730704,1.637921055655149,1 +-1.1819949309545112,3.568805376115622,2 +1.3408153596352634,4.368278782827096,0 +-2.3183732118894462,3.2230719508867254,2 +-0.5489478590965583,3.112928922677086,2 +-1.6823470990711413,2.9665823444675143,2 +-1.5354142201526257,3.1074581291106638,2 +1.0649831496733715,4.102896859344259,0 +-0.3972495378764793,2.8967536855028015,2 +1.039726124898305,4.504782009170741,0 +1.6246546789055267,1.8526961364874537,1 +-0.3002248293705443,4.63059662516857,0 +0.12313498323398875,5.279175025064285,0 +1.5459704208181453,3.686374417271564,0 +1.4425497620177938,1.3198451481387103,1 +2.528893505319752,0.8201586133925197,1 +0.3897083759894151,5.275597920273036,0 +1.5381461005406456,1.2384609190787503,1 +0.8204938124828087,4.3318699985632625,0 +1.5656598641263204,4.213824909542215,0 +-1.9335861428240069,2.184670097431866,2 +-1.3837321687757655,3.222304178570848,2 +0.9621789643771677,4.517953262713599,0 +1.71810119110419,0.9135789390751125,1 +1.6535626893459574,0.5528887710628365,1 +0.45199359601294875,3.5937783588589025,0 +1.1982016949192078,4.470624491135523,0 +2.204386608535906,1.5608566082814521,1 +3.2468399088648487,1.3699034034331437,1 +2.5156969333232873,1.0570274864094475,1 +-1.7983347512302643,3.1259072844638354,2 +-2.049530696295383,3.5234549061744733,2 +2.3678832469464717,0.09663483213456303,1 +2.2434802870106925,0.3479632646458396,1 +0.9991493371972181,4.210195402435473,0 +1.3096387250800752,1.1173595105702052,1 +0.7746816050599052,4.915009862639604,0 +1.7079835915671953,0.8228463897741014,1 +1.9178454270746221,3.6299077968771667,0 +-2.004876513468294,2.7448913734834295,2 +-2.1049952291328617,3.3084813121881194,2 +1.3973138161771175,0.6668713575305824,1 +2.0211467187705625,1.7543350207626203,1 +1.6703094842102197,1.1672882555838455,1 +2.5299779248957583,0.9414392806305323,1 +-2.180167439089564,3.746947601142217,2 +2.006041259220162,0.5659245167568832,1 +1.5030758517186527,0.9237461995684241,1 +1.0537437913949532,4.492868587249477,0 +-1.726628527853343,3.102910205256882,2 +1.7233096151252982,4.201208195565489,0 +0.9246606526497161,4.509086578417576,0 +0.3936951581548168,4.754200570925484,0 +-1.313774647832855,3.2563362788482455,2 +0.7826066698425189,4.152635952160722,0 +1.8275012696742063,0.9064032394504652,1 +-1.2649585013048026,2.9620933048554536,2 +0.9815200889069053,5.196722574401307,0 +-2.4950439161451587,3.012271559730037,2 +1.009528689738079,4.455023276318281,0 +1.408488177976248,3.932704817245169,0 +-1.280033124547311,2.8598302918159395,2 +-1.8250610324471594,2.8915986131983495,2 +0.5408715039555542,4.014362495066182,0 +2.6492824176090286,1.0561349659003616,1 +0.522620896354874,4.329760025346459,0 +0.1693211547675193,4.197417187341405,0 +1.8062512960867254,1.8624296868464296,1 +1.9212658359571877,1.2988918578361344,1 +-1.539067075421373,2.5488668067302784,2 +1.682890110408658,0.48444439060842964,1 +-2.2973025204442177,2.9495132584332886,2 +-1.4559274315615156,2.7582180527605753,2 +-1.3869417137002862,2.8688070665116436,2 +-1.0718145591756094,3.0764913689287736,2 +1.4088390665711537,1.0311890946125284,1 +-1.5859860357861904,2.5777931593346977,2 +-1.5821743418381418,3.4279686171910226,2 +-0.779661740976221,1.8828897488263565,2 +0.569696937524268,3.4406460262825513,0 +-1.8531083044154153,2.722405573739322,2 +1.5988564087107986,1.4561718039858633,1 +-1.840947793042695,2.677368702102729,2 +1.3567889411199918,4.364624835694804,0 +1.1774408991352696,3.961382281978233,0 +1.7334583200164326,-0.21403791617427648,1 +2.343562929740348,0.7935142821489394,1 +-0.9507382308303454,3.4576915573515334,2 +-2.2389344677131713,2.671222319652026,2 +-1.872928937114101,3.6860707884560218,2 +-1.8897027024536976,2.226200283635595,2 +2.2532708777637005,0.35113290557268395,1 +1.5551598477380955,0.12527811154913104,1 diff --git a/examples/data/clustering-data/train.csv b/examples/data/clustering-data/train.csv new file mode 100644 index 0000000..5bdfe98 --- /dev/null +++ b/examples/data/clustering-data/train.csv @@ -0,0 +1,151 @@ +feature1,feature2 +2.6050973193764344,1.2252955252992357 +0.5323772047314387,3.3133890933364265 +0.8023140038834188,4.38196181200038 +0.5285367979496572,4.497238576378021 +2.6185854824861314,0.3576979057562252 +1.5914154189103553,4.904977251840595 +1.742659685725724,5.038466712398533 +2.375333284481674,0.08918563778251964 +-2.1213336421139197,2.664474084183779 +1.720396175444295,5.251731915463681 +3.136885496073223,1.5659276346561328 +-0.3749456643799345,2.387874349972349 +-1.84562252599802,2.7192463541687237 +0.7214439876706683,4.084750176642797 +0.16117090506347276,4.535178455211277 +-1.9991271411810305,2.7128574147318485 +-1.4780415293644775,3.2093591012097695 +1.8706766024616561,0.7779740711499736 +-1.5933443020153832,2.7689868216322586 +2.03562611231913,0.31361691106733813 +0.6400398546585195,4.1240107466781195 +2.441162797039712,1.3094157369198025 +1.1328039293719456,3.8767394577975276 +1.048291864126934,5.030924080929878 +-1.2663715749955262,2.6299882764265896 +2.316905851698755,0.8118904943268128 +2.362307206605918,1.3587669957212003 +1.209101298411725,3.535665484309778 +-2.5422462471126526,3.9501286920127825 +1.4815331952273267,0.6787536375657198 +-1.5948788635610658,3.48632794263447 +-1.8255620477045866,2.7989213964651194 +-1.1337400321674174,2.6846727129651513 +-1.7587020005449525,3.1586229982198537 +0.34987239852153196,4.692532505364345 +1.6854860232372213,1.6691709576413047 +2.989047001646163,1.3506859890756295 +1.7373444822434374,1.2358803074111866 +0.6591090317060133,4.122416744454821 +1.154453277133696,4.657073911544364 +-1.3273808404018403,1.5315858831197977 +-1.6814104977454665,2.0798803581681344 +0.34102757930362826,4.788485681527396 +1.878270565806136,0.21018801322892744 +2.1386042691191425,1.21517937838399 +2.4836828273842233,0.5721508632878634 +-1.1811346376868121,3.265256833161126 +2.1111473905402987,3.5766044901490077 +-1.1937124722059482,2.687522367638846 +1.4513142873092897,4.228108723299541 +1.8376907455720592,1.8222955241776078 +0.4408937677912238,4.831013190913958 +1.0804075675849252,4.792106845690246 +1.8484580310530043,0.523936254217558 +2.391414899939096,1.1013945780584922 +-1.4486507442392176,3.0339727794332605 +0.7208675097620585,3.7134712353871837 +3.0167385346730704,1.637921055655149 +-1.1819949309545112,3.568805376115622 +1.3408153596352634,4.368278782827096 +-2.3183732118894462,3.2230719508867254 +-0.5489478590965583,3.112928922677086 +-1.6823470990711413,2.9665823444675143 +-1.5354142201526257,3.1074581291106638 +1.0649831496733715,4.102896859344259 +-0.3972495378764793,2.8967536855028015 +1.039726124898305,4.504782009170741 +1.6246546789055267,1.8526961364874537 +-0.3002248293705443,4.63059662516857 +0.12313498323398875,5.279175025064285 +1.5459704208181453,3.686374417271564 +1.4425497620177938,1.3198451481387103 +2.528893505319752,0.8201586133925197 +0.3897083759894151,5.275597920273036 +1.5381461005406456,1.2384609190787503 +0.8204938124828087,4.3318699985632625 +1.5656598641263204,4.213824909542215 +-1.9335861428240069,2.184670097431866 +-1.3837321687757655,3.222304178570848 +0.9621789643771677,4.517953262713599 +1.71810119110419,0.9135789390751125 +1.6535626893459574,0.5528887710628365 +0.45199359601294875,3.5937783588589025 +1.1982016949192078,4.470624491135523 +2.204386608535906,1.5608566082814521 +3.2468399088648487,1.3699034034331437 +2.5156969333232873,1.0570274864094475 +-1.7983347512302643,3.1259072844638354 +-2.049530696295383,3.5234549061744733 +2.3678832469464717,0.09663483213456303 +2.2434802870106925,0.3479632646458396 +0.9991493371972181,4.210195402435473 +1.3096387250800752,1.1173595105702052 +0.7746816050599052,4.915009862639604 +1.7079835915671953,0.8228463897741014 +1.9178454270746221,3.6299077968771667 +-2.004876513468294,2.7448913734834295 +-2.1049952291328617,3.3084813121881194 +1.3973138161771175,0.6668713575305824 +2.0211467187705625,1.7543350207626203 +1.6703094842102197,1.1672882555838455 +2.5299779248957583,0.9414392806305323 +-2.180167439089564,3.746947601142217 +2.006041259220162,0.5659245167568832 +1.5030758517186527,0.9237461995684241 +1.0537437913949532,4.492868587249477 +-1.726628527853343,3.102910205256882 +1.7233096151252982,4.201208195565489 +0.9246606526497161,4.509086578417576 +0.3936951581548168,4.754200570925484 +-1.313774647832855,3.2563362788482455 +0.7826066698425189,4.152635952160722 +1.8275012696742063,0.9064032394504652 +-1.2649585013048026,2.9620933048554536 +0.9815200889069053,5.196722574401307 +-2.4950439161451587,3.012271559730037 +1.009528689738079,4.455023276318281 +1.408488177976248,3.932704817245169 +-1.280033124547311,2.8598302918159395 +-1.8250610324471594,2.8915986131983495 +0.5408715039555542,4.014362495066182 +2.6492824176090286,1.0561349659003616 +0.522620896354874,4.329760025346459 +0.1693211547675193,4.197417187341405 +1.8062512960867254,1.8624296868464296 +1.9212658359571877,1.2988918578361344 +-1.539067075421373,2.5488668067302784 +1.682890110408658,0.48444439060842964 +-2.2973025204442177,2.9495132584332886 +-1.4559274315615156,2.7582180527605753 +-1.3869417137002862,2.8688070665116436 +-1.0718145591756094,3.0764913689287736 +1.4088390665711537,1.0311890946125284 +-1.5859860357861904,2.5777931593346977 +-1.5821743418381418,3.4279686171910226 +-0.779661740976221,1.8828897488263565 +0.569696937524268,3.4406460262825513 +-1.8531083044154153,2.722405573739322 +1.5988564087107986,1.4561718039858633 +-1.840947793042695,2.677368702102729 +1.3567889411199918,4.364624835694804 +1.1774408991352696,3.961382281978233 +1.7334583200164326,-0.21403791617427648 +2.343562929740348,0.7935142821489394 +-0.9507382308303454,3.4576915573515334 +-2.2389344677131713,2.671222319652026 +-1.872928937114101,3.6860707884560218 +-1.8897027024536976,2.226200283635595 +2.2532708777637005,0.35113290557268395 +1.5551598477380955,0.12527811154913104 diff --git a/examples/indian-diabetes-example/model_results/description.json b/examples/indian-diabetes-example/model_results/description.json deleted file mode 100644 index 8cfa392..0000000 --- a/examples/indian-diabetes-example/model_results/description.json +++ /dev/null @@ -1,46 +0,0 @@ -{ - "model": "MLPClassifier", - "arguments": "default", - "type": "classification", - "algorithm": "NeuralNetwork", - "dataset_props": { - "type": "csv", - "split": { - "test_size": 0.2, - "shuffle": true - }, - "preprocess": { - "missing_values": "mean", - "scale": { - "method": "standard", - "target": "inputs" - } - } - }, - "model_props": { - "type": "classification", - "algorithm": "NeuralNetwork" - }, - "data_path": "../data/indian-diabetes/train-indians-diabetes.csv", - "train_data_shape": [ - 614, - 8 - ], - "test_data_shape": [ - 154, - 8 - ], - "train_data_size": 614, - "test_data_size": 154, - "results_path": "/home/nidhal/projects/igel/examples/indian-diabetes-example/model_results", - "model_path": "/home/nidhal/projects/igel/examples/indian-diabetes-example/model_results/model.sav", - "target": [ - "sick" - ], - "results_on_test_data": { - "accuracy_score": 0.7142857142857143, - "f1_score": 0.5217391304347826, - "precision_score": 0.6, - "recall_score": 0.46153846153846156 - } -} \ No newline at end of file diff --git a/examples/indian-diabetes-example/model_results/model.sav b/examples/indian-diabetes-example/model_results/model.sav deleted file mode 100644 index a1758c2..0000000 Binary files a/examples/indian-diabetes-example/model_results/model.sav and /dev/null differ diff --git a/igel/configs.py b/igel/configs.py index 2cbc853..1c81ad6 100644 --- a/igel/configs.py +++ b/igel/configs.py @@ -25,7 +25,7 @@ "type": "csv", "split": { - "test_size": 0.2, + "test_size": 0.1, "shuffle": True }, "preprocess": { diff --git a/igel/igel.py b/igel/igel.py index bb94fff..8428336 100644 --- a/igel/igel.py +++ b/igel/igel.py @@ -86,7 +86,7 @@ def __init__(self, **cli_args): with open(self.description_file, 'r') as f: dic = json.load(f) self.target: list = dic.get("target") # target to predict as a list - self.model_type: str = dic.get("type") # type of the model -> regression or classification + self.model_type: str = dic.get("type") # type of the model -> regression, classification or clustering self.dataset_props: dict = dic.get('dataset_props') # dataset props entered while fitting getattr(self, self.command)() @@ -183,8 +183,9 @@ def _process_data(self, target='fit'): read and return data as x and y @return: list of separate x and y """ - assert isinstance(self.target, list), "provide target(s) as a list in the yaml file" + if self.model_type != "clustering": + assert isinstance(self.target, list), "provide target(s) as a list in the yaml file" assert len(self.target) > 0, "please provide at least a target to predict" try: @@ -411,8 +412,8 @@ def fit(self, **kwargs): } if self.model_type == 'clustering': clustering_res = { - "cluster_centers": self.model.cluster_centers_, - "cluster_labels": self.model.labels_ + "cluster_centers": self.model.cluster_centers_.tolist(), + "cluster_labels": self.model.labels_.tolist() } fit_description['clustering_results'] = clustering_res @@ -475,11 +476,13 @@ def predict(self): y_pred = _reshape(y_pred) logger.info(f"predictions shape: {y_pred.shape} | shape len: {len(y_pred.shape)}") logger.info(f"predict on targets: {self.target}") + if not self.target: + self.target = ['result'] df_pred = pd.DataFrame.from_dict( {self.target[i]: y_pred[:, i] if len(y_pred.shape) > 1 else y_pred for i in range(len(self.target))}) logger.info(f"saving the predictions to {self.prediction_file}") - df_pred.to_csv(self.prediction_file) + df_pred.to_csv(self.prediction_file, index=False) except Exception as e: logger.exception(f"Error while preparing predictions: {e}")