Java自学者论坛

 找回密码
 立即注册

手机号码,快捷登录

恭喜Java自学者论坛(https://www.javazxz.com)已经为数万Java学习者服务超过8年了!积累会员资料超过10000G+
成为本站VIP会员,下载本站10000G+会员资源,会员资料板块,购买链接:点击进入购买VIP会员

JAVA高级面试进阶训练营视频教程

Java架构师系统进阶VIP课程

分布式高可用全栈开发微服务教程Go语言视频零基础入门到精通Java架构师3期(课件+源码)
Java开发全终端实战租房项目视频教程SpringBoot2.X入门到高级使用教程大数据培训第六期全套视频教程深度学习(CNN RNN GAN)算法原理Java亿级流量电商系统视频教程
互联网架构师视频教程年薪50万Spark2.0从入门到精通年薪50万!人工智能学习路线教程年薪50万大数据入门到精通学习路线年薪50万机器学习入门到精通教程
仿小米商城类app和小程序视频教程深度学习数据分析基础到实战最新黑马javaEE2.1就业课程从 0到JVM实战高手教程MySQL入门到精通教程
查看: 940|回复: 0

梯度下降算法解决多元线性回归问题 c++实现

[复制链接]
  • TA的每日心情
    奋斗
    2024-4-6 11:05
  • 签到天数: 748 天

    [LV.9]以坛为家II

    2034

    主题

    2092

    帖子

    70万

    积分

    管理员

    Rank: 9Rank: 9Rank: 9

    积分
    705612
    发表于 2021-6-26 23:51:23 | 显示全部楼层 |阅读模式

    没有数据标准化的版本,效率非常低,而且训练结果并不好。

    #include <iostream>
    #define maxn 105
    #include <cstdio>
    #include <cmath>
    using namespace std;
    int n,m;  //n个特征,m个数据
    double theta[maxn];//参数集
    double temp[maxn];
    double data[maxn][maxn];//数据集
    double Y[maxn];//结果集
    double hx[maxn];
    const double eps=1e-9;
    double alpha=0.00001;
    double h(int x)//计算假设函数
    {
        double res=0;
        for(int i=0;i<=n;++i)
        {
            res+=theta*data[x];
        }
        return res;
    }
    double J_theta()//计算cost function
    {
        double res=0;
        for(int i=1;i<=m;++i)
        {
            res+=(h(i)-Y)*(h(i)-Y);
        }
        res=res/(2*m);
        return res;
    }
    double f(int x)//求偏导数
    {
        double res=0;
        for(int i=1;i<=m;++i)
        {
            res+=hx*data[x];
        }
        res/=m;
        return res;
    }
    void Gradient_Descent()//梯度下降
    {
        for(int i=1;i<=m;++i)
        {
            data[0]=1;
        }
        for(int i=0;i<=n;++i)
        {
            theta=1;//初始化
        }
        double now,nex;
        do
        {
            now=J_theta();
            for(int i=1;i<=m;++i)
            {
                hx=h(i)-Y;
            }
            for(int i=0;i<=n;++i)
            {
                temp=theta-alpha*f(i);
            }
            for(int i=0;i<=n;++i)
            {
                theta=temp;
            }
            nex=J_theta();
            //cout<<J_theta()<<endl;
        }while (abs(nex-now)>eps);
    }
    int main()
    {
        freopen("in.txt","r",stdin);
        cin>>n>>m;
        for(int i=1;i<=m;++i)
        {
            for(int j=1;j<=n;++j)
            {
                cin>>data[j];
            }
        }
        for(int i=1;i<=m;++i)
        {
            cin>>Y;
        }
        Gradient_Descent();
        for(int i=0;i<=n;++i)
        {
            printf("%.2lf\n",theta);
        }
        return 0;
    }
    

     

    下面是将数据归一化之后的版本,效率较高:

    #include <iostream>
    #define maxn 105
    
    #include <cmath>
    #include <algorithm>
    #include <cstdio>
    using namespace std;
    int n,m;  //n个特征,m个数据
    double theta[maxn];//参数集
    double temp[maxn];
    double data[maxn][maxn];//数据集
    double Y[maxn];//结果集
    double hx[maxn];
    const double eps=1e-9;
    double alpha=0.001;
    double ave[maxn];
    void Mean_Normaliazation()
    {
    
        for(int i=0;i<=n;++i)
        {
            double maxim=-1e9;
            double minum=1e9;
            double tmp=0;
            for(int j=1;j<=m;++j)
            {
                tmp+=data[j];
            }
            tmp/=m;
            double mb=0;
            for (int j=1;j<=m;++j)
            {
                mb+=(data[j]-tmp)*(data[j]-tmp);
            }
            mb/=m;
            mb=sqrt(mb);
            for(int j=1;j<=m;++j)
            {
                data[j]=(data[j]-tmp)/mb;
            }
        }
        double maxim=-1e9;
        /*double tmp=0;
        for(int i=1;i<=m;++i)
        {
            maxim=max(Y,maxim);
            tmp+=Y;
        }
        tmp/=m;
        for(int i=1;i<=m;++i)
        {
            Y=(Y-tmp)/maxim;
        }*/
    }
    double h(int x)//计算假设函数
    {
        double res=0;
        for(int i=0;i<=n;++i)
        {
            res+=theta*data[x];
        }
        return res;
    }
    double J_theta()//计算cost function
    {
        double res=0;
        for(int i=1;i<=m;++i)
        {
            res+=(h(i)-Y)*(h(i)-Y);
        }
        res=res/(2*m);
        return res;
    }
    double f(int x)//求偏导数
    {
        double res=0;
        for(int i=1;i<=m;++i)
        {
            res+=hx*data[x];
        }
        res/=m;
        return res;
    }
    void Gradient_Descent()//梯度下降
    {
        for(int i=1;i<=m;++i)
        {
            data[0]=1;
        }
        for(int i=0;i<=n;++i)
        {
            theta=1;//初始化
        }
        double now,nex;
        do
        {
            now=J_theta();
            for(int i=1;i<=m;++i)
            {
                hx=h(i)-Y;
            }
            for(int i=0;i<=n;++i)
            {
                temp=theta-alpha*f(i);
            }
            for(int i=0;i<=n;++i)
            {
                theta=temp;
            }
            nex=J_theta();
            //cout<<J_theta()<<endl;
        }while (abs(nex-now)>eps);
    }
    int main()
    {
        freopen("in.txt","r",stdin);
        cin>>n>>m;
        for(int i=1;i<=m;++i)
        {
            for(int j=1;j<=n;++j)
            {
                cin>>data[j];
            }
        }
        for(int i=1;i<=m;++i)
        {
            cin>>Y;
        }
        Mean_Normaliazation();
        Gradient_Descent();
        for(int i=0;i<=n;++i)
        {
            printf("%.2lf\n",theta);
        }
        return 0;
    }
    

     

     

    训练数据在这里:

    2 10
    100 4
    50 3
    100 4
    100 2
    50 2
    80 2
    75 3
    65 4
    90 3
    90 2
    9.3 4.8 8.9 6.5 4.2 6.2 7.4 6.0 7.6 9.3 4.8 8.9 6.5
    

     

    哎...今天够累的,签到来了1...
    回复

    使用道具 举报

    您需要登录后才可以回帖 登录 | 立即注册

    本版积分规则

    QQ|手机版|小黑屋|Java自学者论坛 ( 声明:本站文章及资料整理自互联网,用于Java自学者交流学习使用,对资料版权不负任何法律责任,若有侵权请及时联系客服屏蔽删除 )

    GMT+8, 2024-4-18 17:48 , Processed in 0.070436 second(s), 29 queries .

    Powered by Discuz! X3.4

    Copyright © 2001-2021, Tencent Cloud.

    快速回复 返回顶部 返回列表