View Javadoc
1   package org.argeo.ssh;
2   
3   import java.io.Console;
4   import java.io.IOException;
5   import java.net.URI;
6   import java.net.URISyntaxException;
7   import java.util.Arrays;
8   import java.util.HashSet;
9   import java.util.Scanner;
10  import java.util.Set;
11  
12  import org.apache.commons.logging.Log;
13  import org.apache.commons.logging.LogFactory;
14  import org.apache.sshd.client.SshClient;
15  import org.apache.sshd.client.channel.ClientChannel;
16  import org.apache.sshd.client.channel.ClientChannelEvent;
17  import org.apache.sshd.client.future.ConnectFuture;
18  import org.apache.sshd.client.session.ClientSession;
19  import org.apache.sshd.client.subsystem.sftp.fs.SftpFileSystemProvider;
20  import org.apache.sshd.common.util.io.NoCloseInputStream;
21  import org.apache.sshd.common.util.io.NoCloseOutputStream;
22  
23  @SuppressWarnings("restriction")
24  abstract class AbstractSsh {
25  	private final static Log log = LogFactory.getLog(AbstractSsh.class);
26  
27  	private static SshClient sshClient;
28  	private static SftpFileSystemProvider sftpFileSystemProvider;
29  
30  	private boolean passwordSet = false;
31  	private ClientSession session;
32  
33  	private SshKeyPair sshKeyPair;
34  
35  	synchronized SshClient getSshClient() {
36  		if (sshClient == null) {
37  			long begin = System.currentTimeMillis();
38  			sshClient = SshClient.setUpDefaultClient();
39  			sshClient.start();
40  			long duration = System.currentTimeMillis() - begin;
41  			if (log.isDebugEnabled())
42  				log.debug("SSH client started in " + duration + " ms");
43  			Runtime.getRuntime().addShutdownHook(new Thread(() -> sshClient.stop(), "Stop SSH client"));
44  		}
45  		return sshClient;
46  	}
47  
48  	synchronized SftpFileSystemProvider getSftpFileSystemProvider() {
49  		if (sftpFileSystemProvider == null) {
50  			sftpFileSystemProvider = new SftpFileSystemProvider(sshClient);
51  		}
52  		return sftpFileSystemProvider;
53  	}
54  
55  	void authenticate() {
56  		try {
57  			if (sshKeyPair != null) {
58  				session.addPublicKeyIdentity(sshKeyPair.asKeyPair());
59  			} else {
60  
61  				if (!passwordSet) {
62  					String password;
63  					Console console = System.console();
64  					if (console == null) {// IDE
65  						System.out.print("Password: ");
66  						try (Scanner s = new Scanner(System.in)) {
67  							password = s.next();
68  						}
69  					} else {
70  						console.printf("Password: ");
71  						char[] pwd = console.readPassword();
72  						password = new String(pwd);
73  						Arrays.fill(pwd, ' ');
74  					}
75  					session.addPasswordIdentity(password);
76  					passwordSet = true;
77  				}
78  			}
79  			session.auth().verify(1000l);
80  		} catch (IOException e) {
81  			throw new IllegalStateException(e);
82  		}
83  	}
84  
85  	void addPassword(String password) {
86  		session.addPasswordIdentity(password);
87  	}
88  
89  	void loadKey(String password) {
90  		loadKey(password, System.getProperty("user.home") + "/.ssh/id_rsa");
91  	}
92  
93  	void loadKey(String password, String keyPath) {
94  //		try {
95  //			KeyPair keyPair = ClientIdentityLoader.DEFAULT.loadClientIdentity(keyPath,
96  //					FilePasswordProvider.of(password));
97  //			session.addPublicKeyIdentity(keyPair);
98  //		} catch (IOException | GeneralSecurityException e) {
99  //			throw new IllegalStateException(e);
100 //		}
101 	}
102 
103 	void openSession(URI uri) {
104 		openSession(uri.getUserInfo(), uri.getHost(), uri.getPort() > 0 ? uri.getPort() : null);
105 	}
106 
107 	void openSession(String login, String host, Integer port) {
108 		if (session != null)
109 			throw new IllegalStateException("Session is already open");
110 
111 		if (host == null)
112 			host = "localhost";
113 		if (port == null)
114 			port = 22;
115 		if (login == null)
116 			login = System.getProperty("user.name");
117 		String password = null;
118 		int sepIndex = login.indexOf(':');
119 		if (sepIndex > 0)
120 			if (sepIndex + 1 < login.length()) {
121 				password = login.substring(sepIndex + 1);
122 				login = login.substring(0, sepIndex);
123 			} else {
124 				throw new IllegalArgumentException("Illegal authority: " + login);
125 			}
126 		try {
127 			ConnectFuture connectFuture = getSshClient().connect(login, host, port);
128 			connectFuture.await();
129 			ClientSession session = connectFuture.getSession();
130 			if (password != null) {
131 				session.addPasswordIdentity(password);
132 				passwordSet = true;
133 			}
134 			this.session = session;
135 		} catch (IOException e) {
136 			throw new IllegalStateException("Cannot connect to " + host + ":" + port);
137 		}
138 	}
139 
140 	void closeSession() {
141 		if (session == null)
142 			throw new IllegalStateException("No session is open");
143 		try {
144 			session.close();
145 		} catch (IOException e) {
146 			e.printStackTrace();
147 		} finally {
148 			session = null;
149 		}
150 	}
151 
152 	ClientSession getSession() {
153 		return session;
154 	}
155 
156 	public void setSshKeyPair(SshKeyPair sshKeyPair) {
157 		this.sshKeyPair = sshKeyPair;
158 	}
159 
160 	public static void openShell(ClientSession session) {
161 		try (ClientChannel channel = session.createChannel(ClientChannel.CHANNEL_SHELL)) {
162 			channel.setIn(new NoCloseInputStream(System.in));
163 			channel.setOut(new NoCloseOutputStream(System.out));
164 			channel.setErr(new NoCloseOutputStream(System.err));
165 			channel.open();
166 
167 			Set<ClientChannelEvent> events = new HashSet<>();
168 			events.add(ClientChannelEvent.CLOSED);
169 			channel.waitFor(events, 0);
170 		} catch (IOException e) {
171 			// TODO Auto-generated catch block
172 			e.printStackTrace();
173 		} finally {
174 			session.close(false);
175 		}
176 	}
177 
178 	static URI toUri(String username, String host, int port) {
179 		try {
180 			if (username == null)
181 				username = "root";
182 			return new URI("ssh://" + username + "@" + host + ":" + port);
183 		} catch (URISyntaxException e) {
184 			throw new IllegalArgumentException("Cannot generate SSH URI to " + host + ":" + port + " for " + username,
185 					e);
186 		}
187 	}
188 
189 }