BlackCat_Tensors
A GPU-supported autograd and linear algebra library, designed for neural network construction
io.h
Go to the documentation of this file.
1 /* Project: BlackCat_Tensors
2  * Author: JosephJaspers
3  * Copyright 2018
4  *
5  * This Source Code Form is subject to the terms of the Mozilla Public
6  * License, v. 2.0. If a copy of the MPL was not distributed with this
7  * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
8 
9 #ifndef BLACKCAT_IO_H_
10 #define BLACKCAT_IO_H_
11 
12 #include "common.h"
13 #include "tensors.h"
14 #include "string.h"
15 #include <algorithm>
16 #include <sstream>
17 #include <fstream>
18 #include <string>
19 #include <vector>
20 #include <assert.h>
21 
22 namespace bc {
23 namespace io {
24 
25 template<class T>
26 static T from_string(const std::string& str);
27 
28 #define from_string_def(dtype, ...)\
29 template<>\
30 inline dtype from_string(const std::string& str) {\
31  return __VA_ARGS__;\
32 }
33 from_string_def(double, std::stod(str))
34 from_string_def(float, std::stof(str))
35 from_string_def(int, std::stoi(str))
36 from_string_def(std::string, str)
37 
38 template<class T>
39 struct Range {
40  T begin_;
41  T end_;
42  T begin() const { return begin_; }
43  T end() const { return end_; }
44 };
45 template<class T>
46 auto range(T begin, T end=T()) {
47  return Range<T>{begin, end};
48 }
49 
51 
52 #define FORWARDED_PARAM(dtype, name, default_value) \
53 dtype name##_ = default_value; \
54 csv_descriptor& name(dtype name) { \
55  name##_ = name; \
56  return *this; \
57 } \
58 const dtype& name() const { \
59  return name##_; \
60 } \
61 
62  csv_descriptor(std::string fname) : filename_(fname) {}
63 
64  FORWARDED_PARAM(std::string, filename, "")
65  FORWARDED_PARAM(bool, header, true)
66  FORWARDED_PARAM(bool, index, false)
67  FORWARDED_PARAM(char, mode, 'r')
68  FORWARDED_PARAM(char, delim, ',')
70  FORWARDED_PARAM(bool, transpose, false)
71 
72  FORWARDED_PARAM(std::vector<int>, skip_rows, {})
73  FORWARDED_PARAM(std::vector<int>, skip_cols, {})
74 
75  template<class... Integers>
76  csv_descriptor& skip_rows(int x, Integers... args_) {
77  skip_rows_ = std::vector<int> {x, args_...};
78  return *this;
79  }
80 
81  template<class... Integers>
82  csv_descriptor& skip_cols(int x, Integers... args_) {
83  skip_cols_ = std::vector<int> {x, args_...};
84  return *this;
85  }
86 
87 };
88 
89 static std::vector<std::vector<bc::string>> parse(csv_descriptor desc)
90 {
91  using bc::string;
92  using std::vector;
93 
94  std::ifstream ifs(desc.filename());
95 
96  auto find = [](auto& collection, auto var) -> bool {
97  return std::find(
98  collection.begin(),
99  collection.end(),
100  var) != collection.end();
101  };
102 
103  if (!ifs.good()) {
104  bc::print("Unable to open `", desc.filename(), '`');
105  throw 1;
106  }
107 
108  string csv_text = string(
109  std::istreambuf_iterator<char>(ifs),
110  std::istreambuf_iterator<char>());
111 
112  vector<string> rows = csv_text.split(desc.row_delim());
113  vector<vector<string>> split_rows;
114 
115  int curr_col = 0;
116 
117  for (string& row : rows) {
118  auto cells = row.split(desc.delim());
119 
120  if (!split_rows.empty() &&
121  split_rows.back().size() != cells.size()) {
122  bc::printerr("Column length mismatch."
123  "\nExpected: ", split_rows.back().size(),
124  "\nReceived: ", cells.size(),
125  "\nRow index: ", split_rows.size());
126  throw 1;
127  }
128 
129  if (!cells.empty() && !find(desc.skip_rows(), curr_col)) {
130  if (desc.skip_cols().empty())
131  split_rows.push_back(cells);
132  else {
133  vector<string> curr_row;
134  for (std::size_t i = 0; i < cells.size(); ++i) {
135  if (!find(desc.skip_cols(), i))
136  curr_row.push_back(std::move(cells[i]));
137  }
138  split_rows.push_back(curr_row);
139  }
140  }
141  ++curr_col;
142  }
143  return split_rows;
144 }
145 
146 template<
147  class ValueType,
149 static bc::Matrix<ValueType, Allocator> read_uniform(
150  csv_descriptor desc,
151  Allocator alloc=Allocator()) {
152 
153  using bc::string;
154  using std::vector;
155 
156  if (desc.transpose()){
157  bc::print("Transpose is not supported for read_uniform");
158  bc::print("TODO implement transposition");
159  throw 1;
160  }
161 
162  vector<vector<string>> data = parse(desc);
163 
164  int rows = data.size() - desc.header();
165  int cols = data[0].size() - desc.index();
166 
167  if (desc.transpose())
168  std::swap(rows, cols);
169 
170  bc::Matrix<ValueType, Allocator> matrix(rows, cols);
171  for (int i = 0; i < rows; ++i) {
172  for (int j = 0; j < cols; ++j) {
173  int d_i = i + desc.header();
174  int d_j = j + desc.index();
175 
176  if (desc.transpose())
177  matrix[i][j] = from_string<ValueType>(data[d_i][d_j]);
178  else
179  matrix[j][i] = from_string<ValueType>(data[d_i][d_j]);
180  }
181  }
182 
183  return matrix;
184 }
185 
186 }
187 }
188 
189 
190 #endif /* BLACKCAT_IO_H_ */
#define from_string_def(dtype,...)
Definition: io.h:28
T begin_
Definition: io.h:40
const std::vector< int > & skip_rows() const
Definition: io.h:72
csv_descriptor & skip_cols(int x, Integers... args_)
Definition: io.h:82
csv_descriptor & filename(std::string filename)
Definition: io.h:64
csv_descriptor & skip_rows(int x, Integers... args_)
Definition: io.h:76
const bool & transpose() const
Definition: io.h:70
T end() const
Definition: io.h:43
csv_descriptor & index(bool index)
Definition: io.h:66
csv_descriptor & row_delim(char row_delim)
Definition: io.h:69
const char & delim() const
Definition: io.h:68
std::vector< bc::string > split(char delim) const
Definition: string.h:60
class::::::Args static auto swap(bc::streams::Stream< bc::host_tag > stream, Begin begin, End end, Args... args)
Definition: algorithms.h:140
Definition: io.h:50
const char & mode() const
Definition: io.h:67
csv_descriptor(std::string fname)
Definition: io.h:62
T end_
Definition: io.h:41
csv_descriptor & header(bool header)
Definition: io.h:65
class::::::Args static auto find(bc::streams::Stream< bc::host_tag > stream, Begin begin, End end, Args... args)
Definition: algorithms.h:124
std::string filename_
Definition: io.h:64
csv_descriptor & skip_rows(std::vector< int > skip_rows)
Definition: io.h:72
const bool & index() const
Definition: io.h:66
std::vector< int > skip_cols_
Definition: io.h:73
int size_t
Definition: common.h:283
Definition: allocators.h:20
const bool & header() const
Definition: io.h:65
#define FORWARDED_PARAM(dtype, name, default_value)
Definition: io.h:52
std::vector< int > skip_rows_
Definition: io.h:72
const char & row_delim() const
Definition: io.h:69
auto range(T begin, T end=T())
Definition: io.h:46
void print(const Ts &... args)
Definition: common.h:165
void printerr(const Ts &... args)
Definition: common.h:170
Inherits from std::string.
Definition: string.h:21
const std::vector< int > & skip_cols() const
Definition: io.h:73
csv_descriptor & skip_cols(std::vector< int > skip_cols)
Definition: io.h:73
Definition: common.h:25
const std::string & filename() const
Definition: io.h:64
csv_descriptor & transpose(bool transpose)
Definition: io.h:70
csv_descriptor & delim(char delim)
Definition: io.h:68
T begin() const
Definition: io.h:42
The Evaluator determines if an expression needs to be greedily optimized.
Definition: algorithms.h:22
Definition: io.h:39