mirror of
https://github.com/KevinMidboe/linguist.git
synced 2025-10-29 17:50:22 +00:00
Add simple classifier
This commit is contained in:
55
lib/linguist/classifier.rb
Normal file
55
lib/linguist/classifier.rb
Normal file
@@ -0,0 +1,55 @@
|
||||
require 'linguist/tokenizer'
|
||||
|
||||
module Linguist
|
||||
# Language bayesian classifier.
|
||||
class Classifier
|
||||
def initialize
|
||||
@tokens = Hash.new { |h, k| h[k] = Hash.new(0) }
|
||||
@language_tokens = Hash.new(0)
|
||||
@languages = Hash.new(0)
|
||||
@languages_total = 0
|
||||
@tokens_total = 0
|
||||
end
|
||||
|
||||
def train(language, data)
|
||||
tokens = Tokenizer.new(data).tokens
|
||||
|
||||
tokens.each do |token|
|
||||
@tokens[language][token] += 1
|
||||
@language_tokens[language] += 1
|
||||
@tokens_total += 1
|
||||
end
|
||||
@languages[language] += 1
|
||||
@languages_total += 1
|
||||
end
|
||||
|
||||
def classify(data)
|
||||
tokens = Tokenizer.new(data).tokens
|
||||
|
||||
scores = {}
|
||||
@languages.keys.each do |language|
|
||||
scores[language] = tokens_probability(tokens, language) * language_probability(language)
|
||||
end
|
||||
|
||||
scores.sort { |a, b| b[1] <=> a[1] }
|
||||
end
|
||||
|
||||
def tokens_probability(tokens, language)
|
||||
tokens.inject(1.0) do |sum, token|
|
||||
sum *= token_probability(token, language)
|
||||
end
|
||||
end
|
||||
|
||||
def token_probability(token, language)
|
||||
if @tokens[language][token] == 0
|
||||
1 / @tokens_total.to_f
|
||||
else
|
||||
@tokens[language][token].to_f / @languages[language].to_f
|
||||
end
|
||||
end
|
||||
|
||||
def language_probability(language)
|
||||
@languages[language].to_f / @languages_total.to_f
|
||||
end
|
||||
end
|
||||
end
|
||||
26
test/test_classifier.rb
Normal file
26
test/test_classifier.rb
Normal file
@@ -0,0 +1,26 @@
|
||||
require 'linguist/classifier'
|
||||
require 'linguist/language'
|
||||
|
||||
require 'test/unit'
|
||||
|
||||
class TestClassifier < Test::Unit::TestCase
|
||||
include Linguist
|
||||
|
||||
def fixtures_path
|
||||
File.expand_path("../fixtures", __FILE__)
|
||||
end
|
||||
|
||||
def fixture(name)
|
||||
File.read(File.join(fixtures_path, name))
|
||||
end
|
||||
|
||||
def test_truth
|
||||
classifier = Classifier.new
|
||||
classifier.train Language["Ruby"], fixture("ruby/foo.rb")
|
||||
classifier.train Language["Objective-C"], fixture("objective-c/Foo.h")
|
||||
classifier.train Language["Objective-C"], fixture("objective-c/Foo.m")
|
||||
|
||||
results = classifier.classify(fixture("objective-c/hello.m"))
|
||||
assert_equal Language["Objective-C"], results.first[0]
|
||||
end
|
||||
end
|
||||
Reference in New Issue
Block a user