Make classify a function on the Classifier

This commit is contained in:
Joshua Peek
2012-07-23 13:47:15 -05:00
parent b9779e805e
commit bf944f6d1a
5 changed files with 85 additions and 79 deletions

View File

@@ -10,7 +10,8 @@ end
file 'lib/linguist/samples.yml' => Dir['samples/**/*'] do |f| file 'lib/linguist/samples.yml' => Dir['samples/**/*'] do |f|
require 'linguist/samples' require 'linguist/samples'
File.open(f.name, 'w') { |io| Linguist::Samples.serialize_to_yaml(Linguist::Samples::DATA, io) } yaml = Linguist::Samples.serialize_to_yaml(Linguist::Samples.data)
File.open(f.name, 'w') { |io| io.write yaml }
end end
CLOBBER.include 'lib/linguist/samples.yml' CLOBBER.include 'lib/linguist/samples.yml'
@@ -31,7 +32,7 @@ namespace :classifier do
next if file_language.nil? || file_language == 'Text' next if file_language.nil? || file_language == 'Text'
begin begin
data = open(file_url).read data = open(file_url).read
guessed_language, score = Linguist::Classifier.new(Samples::DATA).classify(data).first guessed_language, score = Linguist::Classifier.classify(Samples::DATA, data).first
total += 1 total += 1
guessed_language == file_language ? correct += 1 : incorrect += 1 guessed_language == file_language ? correct += 1 : incorrect += 1

View File

@@ -442,7 +442,7 @@ module Linguist
if Language.ambiguous?(extname) if Language.ambiguous?(extname)
possible_languages = Language.all.select { |l| l.extensions.include?(extname) }.map(&:name) possible_languages = Language.all.select { |l| l.extensions.include?(extname) }.map(&:name)
if possible_languages.any? if possible_languages.any?
if result = Classifier.new(Samples::DATA).classify(data, possible_languages).first if result = Classifier.classify(Samples::DATA, data, possible_languages).first
Language[result[0]] Language[result[0]]
end end
end end

View File

@@ -3,56 +3,76 @@ require 'linguist/tokenizer'
module Linguist module Linguist
# Language bayesian classifier. # Language bayesian classifier.
class Classifier class Classifier
# Public: Initialize a Classifier.
def initialize(attrs = {})
@tokens_total = attrs['tokens_total'] || 0
@languages_total = attrs['languages_total'] || 0
@tokens = attrs['tokens'] || {}
@language_tokens = attrs['language_tokens'] || {}
@languages = attrs['languages'] || {}
end
# Public: Train classifier that data is a certain language. # Public: Train classifier that data is a certain language.
# #
# db - Hash classifier database object
# language - String language of data # language - String language of data
# data - String contents of file # data - String contents of file
# #
# Examples # Examples
# #
# train('Ruby', "def hello; end") # Classifier.train(db, 'Ruby', "def hello; end")
# #
# Returns nothing. # Returns nothing.
def train(language, data) def self.train!(db, language, data)
tokens = Tokenizer.tokenize(data) tokens = Tokenizer.tokenize(data)
db['tokens_total'] ||= 0
db['languages_total'] ||= 0
db['tokens'] ||= {}
db['language_tokens'] ||= {}
db['languages'] ||= {}
tokens.each do |token| tokens.each do |token|
@tokens[language] ||= {} db['tokens'][language] ||= {}
@tokens[language][token] ||= 0 db['tokens'][language][token] ||= 0
@tokens[language][token] += 1 db['tokens'][language][token] += 1
@language_tokens[language] ||= 0 db['language_tokens'][language] ||= 0
@language_tokens[language] += 1 db['language_tokens'][language] += 1
@tokens_total += 1 db['tokens_total'] += 1
end end
@languages[language] ||= 0 db['languages'][language] ||= 0
@languages[language] += 1 db['languages'][language] += 1
@languages_total += 1 db['languages_total'] += 1
nil nil
end end
# Public: Guess language of data. # Public: Guess language of data.
# #
# db - Hash of classifer tokens database.
# data - Array of tokens or String data to analyze. # data - Array of tokens or String data to analyze.
# languages - Array of language name Strings to restrict to. # languages - Array of language name Strings to restrict to.
# #
# Examples # Examples
# #
# classify("def hello; end") # Classifier.classify(db, "def hello; end")
# # => [ 'Ruby', 0.90], ['Python', 0.2], ... ] # # => [ 'Ruby', 0.90], ['Python', 0.2], ... ]
# #
# Returns sorted Array of result pairs. Each pair contains the # Returns sorted Array of result pairs. Each pair contains the
# String language name and a Float score. # String language name and a Float score.
def classify(tokens, languages = @languages.keys) def self.classify(db, tokens, languages = nil)
languages ||= db['languages'].keys
new(db).classify(tokens, languages)
end
# Internal: Initialize a Classifier.
def initialize(db = {})
@tokens_total = db['tokens_total']
@languages_total = db['languages_total']
@tokens = db['tokens']
@language_tokens = db['language_tokens']
@languages = db['languages']
end
# Internal: Guess language of data
#
# data - Array of tokens or String data to analyze.
# languages - Array of language name Strings to restrict to.
#
# Returns sorted Array of result pairs. Each pair contains the
# String language name and a Float score.
def classify(tokens, languages)
return [] if tokens.nil? return [] if tokens.nil?
tokens = Tokenizer.tokenize(tokens) if tokens.is_a?(String) tokens = Tokenizer.tokenize(tokens) if tokens.is_a?(String)
@@ -99,18 +119,5 @@ module Linguist
def language_probability(language) def language_probability(language)
Math.log(@languages[language].to_f / @languages_total.to_f) Math.log(@languages[language].to_f / @languages_total.to_f)
end end
# Public: Returns serializable hash representation.
#
# Returns Hash.
def to_hash
{
'tokens_total' => @tokens_total,
'languages_total' => @languages_total,
'tokens' => @tokens,
'language_tokens' => @language_tokens,
'languages' => @languages
}
end
end end
end end

