diff --git a/linkchecker.go b/linkchecker.go index 4695622..7d08b80 100644 --- a/linkchecker.go +++ b/linkchecker.go @@ -10,12 +10,14 @@ import ( ) type LinkChecker struct { - client *http.Client + client *http.Client + visited map[string]bool } func NewLinkChecker() *LinkChecker { return &LinkChecker{ - client: &http.Client{}, + client: &http.Client{}, + visited: make(map[string]bool), } } @@ -25,16 +27,35 @@ type BrokenLink struct { Error string } +func (lc *LinkChecker) isSameDomain(baseURL, link string) bool { + base, err := url.Parse(baseURL) + if err != nil { + return false + } + + target, err := url.Parse(link) + if err != nil { + return false + } + + return base.Host == target.Host +} + func (lc *LinkChecker) CheckLinks(baseURL string) ([]BrokenLink, error) { - // Get all links from the page - links, err := lc.getLinks(baseURL) + return lc.checkLinksRecursive(baseURL, make([]BrokenLink, 0)) +} + +func (lc *LinkChecker) checkLinksRecursive(pageURL string, brokenLinks []BrokenLink) ([]BrokenLink, error) { + if lc.visited[pageURL] { + return brokenLinks, nil + } + lc.visited[pageURL] = true + + links, err := lc.getLinks(pageURL) if err != nil { return nil, fmt.Errorf("error getting links: %w", err) } - var brokenLinks []BrokenLink - - // Check each link for _, link := range links { if status, err := lc.isLinkValid(link); status >= 400 || err != nil { broken := BrokenLink{URL: link} @@ -45,6 +66,15 @@ func (lc *LinkChecker) CheckLinks(baseURL string) ([]BrokenLink, error) { } brokenLinks = append(brokenLinks, broken) } + + // Recursively check links from the same domain + if lc.isSameDomain(pageURL, link) && !lc.visited[link] { + recursiveLinks, err := lc.checkLinksRecursive(link, brokenLinks) + if err != nil { + continue // Skip this page if there's an error, but continue checking others + } + brokenLinks = recursiveLinks + } } return brokenLinks, nil