summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
Diffstat (limited to 'tesseract/src/training/common/sampleiterator.cpp')
-rw-r--r--tesseract/src/training/common/sampleiterator.cpp275
1 files changed, 275 insertions, 0 deletions
diff --git a/tesseract/src/training/common/sampleiterator.cpp b/tesseract/src/training/common/sampleiterator.cpp
new file mode 100644
index 00000000..e2047b2f
--- /dev/null
+++ b/tesseract/src/training/common/sampleiterator.cpp
@@ -0,0 +1,275 @@
+// Copyright 2011 Google Inc. All Rights Reserved.
+// Author: rays@google.com (Ray Smith)
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+// http://www.apache.org/licenses/LICENSE-2.0
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+//
+///////////////////////////////////////////////////////////////////////
+
+#include "sampleiterator.h"
+
+#include "intfeaturemap.h"
+
+#include "indexmapbidi.h"
+#include "shapetable.h"
+#include "trainingsample.h"
+#include "trainingsampleset.h"
+
+namespace tesseract {
+
+// ================== SampleIterator Implementation =================
+
+SampleIterator::SampleIterator()
+ : charset_map_(nullptr),
+ shape_table_(nullptr),
+ sample_set_(nullptr),
+ randomize_(false),
+ owned_shape_table_(nullptr) {
+ num_shapes_ = 0;
+ Begin();
+}
+
+SampleIterator::~SampleIterator() {
+ Clear();
+}
+
+void SampleIterator::Clear() {
+ delete owned_shape_table_;
+ owned_shape_table_ = nullptr;
+}
+
+// See class comment for arguments.
+void SampleIterator::Init(const IndexMapBiDi* charset_map,
+ const ShapeTable* shape_table,
+ bool randomize,
+ TrainingSampleSet* sample_set) {
+ Clear();
+ charset_map_ = charset_map;
+ shape_table_ = shape_table;
+ sample_set_ = sample_set;
+ randomize_ = randomize;
+ if (shape_table_ == nullptr && charset_map_ != nullptr) {
+ // The caller wishes to iterate by class. The easiest way to do this
+ // is to create a dummy shape_table_ that we will own.
+ int num_fonts = sample_set_->NumFonts();
+ owned_shape_table_ = new ShapeTable(sample_set_->unicharset());
+ int charsetsize = sample_set_->unicharset().size();
+ for (int c = 0; c < charsetsize; ++c) {
+ // We always add a shape for each character to keep the index in sync
+ // with the unichar_id.
+ int shape_id = owned_shape_table_->AddShape(c, 0);
+ for (int f = 1; f < num_fonts; ++f) {
+ if (sample_set_->NumClassSamples(f, c, true) > 0) {
+ owned_shape_table_->AddToShape(shape_id, c, f);
+ }
+ }
+ }
+ shape_table_ = owned_shape_table_;
+ }
+ if (shape_table_ != nullptr) {
+ num_shapes_ = shape_table_->NumShapes();
+ } else {
+ num_shapes_ = randomize ? sample_set_->num_samples()
+ : sample_set_->num_raw_samples();
+ }
+ Begin();
+}
+
+// Iterator functions designed for use with a simple for loop:
+// for (it.Begin(); !it.AtEnd(); it.Next()) {
+// const TrainingSample& sample = it.GetSample();
+// }
+void SampleIterator::Begin() {
+ shape_index_ = -1;
+ shape_char_index_ = 0;
+ num_shape_chars_ = 0;
+ shape_font_index_ = 0;
+ num_shape_fonts_ = 0;
+ sample_index_ = 0;
+ num_samples_ = 0;
+ // Find the first indexable sample.
+ Next();
+}
+
+bool SampleIterator::AtEnd() const {
+ return shape_index_ >= num_shapes_;
+}
+
+const TrainingSample& SampleIterator::GetSample() const {
+ if (shape_table_ != nullptr) {
+ const UnicharAndFonts* shape_entry = GetShapeEntry();
+ int char_id = shape_entry->unichar_id;
+ int font_id = shape_entry->font_ids[shape_font_index_];
+ return *sample_set_->GetSample(font_id, char_id, sample_index_);
+ } else {
+ return *sample_set_->GetSample(shape_index_);
+ }
+}
+
+TrainingSample* SampleIterator::MutableSample() const {
+ if (shape_table_ != nullptr) {
+ const UnicharAndFonts* shape_entry = GetShapeEntry();
+ int char_id = shape_entry->unichar_id;
+ int font_id = shape_entry->font_ids[shape_font_index_];
+ return sample_set_->MutableSample(font_id, char_id, sample_index_);
+ } else {
+ return sample_set_->mutable_sample(shape_index_);
+ }
+}
+
+// Returns the total index (from the original set of samples) of the current
+// sample.
+int SampleIterator::GlobalSampleIndex() const {
+ if (shape_table_ != nullptr) {
+ const UnicharAndFonts* shape_entry = GetShapeEntry();
+ int char_id = shape_entry->unichar_id;
+ int font_id = shape_entry->font_ids[shape_font_index_];
+ return sample_set_->GlobalSampleIndex(font_id, char_id, sample_index_);
+ } else {
+ return shape_index_;
+ }
+}
+
+// Returns the index of the current sample in compact charset space, so
+// in a 2-class problem between x and y, the returned indices will all be
+// 0 or 1, and have nothing to do with the unichar_ids.
+// If the charset_map_ is nullptr, then this is equal to GetSparseClassID().
+int SampleIterator::GetCompactClassID() const {
+ return charset_map_ != nullptr ? charset_map_->SparseToCompact(shape_index_)
+ : GetSparseClassID();
+}
+// Returns the index of the current sample in sparse charset space, so
+// in a 2-class problem between x and y, the returned indices will all be
+// x or y, where x and y may be unichar_ids (no shape_table_) or shape_ids
+// with a shape_table_.
+int SampleIterator::GetSparseClassID() const {
+ return shape_table_ != nullptr ? shape_index_ : GetSample().class_id();
+}
+
+// Moves on to the next indexable sample. If the end is reached, leaves
+// the state such that AtEnd() is true.
+void SampleIterator::Next() {
+ if (shape_table_ != nullptr) {
+ // Next sample in this class/font combination.
+ ++sample_index_;
+ if (sample_index_ < num_samples_)
+ return;
+ // Next font in this class in this shape.
+ sample_index_ = 0;
+ do {
+ ++shape_font_index_;
+ if (shape_font_index_ >= num_shape_fonts_) {
+ // Next unichar in this shape.
+ shape_font_index_ = 0;
+ ++shape_char_index_;
+ if (shape_char_index_ >= num_shape_chars_) {
+ // Find the next shape that is mapped in the charset_map_.
+ shape_char_index_ = 0;
+ do {
+ ++shape_index_;
+ } while (shape_index_ < num_shapes_ &&
+ charset_map_ != nullptr &&
+ charset_map_->SparseToCompact(shape_index_) < 0);
+ if (shape_index_ >= num_shapes_)
+ return; // The end.
+ num_shape_chars_ = shape_table_->GetShape(shape_index_).size();
+ }
+ }
+ const UnicharAndFonts* shape_entry = GetShapeEntry();
+ num_shape_fonts_ = shape_entry->font_ids.size();
+ int char_id = shape_entry->unichar_id;
+ int font_id = shape_entry->font_ids[shape_font_index_];
+ num_samples_ = sample_set_->NumClassSamples(font_id, char_id, randomize_);
+ } while (num_samples_ == 0);
+ } else {
+ // We are just iterating over the samples.
+ ++shape_index_;
+ }
+}
+
+// Returns the size of the compact charset space.
+int SampleIterator::CompactCharsetSize() const {
+ return charset_map_ != nullptr ? charset_map_->CompactSize()
+ : SparseCharsetSize();
+}
+
+// Returns the size of the sparse charset space.
+int SampleIterator::SparseCharsetSize() const {
+ return charset_map_ != nullptr
+ ? charset_map_->SparseSize()
+ : (shape_table_ != nullptr ? shape_table_->NumShapes()
+ : sample_set_->charsetsize());
+}
+
+
+// Sets the mapped_features_ from the features using the provided
+// feature_map.
+static void MapFeatures(TrainingSample &s, const IntFeatureMap& feature_map) {
+ GenericVector<int> indexed_features;
+ feature_map.feature_space().IndexAndSortFeatures(s.features(), s.num_features(),
+ &indexed_features);
+ feature_map.MapIndexedFeatures(indexed_features, &s.mapped_features_);
+ s.features_are_indexed_ = false;
+ s.features_are_mapped_ = true;
+}
+
+// Apply the supplied feature_space/feature_map transform to all samples
+// accessed by this iterator.
+void SampleIterator::MapSampleFeatures(const IntFeatureMap& feature_map) {
+ for (Begin(); !AtEnd(); Next()) {
+ TrainingSample* sample = MutableSample();
+ MapFeatures(*sample, feature_map);
+ }
+}
+
+// Adjust the weights of all the samples to be uniform in the given charset.
+// Returns the number of samples in the iterator.
+int SampleIterator::UniformSamples() {
+ int num_good_samples = 0;
+ for (Begin(); !AtEnd(); Next()) {
+ TrainingSample* sample = MutableSample();
+ sample->set_weight(1.0);
+ ++num_good_samples;
+ }
+ NormalizeSamples();
+ return num_good_samples;
+}
+
+// Normalize the weights of all the samples in the charset_map so they sum
+// to 1. Returns the minimum assigned sample weight.
+double SampleIterator::NormalizeSamples() {
+ double total_weight = 0.0;
+ int sample_count = 0;
+ for (Begin(); !AtEnd(); Next()) {
+ const TrainingSample& sample = GetSample();
+ total_weight += sample.weight();
+ ++sample_count;
+ }
+ // Normalize samples.
+ double min_assigned_sample_weight = 1.0;
+ if (total_weight > 0.0) {
+ for (Begin(); !AtEnd(); Next()) {
+ TrainingSample* sample = MutableSample();
+ double weight = sample->weight() / total_weight;
+ if (weight < min_assigned_sample_weight)
+ min_assigned_sample_weight = weight;
+ sample->set_weight(weight);
+ }
+ }
+ return min_assigned_sample_weight;
+}
+
+// Helper returns the current UnicharAndFont shape_entry.
+const UnicharAndFonts* SampleIterator::GetShapeEntry() const {
+ const Shape& shape = shape_table_->GetShape(shape_index_);
+ return &shape[shape_char_index_];
+}
+
+} // namespace tesseract.