kaggle 预测房价

article/2025/8/25 19:22:45

利用简单的线性模型,训练kaggle 房屋数据集:

import os
import random
import tarfile
import time
import zipfile
import pandas as pd
import requests
import torch
from torch import nn
from torch.utils import data
from matplotlib import pyplot as plt
from matplotlib_inline import backend_inlinedef plot(X, Y=None, xlabel=None, ylabel=None, legend=None, xlim=None,ylim=None, xscale='linear', yscale='linear',fmts=('-', 'm--', 'g-.', 'r:'), figsize=(3.5, 2.5), axes=None):if legend is None:legend = []set_figsize(figsize)axes = axes if axes else plt.gca()# 如果X有一个轴,输出Truedef has_one_axis(X):return (hasattr(X, "ndim") and X.ndim == 1 or isinstance(X, list)and not hasattr(X[0], "__len__"))if has_one_axis(X):X = [X]if Y is None:X, Y = [[]] * len(X), Xelif has_one_axis(Y):Y = [Y]if len(X) != len(Y):X = X * len(Y)axes.cla()for x, y, fmt in zip(X, Y, fmts):if len(x):axes.plot(x, y, fmt)else:axes.plot(y, fmt)set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend)def use_svg_display():backend_inline.set_matplotlib_formats('svg')def set_figsize(figsize=(3.5, 2.5)):use_svg_display()plt.rcParams['figure.figsize'] = figsizedef set_axes(axes, xlabel, ylabel, xlim, ylim, xscale, yscale, legend):axes.set_xlabel(xlabel)axes.set_ylabel(ylabel)axes.set_xscale(xscale)axes.set_yscale(yscale)axes.set_xlim(xlim)axes.set_ylim(ylim)if legend:axes.legend(legend)axes.grid()def download(name, cache_dir=os.path.join('..', 'data')):assert name in DATA_HUB, f"{name} 不存在于 {DATA_HUB}"url, sha1_hash = DATA_HUB[name]os.makedirs(cache_dir, exist_ok=True)fname = os.path.join(cache_dir, url.split('/')[-1])if os.path.exists(fname):sha1 = hashlib.sha1()with open(fname, 'rb') as f:while True:data = f.read(1048576)if not data:breaksha1.update(data)if sha1.hexdigest() == sha1_hash:return fname  # 命中缓存print(f'正在从{url}下载{fname}...')r = requests.get(url, stream=True, verify=True)with open(fname, 'wb') as f:f.write(r.content)return fnamedef download_extract(name, folder=None):fname = download(name)base_dir = os.path.dirname(fname)data_dir, ext = os.path.splitext(fname)if ext == '.zip':fp = zipfile.ZipFile(fname, 'r')elif ext in ('.tar', '.gz'):fp = tarfile.open(fname, 'r')else:assert False, '只有zip/tar文件可以被解压缩'fp.extractall(base_dir)return os.path.join(base_dir, folder) if folder else data_dirdef download_all():for name in DATA_HUB:download(name)DATA_HUB = dict()
DATA_URL = 'http://d2l-data.s3-accelerate.amazonaws.com/'DATA_HUB['kaggle_house_train'] = (DATA_URL + 'kaggle_house_pred_train.csv','585e9cc93e70b39160e7921475f9bcd7d31219ce')DATA_HUB['kaggle_house_test'] = (DATA_URL + 'kaggle_house_pred_test.csv','fa19780a7b011d9b009e8bff8e99922a8ee2eb90')train_data = pd.read_csv(download('kaggle_house_train'))
test_data = pd.read_csv(download('kaggle_house_test'))all_features = pd.concat((train_data.iloc[:,1:-1], test_data.iloc[:,1:])) #测试数据集没有标签numeric_features_index_name = all_features.dtypes[all_features.dtypes != 'object'].indexall_features[numeric_features_index_name] = all_features[numeric_features_index_name].apply(lambda x : (x-x.mean())/(x.std()))all_features[numeric_features_index_name] = all_features[numeric_features_index_name].fillna(0)all_features = pd.get_dummies(all_features, dummy_na=True)
bool_columns = all_features.select_dtypes(include=['bool']).columns
all_features[bool_columns] = all_features[bool_columns].astype(int)sample_count_train = train_data.shape[0]train_features = torch.tensor(all_features[:sample_count_train].values, dtype=torch.float32)
test_features = torch.tensor(all_features[sample_count_train:].values, dtype=torch.float32)
train_labels = torch.tensor(train_data.SalePrice.values.reshape(-1, 1), dtype=torch.float32)loss = nn.MSELoss()
in_features = train_features.shape[1]def get_net():net = nn.Sequential(nn.Linear(in_features,1))return netdef log_rmse(net, features, labels):clipped_preds = torch.clamp(net(features), 1, float('inf'))rmse = torch.sqrt(loss(torch.log(clipped_preds), torch.log(labels)))return rmse.item()def load_array(data_arrays, batch_size, is_train=True):dataset = data.TensorDataset(*data_arrays)return data.DataLoader(dataset, batch_size, shuffle=is_train)def train(net, train_features, train_labels, test_features, test_labels,num_epochs, learning_rate, weight_decay, batch_size):train_ls, test_ls = [], []train_iter = load_array((train_features, train_labels), batch_size)optimizer = torch.optim.Adam(net.parameters(), lr = learning_rate, weight_decay = weight_decay)for epoch in range(num_epochs):for X, y in train_iter:optimizer.zero_grad()l = loss(net(X), y)l.backward()optimizer.step()train_ls.append(log_rmse(net, train_features, train_labels))if test_labels is not None:test_ls.append(log_rmse(net, test_features, test_labels))return train_ls, test_lsdef get_k_fold_data(k, i, X, y):assert k > 1fold_size = X.shape[0] // kX_train, y_train = None, Nonefor j in range(k):idx = slice(j * fold_size, (j + 1) * fold_size)X_part, y_part = X[idx, :], y[idx]if j == i:X_valid, y_valid = X_part, y_partelif X_train is None:X_train, y_train = X_part, y_partelse:X_train = torch.cat([X_train, X_part], 0)y_train = torch.cat([y_train, y_part], 0)return X_train, y_train, X_valid, y_validdef k_fold(k, X_train, y_train, num_epochs, learning_rate, weight_decay, batch_size):train_l_sum, valid_l_sum = 0, 0for i in range(k):data = get_k_fold_data(k, i, X_train, y_train)# 检查训练集是否为空if data[0].shape[0] == 0:raise ValueError(f"折 {i+1} 的训练集为空!")net = get_net()train_ls, valid_ls = train(net, *data, num_epochs, learning_rate, weight_decay, batch_size)train_l_sum += train_ls[-1]valid_l_sum += valid_ls[-1]if i == 0:plot(list(range(1, num_epochs + 1)), [train_ls, valid_ls], xlabel='epoch', ylabel='rmse',xlim=[1, num_epochs], legend=['train','valid'], yscale='log')print(f'折{i + 1},训练log rmse{float(train_ls[-1]):f}, 'f'验证log rmse{float(valid_ls[-1]):f}')return train_l_sum / k, valid_l_sum / kk, num_epochs, lr, weight_decay, batch_size = 5, 100, 5, 0, 64
train_l, valid_l = k_fold(k, train_features, train_labels, num_epochs, lr, weight_decay, batch_size)
print(f'{k}-折验证:平均训练 log rmse: {float(train_l):f}, 平均验证log rmse: {float(valid_l):f}')

