diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 3fcc7e3c..cee17bf9 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -7,11 +7,11 @@ concurrency: on: push: branches: - - master + - main - ci pull_request: branches: - - master + - main schedule: - cron: 0 0 * * * diff --git a/http-client-tls/Network/HTTP/Client/TLS.hs b/http-client-tls/Network/HTTP/Client/TLS.hs index 965d3fda..538d21d0 100644 --- a/http-client-tls/Network/HTTP/Client/TLS.hs +++ b/http-client-tls/Network/HTTP/Client/TLS.hs @@ -1,4 +1,5 @@ {-# LANGUAGE CPP #-} +{-# LANGUAGE PackageImports #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE DeriveDataTypeable #-} @@ -27,8 +28,8 @@ import Control.Applicative ((<|>)) import Control.Arrow (first) import System.Environment (getEnvironment) import Data.Default -import Network.HTTP.Client hiding (host, port) -import Network.HTTP.Client.Internal hiding (host, port) +import Network.HTTP.Client hiding (withConnection, host, port) +import Network.HTTP.Client.Internal hiding (withConnection, host, port) import Control.Exception import qualified Network.Connection as NC import Network.Socket (HostAddress) @@ -44,7 +45,7 @@ import Network.HTTP.Types (status401) import qualified Crypto.Hash.MD5 as MD5 import Control.Arrow ((***)) import Data.Base16.Types (extractBase16) -import Data.ByteString.Base16 (encodeBase16') +import "base16" Data.ByteString.Base16 (encodeBase16') import Data.Typeable (Typeable) import Control.Monad.Catch (MonadThrow, throwM) import qualified Data.Map as Map @@ -124,15 +125,15 @@ getTlsConnection :: Maybe NC.ConnectionContext -> IO (Maybe HostAddress -> String -> Int -> IO Connection) getTlsConnection mcontext tls sock = do context <- maybe NC.initConnectionContext return mcontext - return $ \_ha host port -> bracketOnError - (NC.connectTo context NC.ConnectionParams - { NC.connectionHostname = strippedHostName host - , NC.connectionPort = fromIntegral port - , NC.connectionUseSecure = tls - , NC.connectionUseSocks = sock - }) - NC.connectionClose - convertConnection + return $ \_ha host port -> do + let params = NC.ConnectionParams + { NC.connectionHostname = strippedHostName host + , NC.connectionPort = fromIntegral port + , NC.connectionUseSecure = tls + , NC.connectionUseSocks = sock + } + withConnection context params + convertConnection getTlsProxyConnection :: Maybe NC.ConnectionContext @@ -141,18 +142,17 @@ getTlsProxyConnection -> IO (S.ByteString -> (Connection -> IO ()) -> String -> Maybe HostAddress -> String -> Int -> IO Connection) getTlsProxyConnection mcontext tls sock = do context <- maybe NC.initConnectionContext return mcontext - return $ \connstr checkConn serverName _ha host port -> bracketOnError - (NC.connectTo context NC.ConnectionParams - { NC.connectionHostname = strippedHostName serverName - , NC.connectionPort = fromIntegral port - , NC.connectionUseSecure = Nothing - , NC.connectionUseSocks = - case sock of - Just _ -> error "Cannot use SOCKS and TLS proxying together" - Nothing -> Just $ NC.OtherProxy (strippedHostName host) $ fromIntegral port - }) - NC.connectionClose - $ \conn -> do + return $ \connstr checkConn serverName _ha host port -> do + let params = NC.ConnectionParams + { NC.connectionHostname = strippedHostName serverName + , NC.connectionPort = fromIntegral port + , NC.connectionUseSecure = Nothing + , NC.connectionUseSocks = + case sock of + Just _ -> error "Cannot use SOCKS and TLS proxying together" + Nothing -> Just $ NC.OtherProxy (strippedHostName host) $ fromIntegral port + } + withConnection context params $ \conn -> do NC.connectionPut conn connstr conn' <- convertConnection conn @@ -162,6 +162,9 @@ getTlsProxyConnection mcontext tls sock = do return conn' +withConnection :: NC.ConnectionContext -> NC.ConnectionParams -> (NC.Connection -> IO a) -> IO a +withConnection context params = bracketOnError (NC.connectTo context params) NC.connectionClose + convertConnection :: NC.Connection -> IO Connection convertConnection conn = makeConnection (NC.connectionGetChunk conn)