linear_regression.hpp 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. /*
  2. * Copyright Nick Thompson, 2019
  3. * Use, modification and distribution are subject to the
  4. * Boost Software License, Version 1.0. (See accompanying file
  5. * LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
  6. */
  7. #ifndef BOOST_MATH_STATISTICS_LINEAR_REGRESSION_HPP
  8. #define BOOST_MATH_STATISTICS_LINEAR_REGRESSION_HPP
  9. #include <cmath>
  10. #include <algorithm>
  11. #include <utility>
  12. #include <boost/math/statistics/univariate_statistics.hpp>
  13. #include <boost/math/statistics/bivariate_statistics.hpp>
  14. namespace boost::math::statistics {
  15. template<class RandomAccessContainer>
  16. auto simple_ordinary_least_squares(RandomAccessContainer const & x,
  17. RandomAccessContainer const & y)
  18. {
  19. using Real = typename RandomAccessContainer::value_type;
  20. if (x.size() <= 1)
  21. {
  22. throw std::domain_error("At least 2 samples are required to perform a linear regression.");
  23. }
  24. if (x.size() != y.size())
  25. {
  26. throw std::domain_error("The same number of samples must be in the independent and dependent variable.");
  27. }
  28. auto [mu_x, mu_y, cov_xy] = boost::math::statistics::means_and_covariance(x, y);
  29. auto var_x = boost::math::statistics::variance(x);
  30. if (var_x <= 0) {
  31. throw std::domain_error("Independent variable has no variance; this breaks linear regression.");
  32. }
  33. Real c1 = cov_xy/var_x;
  34. Real c0 = mu_y - c1*mu_x;
  35. return std::make_pair(c0, c1);
  36. }
  37. template<class RandomAccessContainer>
  38. auto simple_ordinary_least_squares_with_R_squared(RandomAccessContainer const & x,
  39. RandomAccessContainer const & y)
  40. {
  41. using Real = typename RandomAccessContainer::value_type;
  42. if (x.size() <= 1)
  43. {
  44. throw std::domain_error("At least 2 samples are required to perform a linear regression.");
  45. }
  46. if (x.size() != y.size())
  47. {
  48. throw std::domain_error("The same number of samples must be in the independent and dependent variable.");
  49. }
  50. auto [mu_x, mu_y, cov_xy] = boost::math::statistics::means_and_covariance(x, y);
  51. auto var_x = boost::math::statistics::variance(x);
  52. if (var_x <= 0) {
  53. throw std::domain_error("Independent variable has no variance; this breaks linear regression.");
  54. }
  55. Real c1 = cov_xy/var_x;
  56. Real c0 = mu_y - c1*mu_x;
  57. Real squared_residuals = 0;
  58. Real squared_mean_deviation = 0;
  59. for(decltype(y.size()) i = 0; i < y.size(); ++i) {
  60. squared_mean_deviation += (y[i] - mu_y)*(y[i]-mu_y);
  61. Real ei = (c0 + c1*x[i]) - y[i];
  62. squared_residuals += ei*ei;
  63. }
  64. Real Rsquared;
  65. if (squared_mean_deviation == 0) {
  66. // Then y = constant, so the linear regression is perfect.
  67. Rsquared = 1;
  68. } else {
  69. Rsquared = 1 - squared_residuals/squared_mean_deviation;
  70. }
  71. return std::make_tuple(c0, c1, Rsquared);
  72. }
  73. }
  74. #endif