结果如下:

折1,训练log rmse0.170301, 验证log rmse0.156731
折2,训练log rmse0.162383, 验证log rmse0.191084
折3,训练log rmse0.164671, 验证log rmse0.168909
折4,训练log rmse0.167769, 验证log rmse0.154724
折5,训练log rmse0.162871, 验证log rmse0.182890
5-折验证:平均训练 log rmse: 0.165599, 平均验证log rmse: 0.17086

可以看到,随着训练,损失越来越小。


http://www.hkcw.cn/article/RtRUmplHYy.shtml

相关文章

ASP.NET Core SignalR的基本使用

文章目录 前言一、SignalR是什么?在 ASP.NET Core 中的关键特性:SignalR 工作原理简图: 二、使用步骤1.创建ASP.NET Core web Api 项目2.添加 SignalR 包3.创建 SignalR Hub4.配置服务与中间件5.创建控制器(模拟服务器向客户端发送消息)6.创建…

AI书签管理工具开发全记录(七):页面编写与接口对接

文章目录 AI书签管理工具开发全记录(七):页面编写与接口对接前言 📝1. 页面功能规划 📌2. 接口api编写 📡2.1 创建.env,设置环境变量2.2 增加axios拦截器2.3 创建接口 2. 页面编写 📄2.1 示例代…

“AI 编程三国杀” Google Jules, OpenAl Codex, Claude Code,人类开始沦为AI编程发展的瓶颈?

AI 编程三国杀:Google Jules, OpenAI Codex, Claude code “AI 编程三国杀”是一个形象的比喻,借指当前 AI 编程领域中几个主要参与者之间的激烈竞争与并存的局面。这其中,Google、OpenAI 以及 Anthropic (Claude 的开发者) 是重要的“国家”,而它们各自的 AI 编程工具则是…

【Redis技术进阶之路】「原理分析系列开篇」分析客户端和服务端网络诵信交互实现(服务端执行命令请求的过程 - 文件事件处理部分)

分析客户端和服务端网络诵信交互实现 【专栏简介】【技术大纲】【专栏目标】【目标人群】1. Redis爱好者与社区成员2. 后端开发和系统架构师3. 计算机专业的本科生及研究生 命令请求的执行过程案例分析介绍发送命令请求读取命令请求客户端状态的argv属性和argc属性命令执行器第…

第29次CCF计算机软件能力认证-3-LDAP

LDAP 刷新 时间限制: 10.0 秒 空间限制: 512 MiB 下载题目目录(样例文件) 题目背景 西西艾弗岛运营公司是一家负责维护和运营岛上基础设施的大型企业,拥有数千名员工。公司内有很多 IT 系统。 为了能够实现这些…