View File

@@ -22,7 +22,7 @@ module Linguist
# #
# Returns Boolean. # Returns Boolean.
def self.outdated? def self.outdated?
MD5.hexdigest(DATA) != MD5.hexdigest(classifier.to_hash) MD5.hexdigest(DATA) != MD5.hexdigest(data)
end end
# Public: Iterate over each sample. # Public: Iterate over each sample.
@@ -98,52 +98,50 @@ module Linguist
# Public: Build Classifier from all samples. # Public: Build Classifier from all samples.
# #
# Returns trained Classifier. # Returns trained Classifier.
def self.classifier def self.data
require 'linguist/classifier' require 'linguist/classifier'
require 'linguist/language' require 'linguist/language'
classifier = Classifier.new db = {}
each { |sample| each do |sample|
language = Language.find_by_alias(sample[:language]) language = Language.find_by_alias(sample[:language])
data = File.read(sample[:path]) data = File.read(sample[:path])
classifier.train(language.name, data) Classifier.train!(db, language.name, data)
} end
classifier db
end end
# Public: Serialize samples data to YAML. # Public: Serialize samples data to YAML.
# #
# data - Hash # db - Hash
# io - IO object to write to
# #
# Returns nothing. # Returns String.
def self.serialize_to_yaml(data, io) def self.serialize_to_yaml(db)
data = "" out = ""
escape = lambda { |s| s.inspect.gsub(/\\#/, "\#") } escape = lambda { |s| s.inspect.gsub(/\\#/, "\#") }
data << "languages_total: #{data['languages_total']}\n" out << "languages_total: #{db['languages_total']}\n"
data << "tokens_total: #{data['tokens_total']}\n" out << "tokens_total: #{db['tokens_total']}\n"
data << "languages:\n" out << "languages:\n"
data['languages'].sort.each do |language, count| db['languages'].sort.each do |language, count|
data << " #{escape.call(language)}: #{count}\n" out << " #{escape.call(language)}: #{count}\n"
end end
data << "language_tokens:\n" out << "language_tokens:\n"
data['language_tokens'].sort.each do |language, count| db['language_tokens'].sort.each do |language, count|
data << " #{escape.call(language)}: #{count}\n" out << " #{escape.call(language)}: #{count}\n"
end end
data << "tokens:\n" out << "tokens:\n"
data['tokens'].sort.each do |language, tokens| db['tokens'].sort.each do |language, tokens|
data << " #{escape.call(language)}:\n" out << " #{escape.call(language)}:\n"
tokens.sort.each do |token, count| tokens.sort.each do |token, count|
data << " #{escape.call(token)}: #{count}\n" out << " #{escape.call(token)}: #{count}\n"
end end
end end
io.write data out
nil
end end
end end
end end

View File

@@ -24,39 +24,39 @@ class TestClassifier < Test::Unit::TestCase
end end
def test_classify def test_classify
classifier = Classifier.new db = {}
classifier.train "Ruby", fixture("ruby/foo.rb") Classifier.train! db, "Ruby", fixture("ruby/foo.rb")
classifier.train "Objective-C", fixture("objective-c/Foo.h") Classifier.train! db, "Objective-C", fixture("objective-c/Foo.h")
classifier.train "Objective-C", fixture("objective-c/Foo.m") Classifier.train! db, "Objective-C", fixture("objective-c/Foo.m")
results = classifier.classify(fixture("objective-c/hello.m")) results = Classifier.classify(db, fixture("objective-c/hello.m"))
assert_equal "Objective-C", results.first[0] assert_equal "Objective-C", results.first[0]
tokens = Tokenizer.tokenize(fixture("objective-c/hello.m")) tokens = Tokenizer.tokenize(fixture("objective-c/hello.m"))
results = classifier.classify(tokens) results = Classifier.classify(db, tokens)
assert_equal "Objective-C", results.first[0] assert_equal "Objective-C", results.first[0]
end end
def test_restricted_classify def test_restricted_classify
classifier = Classifier.new db = {}
classifier.train "Ruby", fixture("ruby/foo.rb") Classifier.train! db, "Ruby", fixture("ruby/foo.rb")
classifier.train "Objective-C", fixture("objective-c/Foo.h") Classifier.train! db, "Objective-C", fixture("objective-c/Foo.h")
classifier.train "Objective-C", fixture("objective-c/Foo.m") Classifier.train! db, "Objective-C", fixture("objective-c/Foo.m")
results = classifier.classify(fixture("objective-c/hello.m"), ["Objective-C"]) results = Classifier.classify(db, fixture("objective-c/hello.m"), ["Objective-C"])
assert_equal "Objective-C", results.first[0] assert_equal "Objective-C", results.first[0]
results = classifier.classify(fixture("objective-c/hello.m"), ["Ruby"]) results = Classifier.classify(db, fixture("objective-c/hello.m"), ["Ruby"])
assert_equal "Ruby", results.first[0] assert_equal "Ruby", results.first[0]
end end
def test_instance_classify_empty def test_instance_classify_empty
results = Classifier.new(Samples::DATA).classify("") results = Classifier.classify(Samples::DATA, "")
assert results.first[1] < 0.5, results.first.inspect assert results.first[1] < 0.5, results.first.inspect
end end
def test_instance_classify_nil def test_instance_classify_nil
assert_equal [], Classifier.new(Samples::DATA).classify(nil) assert_equal [], Classifier.classify(Samples::DATA, nil)
end end
def test_verify def test_verify
@@ -76,7 +76,7 @@ class TestClassifier < Test::Unit::TestCase
languages = Language.all.select { |l| l.extensions.include?(extname) }.map(&:name) languages = Language.all.select { |l| l.extensions.include?(extname) }.map(&:name)
next unless languages.length > 1 next unless languages.length > 1
results = Classifier.new(Samples::DATA).classify(File.read(sample[:path]), languages) results = Classifier.classify(Samples::DATA, File.read(sample[:path]), languages)
assert_equal language.name, results.first[0], "#{sample[:path]}\n#{results.inspect}" assert_equal language.name, results.first[0], "#{sample[:path]}\n#{results.inspect}"
end end
end end