2 回答

TA貢獻1843條經驗 獲得超7個贊
import sklearn
import pandas as pd
def tree_to_df(reg_tree, feature_names):
? ? tree_ = reg_tree.tree_
? ? feature_name = [
? ? ? ? feature_names[i] if i != sklearn.tree._tree.TREE_UNDEFINED else "undefined!"
? ? ? ? for i in tree_.feature
? ? ]
? ??
? ? def recurse(node, row, ret):
? ? ? ? if tree_.feature[node] != sklearn.tree._tree.TREE_UNDEFINED:
? ? ? ? ? ? name = feature_name[node]
? ? ? ? ? ? threshold = tree_.threshold[node]
? ? ? ? ? ? # Add rule to row and search left branch
? ? ? ? ? ? row[-1].append(name + " <= " +? str(threshold))
? ? ? ? ? ? recurse(tree_.children_left[node], row, ret)
? ? ? ? ? ? # Add rule to row and search right branch
? ? ? ? ? ? row[-1].append(name + " > " +? str(threshold))
? ? ? ? ? ? recurse(tree_.children_right[node], row, ret)
? ? ? ? else:
? ? ? ? ? ? # Add output rules and start a new row
? ? ? ? ? ? label = tree_.value[node]
? ? ? ? ? ? ret.append("return " + str(label[0][0]))
? ? ? ? ? ? row.append([])
? ??
? ? # Initialize
? ? rules = [[]]
? ? vals = []
? ??
? ? # Call recursive function with initial values
? ? recurse(0, rules, vals)
? ??
? ? # Convert to table and output
? ? df = pd.DataFrame(rules).dropna(how='all')
? ? df['Return'] = pd.Series(vals)
? ? return df
這將返回一個 pandas 數據框:
? ? ? ? ? ? ? ? ? ? ?0? ? ? ? ? ? ? ? ? ?1? ? ? ? ? ? ? ? ? ?2? ? ? ? ? ? ? ? ?3? ? ? ? ? Return
0? ?feature <= 20750.0? ?feature <= 7000.0? ?feature <= 1000.0? feature <= 300.0? ?return 1000.0
1? ? ? feature > 300.0? ? ? ? ? ? ? ? None? ? ? ? ? ? ? ? None? ? ? ? ? ? ? None? ?return 3000.0
2? ? ?feature > 1000.0? ?feature <= 2500.0? ? ? ? ? ? ? ? None? ? ? ? ? ? ? None? ?return 5000.0
3? ? ?feature > 2500.0? ?feature <= 4250.0? ? ? ? ? ? ? ? None? ? ? ? ? ? ? None? ?return 8000.0
4? ? ?feature > 4250.0? ?feature <= 5500.0? ? ? ? ? ? ? ? None? ? ? ? ? ? ? None? ?return 6500.0
5? ? ?feature > 5500.0? ? ? ? ? ? ? ? None? ? ? ? ? ? ? ? None? ? ? ? ? ? ? None? ?return 7000.0
6? ? ?feature > 7000.0? feature <= 13000.0? ?feature <= 8750.0? ? ? ? ? ? ? None? return 15000.0
7? ? ?feature > 8750.0? feature <= 10750.0? ? ? ? ? ? ? ? None? ? ? ? ? ? ? None? return 20000.0
8? ? feature > 10750.0? ? ? ? ? ? ? ? None? ? ? ? ? ? ? ? None? ? ? ? ? ? ? None? return 21000.0
9? ? feature > 13000.0? feature <= 16000.0? feature <= 14750.0? ? ? ? ? ? ? None? return 25000.0
10? ?feature > 14750.0? ? ? ? ? ? ? ? None? ? ? ? ? ? ? ? None? ? ? ? ? ? ? None? return 27000.0
11? ?feature > 16000.0? ? ? ? ? ? ? ? None? ? ? ? ? ? ? ? None? ? ? ? ? ? ? None? return 30000.0
12? ?feature > 20750.0? feature <= 27500.0? ? ? ? ? ? ? ? None? ? ? ? ? ? ? None? return 52000.0
13? ?feature > 27500.0? ? ? ? ? ? ? ? None? ? ? ? ? ? ? ? None? ? ? ? ? ? ? None? return 80000.0

TA貢獻1852條經驗 獲得超7個贊
如果您正在處理分類決策樹,您可以嘗試一下
import pandas as pd
text="""
|--- Age <= 0.63
| |--- EstimatedSalary <= 0.61
| | |--- Age <= -0.16
| | | |--- class: 0
| | |--- Age > -0.16
| | | |--- EstimatedSalary <= -0.06
| | | | |--- class: 0
| | | |--- EstimatedSalary > -0.06
| | | | |--- EstimatedSalary <= 0.40
| | | | | |--- EstimatedSalary <= 0.03
| | | | | | |--- class: 1
"""
def tree_parser(text):
lines=text.splitlines()
max_levels=max([l.count('|') for l in lines])
result={}
for i in range(0,max_levels+1):
result['Column'+str(i)]=[]
for line in lines:
level=line.count('|')
currvalue=result.get('Column'+str(level),[])
currvalue.append(line.replace('|','').replace('-',''))
result['Column'+str(level)]=currvalue
for i in range(0, max_levels + 1):
if i>level and line.find('class')!=-1:
result['Column' + str(i)].append(None)
if i<level:
parent_value=result.get('Column' + str(i),[])
if len(parent_value)!=len(currvalue):
parent_value.append(parent_value[len(parent_value)-1])
return result
result=tree_parser(text)
df=pd.DataFrame(result)
df=df.drop(columns=['Column0'])
df.to_csv('treeout1.csv',index=False)
添加回答
舉報