2025年- H63-Lc171--33.搜索旋转排序数组(2次二分查找,需二刷)--Java版

1.题目描述 2.思路 输入:旋转后的数组 nums,和一个整数 target 输出:target 在 nums 中的下标,如果不存在,返回 -1 限制:时间复杂度为 O(log n),所以不能用遍历,必须使用 二分查找…

HomeKit 基本理解

概括 HomeKit 将用户的家庭自动化信息存储在数据库中,该数据库由苹果的内置iOS家庭应用程序、支持HomeKit的应用程序和其他开发人员的应用程序共享。所有这些应用程序都使用HomeKit框架作为对等程序访问数据库. Home 只是相当于 HomeKit 的表现层,其他应用在实现 …

秒杀系统—5.第二版升级优化的技术文档三

大纲 8.秒杀系统的秒杀库存服务实现 9.秒杀系统的秒杀抢购服务实现 10.秒杀系统的秒杀下单服务实现 11.秒杀系统的页面渲染服务实现 12.秒杀系统的页面发布服务实现 8.秒杀系统的秒杀库存服务实现 (1)秒杀商品的库存在Redis中的结构 (2)库存分片并同步到Redis的实现 (3…

尚硅谷-尚庭公寓知识点

文章目录 尚庭公寓知识点1、转换器(Converter)2、全局异常3、定时任务1. 核心步骤(1) 启用定时任务(2) 创建定时任务 2. Scheduled 参数详解3. Cron 表达式语法4. 配置线程池(避免阻塞)5. 动态控制任务(高级用法)6. 注意事项 4、M…

字符串~~~

字符串~~ KMP例题1.无线传输2.删除字符串3.二叉树中的链表 AC自动机Manacher例题 扩展KMP字符串哈希 KMP (1) (2) (3) 经典例题 https://leetcode.cn/problems/find-the-index-of-the-first-occurre…

WEB3——简易NFT铸造平台之nft.storage

🧠 1. nft.storage 是什么? https://nft.storage 是 一个免费的去中心化存储平台,由 Filecoin 背后的 Protocol Labs 推出。 它的作用是: ✅ 接收用户上传的文件(图片、JSON 等) ✅ 把它们永久存储到 IPFS…

MCP架构全解析:从核心原理到企业级实践

💝💝💝欢迎莅临我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:「storms…

注销微软账户

若你需要注销微软账号,请点击下方超链接。 点击此处

华为OD机试真题——生成哈夫曼树(2025A卷:100分)Java/python/JavaScript/C/C++/GO六种最佳实现

2025 A卷 100分 题型 本文涵盖详细的问题分析、解题思路、代码实现、代码详解、测试用例以及综合分析; 并提供Java、python、JavaScript、C++、C语言、GO六种语言的最佳实现方式! 本文收录于专栏:《2025华为OD真题目录+全流程解析/备考攻略/经验分享》 华为OD机试真题《生成…

Python实现P-PSO优化算法优化BP神经网络分类模型项目实战

说明:这是一个机器学习实战项目(附带数据代码文档),如需数据代码文档可以直接到文章最后关注获取。 1.项目背景 随着人工智能技术的快速发展,神经网络在分类任务中展现了强大的性能。BP(Back Propagation&…

学习海康VisionMaster之表面缺陷滤波

一:进一步学习了 今天学习下VisionMaster中的表面缺陷滤波:简单、无纹理背景的表面缺陷检测,可以检测表面的异物,缺陷,划伤等 二:开始学习 1:什么表面缺陷滤波? 表面缺陷滤波的核心…

34.x64汇编写法(一)

免责声明:内容仅供学习参考,请合法利用知识,禁止进行违法犯罪活动! 本次游戏没法给 内容参考于:微尘网络安全 上一个内容:33.第二阶段x64游戏实战-InLineHook 首先打开 Visual Studio,然后创…

Java网络编程实战:TCP/UDP Socket通信详解与高并发服务器设计

🔍 开发者资源导航 🔍🏷️ 博客主页: 个人主页📚 专栏订阅: JavaEE全栈专栏 内容: socket(套接字)TCP和UDP差别UDP编程方法使用简单服务器实现 TCP编程方法Socket和ServerSocket之间的关系使用简…

算法:滑动窗口

1.长度最小的子数组 209. 长度最小的子数组 - 力扣(LeetCode) 运用滑动窗口(同向双指针)来解决,因为这些数字全是正整数,在left位置确定的下,right这个总sum会越大,所以我们先让num…

AI笔记 - 网络模型 - mobileNet

网络模型 mobileNet mobileNet V1网络结构深度可分离卷积空间可分![在这里插入图片描述](https://i-blog.csdnimg.cn/direct/aff06377feac40b787cfc882be7c6e5d.png) 参考 mobileNet V1 网络结构 MobileNetV1可以理解为VGG中的标准卷积层换成深度可分离卷积 可分离卷积主要有…